# MOHID Water

This Jupyter Notebook aims to help implement and run the MOHID Water model.

***
**Note 1**: Execute each cell through the <button class="btn btn-default btn-xs"><i class="icon-play fa fa-play"></i></button> button from the top MENU (or keyboard shortcut `Shift` + `Enter`).<br>
<br>
**Note 2**: Use the Kernel and Cell menus to restart the kernel and clear outputs.<br>
***

# Table of contents
- [1. Import required libraries](#1.-Import-required-libraries)
- [2. General options](#2.-General-options)
    - [2.1 Set run case](#2.1-Set-run-case)
    - [2.2 Load MOHID griddata](#2.2-Load-MOHID-griddata)
    - [2.3 Define a bounding box](#2.3-Define-a-bounding-box)
    - [2.4 Set dates](#2.4-Set-dates)
- [3. Boundary Conditions](#3.-Boundary-Conditions)
    - [3.1 Oceanic](#3.1-Oceanic)
        - [3.1.1 Create Copernicus Marine credentials file](#3.1.1-Create-Copernicus-Marine-credentials-file)
        - [3.1.2 Set CMEMS product](#3.1.2-Set-CMEMS-product)
        - [3.1.3 Download CMEMS](#3.1.3-Download-CMEMS)
    - [3.2 Meteorological](#3.2-Meteorological)
        - [3.2.1 Setup the CDS API personal access token](#3.2.1-Setup-the-CDS-API-personal-access-token)
        - [3.2.2 Download ERA5 Reanalysis](#3.2.2-Download-ERA5-Reanalysis)
    - [3.3 Tide](#3.3-Tide)
        - [3.3.1 Download FES2014.zip](#3.3.1-Download-FES2014.zip)
        - [3.3.2 Crop FES2014.hdf5 to your grid area](#3.3.2-Crop-FES2014.hdf5-to-your-grid-area)
        - [3.3.3 Plot a specific dataset (e.g., M2 amplitude)](#3.3.3-Plot-a-specific-dataset-(e.g.,-M2-amplitude))

# 1. Import required libraries

In [None]:
import copernicusmarine
import cdsapi
import os
from ipyleaflet import Map, TileLayer, DrawControl, GeoJSON, Marker
import json
import re
import datetime
import time
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, to_hex
import ipywidgets as widgets
from IPython.display import display
import pandas as pd
import shutil
import subprocess
import sys
import matplotlib as mpl
from folium.plugins import MeasureControl
import glob
import zipfile
import h5py
import requests
import pathlib
from tqdm import tqdm
import contextily as ctx

# 2. General options

## 2.1 Set run case

In [None]:
dirpath = "run_cases" 

name = "Coastal3D_Operational"

case_dir = (os.path.join(os.getcwd(),dirpath, name))

## 2.2 Load MOHID griddata

In [None]:
# Load grid data from file

grid_data = "Level1.dat"
file_path = (os.path.join(case_dir, "GeneralData\\Batim\\", grid_data))

# Load a MOHID grid data file
with open(file_path, 'r') as f:
    lines = f.readlines()

grid_data = []
x_coords = []
y_coords = []
n_rows, n_cols = None, None  # Use None to detect missing values
start_reading_xx = False
start_reading_yy = False
start_reading_grid = False

for line in lines:
    line = line.strip()  # Remove leading and trailing spaces
    parts = line.split()

    if line.startswith("ILB_IUB"):
        n_rows = int(parts[3]) 
    elif line.startswith("JLB_JUB"):
        n_cols = int(parts[3])
    elif line.startswith("ORIGIN"):
        x0 = float(parts[2])
        y0 = float(parts[3])
    elif line.startswith("DX"):
        dx = float(parts[2])
    elif line.startswith("DY"):
        dy = float(parts[2])
    elif "<BeginXX>" in line:
        start_reading_xx = True
        continue
    elif "<EndXX>" in line:
        start_reading_xx = False
    elif start_reading_xx:
        try:
            x_coords.append(float(line))
        except ValueError:
            print(f"Warning: Skipping invalid line -> {line}")
    elif "<BeginYY>" in line:
        start_reading_yy = True
        continue
    elif "<EndYY>" in line:
        start_reading_yy = False
    elif start_reading_yy:
        try:
            y_coords.append(float(line))
        except ValueError:
            print(f"Warning: Skipping invalid line -> {line}")
    elif "<BeginGridData2D>" in line:
        start_reading_grid = True
        continue
    elif "<EndGridData2D>" in line:
        start_reading_grid = False
    elif start_reading_grid:
        try:
            grid_data.append(float(line))
        except ValueError:
            print(f"Warning: Skipping invalid line -> {line}")

# Debugging prints
print(f"Extracted Dimensions: n_rows={n_rows}, n_cols={n_cols}")
print(f"Grid Data Length: {len(grid_data)}")

if not x_coords:
    x_coords = np.linspace(x0, x0 + dx * n_cols, n_cols+1)
    y_coords = np.linspace(y0, y0 + dy * n_rows, n_rows+1)
else:
    x_coords = np.array(x_coords) + x0
    y_coords = np.array(y_coords) + y0
    
# Ensure grid dimensions are valid
if n_rows is None or n_cols is None:
    raise ValueError("Grid dimensions could not be determined from the file.")

# Check if data size matches expected shape
expected_size = n_rows * n_cols
if len(grid_data) != expected_size:
    raise ValueError(f"Mismatch: Grid data size {len(grid_data)} does not match expected ({expected_size}).")

# Convert to NumPy array and reshape correctly
zi = np.array(grid_data).reshape(n_rows, n_cols)

x_grid, y_grid = np.meshgrid(x_coords, y_coords)
    
print(f"Loaded grid data shape: {zi.shape}")

## 2.3 Define a bounding box

In [None]:
#Define a bounding box based on MOHID grid for boundary conditions download
np_x = np.array(x_grid)
np_y = np.array(y_grid)

c = 0.25 #degrees
min_lon=np.min(np_x)-c
min_lat=np.min(np_y)-c
max_lon=np.max(np_x)+c 
max_lat=np.max(np_y)+c

print(f"✅ Polygon Bounds:")
print(f"  - min_lon (West): {min_lon}")
print(f"  - min_lat (South): {min_lat}")
print(f"  - max_lon (East): {max_lon}")
print(f"  - max_lat (North): {max_lat}")

## 2.4 Set dates

In [None]:
#Set dates for boundary conditions download

#Define a 5-day interval if it´s the initial run for model warm-up
start_date_str = "2025-1-1" #"%Y-%m-%d"
end_date_str = "2025-1-6" #"%Y-%m-%d"

daily = 0 #if daily = 1, one day per file, else just one file for the interval end_date - start_date.   

forecast = 0 

#The keywords below are only used if forecast = 1
refday_to_start = 0 #0 is today, -1 yesterday, 1 tomorrow
number_of_runs = 1 #

# 3. Boundary Conditions

## 3.1 Oceanic 

### 3.1.1 Create Copernicus Marine credentials file
#It has to be done only once!

In [None]:
#The login command will check your Copernicus Marine credentials and create the configuration file. 
copernicusmarine.login()

### 3.1.2 Set CMEMS product

In [None]:
#6-hourly instanataneous
#product_id = ["cmems_mod_glo_phy_anfc_0.083deg_PT6H-i","cmems_mod_glo_phy-cur_anfc_0.083deg_PT6H-i","cmems_mod_glo_phy-so_anfc_0.083deg_PT6H-i","cmems_mod_glo_phy-thetao_anfc_0.083deg_PT6H-i"]

#daily mean
product_id = ["cmems_mod_glo_phy-cur_anfc_0.083deg_P1D-m","cmems_mod_glo_phy-so_anfc_0.083deg_P1D-m","cmems_mod_glo_phy-thetao_anfc_0.083deg_P1D-m","cmems_mod_glo_phy_anfc_0.083deg_P1D-m"]

start_depth = 0.49402499198913574
#end_depth = 0.49402499198913574
end_depth = 5727.9

### 3.1.3 Download CMEMS

In [None]:
backup_path = (os.path.join(case_dir, r"GeneralData\\BoundaryConditions\\CMEMS\\backup"))

#This file can later be used as input to CMEMS2HDF5.py for operational purposes
input_file = os.path.join(os.getcwd(),"work","CMEMS","Input_CMEMS2HDF5.py")

with open(input_file, 'w') as file:
    file.write(f"backup_path=r'{backup_path}'\n")
    file.write(f"daily={daily}\n")
    file.write(f"forecast={forecast}\n")
    file.write(f"number_of_runs={number_of_runs}\n")
    file.write(f"refday_to_start={refday_to_start}\n")
    file.write(f"product_id={product_id}\n")
    file.write(f"start_depth={start_depth}\n")
    file.write(f"end_depth={end_depth}\n")
    file.write(f"min_lon={min_lon}\n")
    file.write(f"max_lon={max_lon}\n")
    file.write(f"min_lat={min_lat}\n")
    file.write(f"max_lat={max_lat}\n")
    file.write(f"start_date_str='{start_date_str}'\n")
    file.write(f"end_date_str='{end_date_str}'\n")

%cd work/CMEMS/
%run CMEMS2HDF5.py

# Return to the original directory
%cd -
    

## 3.2 Meteorological

### 3.2.1 Setup the CDS API personal access token
It has to be done only once!

If you do not have an account yet, please register (https://cds.climate.copernicus.eu/).
If you are not logged in, please login.
Once logged in, copy the URL and key.

Create a file named .cdsapirc in your home directory.

$HOME/.cdsapirc (in your Unix/Linux environment)

%USERPROFILE%\.cdsapirc file (in your windows environment,%USERPROFILE% is usually located at C:\Users\Username folder). 

Paste the URL and key into .cdsapirc file.

The CDS API expects to find the .cdsapirc file in your home directory.

### 3.2.2 Download ERA5 Reanalysis

In [None]:
backup_path = (os.path.join(case_dir, r"GeneralData\\BoundaryConditions\\ERA5\\backup"))

#This file can later be used as input to ERA52HDF5.py for operational purposes
input_file = os.path.join(os.getcwd(),"work","ERA5","Input_ERA52HDF5.py")

with open(input_file, 'w') as file:
    file.write(f"backup_path=r'{backup_path}'\n")
    file.write(f"daily={daily}\n")
    file.write(f"forecast={forecast}\n")
    file.write(f"number_of_runs={number_of_runs}\n")
    file.write(f"refday_to_start={refday_to_start}\n")
    file.write(f"min_lon={min_lon}\n")
    file.write(f"max_lon={max_lon}\n")
    file.write(f"min_lat={min_lat}\n")
    file.write(f"max_lat={max_lat}\n")
    file.write(f"start_date_str='{start_date_str}'\n")
    file.write(f"end_date_str='{end_date_str}'\n")

%cd work/ERA5/
%run ERA52HDF5.py

# Return to the original directory
%cd -

## 3.3 Tide

### 3.3.1 Download FES2014.zip
It has to be done only once!

In [None]:
# URL of the ZIP file.
url = "http://www.mohid.com/PublicData/Products/Software/FES2014.zip"

# Define paths using pathlib for cross-platform compatibility.
home_dir = pathlib.Path.home()
local_zip_path = home_dir / "FES2014.zip"
extract_dir = home_dir / "FES2014"

def download_file(url, save_path):
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise exception for HTTP errors
        total_size = int(response.headers.get('content-length', 0))
        
        with open(save_path, 'wb') as f, tqdm(
            desc="Downloading",
            total=total_size,
            unit='B',
            unit_scale=True,
            unit_divisor=1024
        ) as bar:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
                bar.update(len(chunk))

        print("Download completed successfully.")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading the file: {e}")
        exit(1)

def extract_zip(zip_path, extract_to):
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        print(f"Extraction completed. Files are available in '{extract_to}' directory.")
    except zipfile.BadZipFile:
        print("Error: The downloaded file is not a valid ZIP archive.")
        exit(1)

def verify_extraction(extract_to):
    if not extract_to.exists() or not any(extract_to.iterdir()):
        print("Error: Extraction failed or directory is empty.")
        exit(1)

def cleanup(zip_path):
    try:
        zip_path.unlink()  # Deletes the ZIP file
        print(f"Deleted {zip_path}")
    except Exception as e:
        print(f"Error deleting ZIP file: {e}")

# Run the process
download_file(url, local_zip_path)

# Ensure the extraction directory exists
extract_dir.mkdir(exist_ok=True)

extract_zip(local_zip_path, extract_dir)
verify_extraction(extract_dir)
cleanup(local_zip_path)

### 3.3.2 Crop FES2014.hdf5 to your grid area

In [None]:
# Provide paths to your input (original) and output (cropped) HDF5 files.
home_dir = pathlib.Path.home()
input_file = home_dir / "FES2014" / 'FES2014.hdf5'
output_file = os.path.join(case_dir, r"GeneralData/BoundaryConditions/FES2014/FES2014.hdf5")
    
def get_bbox_indices(lon_arr, lat_arr, min_lon, max_lon, min_lat, max_lat):
    """
    Given 2D longitude and latitude arrays, compute the row and column indices
    corresponding to the region defined by the geographic bounding box.
    
    Parameters:
        lon_arr : 2D numpy array of longitudes.
        lat_arr : 2D numpy array of latitudes.
        min_lon, max_lon : float, desired longitude limits.
        min_lat, max_lat : float, desired latitude limits.
        
    Returns:
        A tuple (row_start, row_end, col_start, col_end) for slicing.
    """
    # Create a boolean mask selecting grid cells within the bounding box.
    mask = (lon_arr >= min_lon) & (lon_arr <= max_lon) & \
           (lat_arr >= min_lat) & (lat_arr <= max_lat)
    
    if not np.any(mask):
        raise ValueError("No grid points fall within the provided bounding box.")
    
    # Determine the minimal spanning rectangle for the selected grid cells.
    rows, cols = np.where(mask)
    row_start = rows.min()
    row_end = rows.max() + 1  # +1 because Python slices are half-open.
    col_start = cols.min()
    col_end = cols.max() + 1
    
    return row_start, row_end, col_start, col_end

def create_dataset_with_attrs(out_group, key, data, source_ds):
    """
    Create a new dataset in out_group with the provided data while preserving
    the source dataset's compression, chunking, and attributes. If the original 
    chunk shape is larger than the new data dimensions, adjust the chunk shape to 
    fit the new data.
    
    Parameters:
        out_group : HDF5 group (or file) where the new dataset will be created.
        key       : The name for the new dataset.
        data      : The data to be stored.
        source_ds : The source dataset from which attributes and storage options are copied.
    """
    ds_kwargs = {}
    # Preserve compression if present.
    if source_ds.compression:
        ds_kwargs['compression'] = source_ds.compression
        if source_ds.compression_opts is not None:
            ds_kwargs['compression_opts'] = source_ds.compression_opts
    # Preserve chunking but adjust so that chunks do not exceed the new data dimensions.
    if source_ds.chunks:
        # Recalculating the chunk shape so that each chunk dimension is at most the corresponding data dimension.
        new_chunks = tuple(min(new_dim, orig_chunk) for new_dim, orig_chunk in zip(data.shape, source_ds.chunks))
        ds_kwargs['chunks'] = new_chunks
    # Create the new dataset with the options.
    new_ds = out_group.create_dataset(key, data=data, **ds_kwargs)
    # Copy all dataset attributes.
    for attr_key, attr_val in source_ds.attrs.items():
        new_ds.attrs[attr_key] = attr_val
    return new_ds

def recursive_copy(in_group, out_group, crop_slice, grid_shape):
    # Extract the row and column start/end indices from crop_slice for clarity.
    row_start, row_end = crop_slice[0].start, crop_slice[0].stop
    col_start, col_end = crop_slice[1].start, crop_slice[1].stop

    for key in in_group:
        item = in_group[key]
        if isinstance(item, h5py.Dataset):
            data = item[()]
            if data.ndim >= 2:
                # Determine what the spatial dimensions of this dataset are.
                data_grid_shape = data.shape[-2:]
                if data_grid_shape == grid_shape:
                    # Data has the same shape as the lat/lon arrays.
                    slicing = (slice(None),) * (data.ndim - 2) + crop_slice
                elif data_grid_shape == (grid_shape[0] - 1, grid_shape[1] - 1):
                    # Data is defined on grid cells, which is one less in each dimension.
                    adjusted_crop_slice = (slice(row_start, row_end - 1), slice(col_start, col_end - 1))
                    slicing = (slice(None),) * (data.ndim - 2) + adjusted_crop_slice
                else:
                    # If the dataset doesn't match either expected shape, copy it unchanged.
                    slicing = None

                if slicing is not None:
                    cropped = data[slicing]
                    create_dataset_with_attrs(out_group, key, cropped, item)
                else:
                    create_dataset_with_attrs(out_group, key, data, item)
            else:
                create_dataset_with_attrs(out_group, key, data, item)
        elif isinstance(item, h5py.Group):
            # Recursively copy groups and their attributes.
            new_group = out_group.create_group(key)
            for attr_key, attr_val in item.attrs.items():
                new_group.attrs[attr_key] = attr_val
            recursive_copy(item, new_group, crop_slice, grid_shape)

def crop_hdf5_file(input_file, output_file, min_lon, max_lon, min_lat, max_lat):
    """
    Crop datasets within an HDF5 file that share a common grid, based on a geographical
    bounding box specified by min/max longitudes and latitudes. Datasets whose last two
    dimensions match the grid will be cropped.
    
    Parameters:
        input_file  : Path to the original HDF5 file.
        output_file : Path to the new (cropped) HDF5 file.
        min_lon, max_lon : Geographic longitude limits.
        min_lat, max_lat : Geographic latitude limits.
    """
    # First, determine the crop indices using the grid datasets.
    with h5py.File(input_file, 'r') as fin:
        grid = fin['Grid']  # assumes a 'grid' group exists.
        lon_arr = grid['Longitude'][()]
        lat_arr = grid['Latitude'][()]
        grid_shape = lon_arr.shape
        row_start, row_end, col_start, col_end = get_bbox_indices(
            lon_arr, lat_arr, min_lon, max_lon, min_lat, max_lat
        )
        print(f"Cropping rows: {row_start} to {row_end}, columns: {col_start} to {col_end}")
        crop_slice = (slice(row_start, row_end), slice(col_start, col_end))

    # Now recursively copy and crop datasets as needed.
    with h5py.File(input_file, 'r') as fin, h5py.File(output_file, 'w') as fout:
        recursive_copy(fin, fout, crop_slice, grid_shape)
    
    print(f"Cropping completed. Output saved to '{output_file}'.")

if __name__ == '__main__':

    # Run the cropping function.
    crop_hdf5_file(input_file, output_file, min_lon, max_lon, min_lat, max_lat)

### 3.3.3 Plot a specific dataset (e.g., M2 amplitude)

In [None]:
# Specify the HDF5 file path and dataset key.
file_path =  os.path.join(case_dir, "GeneralData/BoundaryConditions/FES2014/FES2014.hdf5")
dataset_key = '/Results/water level/M2/amplitude'  # Update as needed.
    
def plot_dataset(file_path, dataset_key):
    """
    Open an HDF5 file and plot the specified dataset with several enhancements:
    
    - Masks invalid data values below -99.
    - Uses coordinate arrays (if available) for georeferenced plotting.
    - Automatically obtains a satellite basemap from an online tile service (Esri World Imagery)
      and overlays it as the background.
    - Saves the output plot as an image file and displays it inline (suitable for a Jupyter Notebook).
    
    Parameters:
        file_path (str): Path to the HDF5 file.
        dataset_key (str): The key/path of the dataset to plot.
    """
    with h5py.File(file_path, 'r') as f:
        data = f[dataset_key][()]
        print(f"Loaded dataset '{dataset_key}' with shape {data.shape}.")

        # Mask invalid values below -99.
        data = np.ma.masked_where(data < -99, data)
        print("Applied mask for values below -99.")

        # Construct an output filename based on the dataset key.
        output_filename = f"plot_{dataset_key.strip('/').replace('/', '_')}.png"
        print(f"Output image filename will be: {output_filename}")

        # Attempt to load coordinate arrays for georeferencing.
        try:
            lon_arr = f["Grid/Longitude"][()]
            lat_arr = f["Grid/Latitude"][()]
            print(f"Found coordinate arrays with shapes {lon_arr.shape} and {lat_arr.shape}.")

            # Check if the data shape matches the expected cell-centered dimensions.
            if data.shape == (lon_arr.shape[0] - 1, lon_arr.shape[1] - 1):
                # Set up the figure and axis.
                fig, ax = plt.subplots(figsize=(8, 6))
                
                # Draw the data with pcolormesh.
                mesh = ax.pcolormesh(lon_arr, lat_arr, data, shading='auto', cmap='viridis', zorder=2)
                ax.set_xlabel("Longitude")
                ax.set_ylabel("Latitude")
                ax.set_title(f"Plot of {dataset_key}")
                
                # Automatically obtain and add a satellite basemap.
                # Since our coordinates are in EPSG:4326, we pass that as the CRS.
                # The Esri WorldImagery provider offers pretty good satellite images.
                ctx.add_basemap(ax, crs="EPSG:4326", source=ctx.providers.Esri.WorldImagery, alpha=1., zorder=1)
                
                fig.colorbar(mesh, ax=ax, label='Data values')
                plt.savefig(output_filename)
                print(f"Saved plot to {output_filename}")
                plt.show()  # Shows the plot inline 
                plt.close()
                return
            else:
                print("Data shape does not match expected cell-centered dimensions. Using fallback imshow.")
        except KeyError:
            print("Coordinate arrays not found. Using imshow for plotting without basemap.")

        # Fallback: Plot using imshow when georeferenced coordinates are unavailable.
        plt.figure(figsize=(8, 6))
        im = plt.imshow(data, origin='lower', cmap='viridis')
        plt.xlabel("Column Index")
        plt.ylabel("Row Index")
        plt.title(f"Plot of {dataset_key}")
        plt.colorbar(im, label='Data values')
        plt.savefig(output_filename)
        print(f"Saved plot to {output_filename}")
        plt.show()
        plt.close()

if __name__ == '__main__':
    plot_dataset(file_path, dataset_key)

# 3. Define sources

## 3.3 Draw markers on the map to define the source coordinates

In [None]:
import matplotlib.colors as mcolors

# Start timing
start_time = time.time()

LonGrid = np.array(x_grid)
LatGrid = np.array(y_grid)
min_lon, max_lon = LonGrid.min(), LonGrid.max()
min_lat, max_lat = LatGrid.min(), LatGrid.max()


# -------------------------
# Create the map.
m = Map(center=(LatGrid.mean(), LonGrid.mean()), zoom=8)
marker = None  # For interactive marker

# Store the block (batch) layers in a dictionary, keyed by (block_row, block_col).
block_layers = {}

# Set a block (batch) size. (This controls the spatial grouping.)
block_size = 10  # Adjust as needed.

# Precompute grid cell corners (for all cells).
# These arrays are of shape (M-1, N-1) if LonGrid and LatGrid are shape (M, N).
lon_sw = LonGrid[:-1, :-1]  # Southwest corner
lon_se = LonGrid[:-1,  1:]  # Southeast corner
lon_ne = LonGrid[1:,   1:]  # Northeast corner
lon_nw = LonGrid[1:,  :-1]  # Northwest corner

lat_sw = LatGrid[:-1, :-1]
lat_se = LatGrid[:-1,  1:]
lat_ne = LatGrid[1:,   1:]
lat_nw = LatGrid[1:,  :-1]

# -------------------------
# For discrete (flat) color mapping, we’ll use globals for our binning.
_nbins = 10
_bins = None
_discrete_colors = None

def map_value_to_color(value):
    if value == -99:
        return "#ffffff00"  # Transparent for invalid cells.
    # np.digitize returns indices in 1.._nbins+1; subtract 1 for 0-index.
    bin_index = np.digitize(value, _bins) - 1
    bin_index = int(np.clip(bin_index, 0, _nbins - 1))
    return _discrete_colors[bin_index]

def precompute_color_grid(zi, nbins=10):
    """
    Precompute a flat color mapping for the depth grid zi.
    Only valid cells (≠ -99) get a color from the viridis colormap.
    Also sets the global _bins, _nbins, and _discrete_colors for later use.
    """
    global _bins, _nbins, _discrete_colors
    _nbins = nbins
    valid = zi != -99
    if np.any(valid):
        valid_min = zi[valid].min()
        valid_max = zi[valid].max()
    else:
        valid_min, valid_max = 0, 1
    _bins = np.linspace(valid_min, valid_max, nbins + 1)
    cmap = plt.colormaps.get_cmap('viridis')
    _discrete_colors = [mcolors.to_hex(c) for c in cmap(np.linspace(0, 1, nbins))]
    # Optionally, you can return a full color array:
    vectorized_map = np.vectorize(map_value_to_color)
    return vectorized_map(zi)

# Precompute colors on the entire grid (used only for initial reference).
color_mapped_zi = precompute_color_grid(zi, nbins=_nbins)

# -------------------------
# New: Generate a GeoJSON FeatureCollection for a given block (spatial batch).
def generate_block_geojson(block_row, block_col, block_size):
    """
    Create a GeoJSON FeatureCollection for cells within one block.
    The block covers cell indices:
         i from block_row*block_size to min((block_row+1)*block_size, total rows)
         j from block_col*block_size to min((block_col+1)*block_size, total cols)
    Only cells where zi != -99 are rendered.
    """
    features = []
    total_rows, total_cols = zi.shape
    row_start = block_row * block_size
    row_end = min((block_row + 1) * block_size, total_rows)
    col_start = block_col * block_size
    col_end = min((block_col + 1) * block_size, total_cols)
    for i in range(row_start, row_end):
        for j in range(col_start, col_end):
            # Skip cell if it's invalid.
            if zi[i, j] == -99:
                continue
            coordinates = [[
                [float(lon_sw[i, j]), float(lat_sw[i, j])],  # SW
                [float(lon_se[i, j]), float(lat_se[i, j])],  # SE
                [float(lon_ne[i, j]), float(lat_ne[i, j])],  # NE
                [float(lon_nw[i, j]), float(lat_nw[i, j])],  # NW
                [float(lon_sw[i, j]), float(lat_sw[i, j])]   # Close ring
            ]]
            feature = {
                "type": "Feature",
                "geometry": {
                    "type": "Polygon",
                    "coordinates": coordinates
                },
                "properties": {
                    "fill": map_value_to_color(zi[i, j]),
                    "stroke": "#000000",
                    "fill-opacity": 0.5,
                    "stroke-width": 0.2,
                    "i": i,
                    "j": j
                }
            }
            features.append(feature)
    return {"type": "FeatureCollection", "features": features}

# -------------------------
# Update (or initially generate) all blocks.
def update_all_blocks():
    global block_layers
    # Remove any existing block layers.
    for layer in block_layers.values():
        m.remove_layer(layer)
    block_layers = {}
    
    total_rows, total_cols = zi.shape
    n_block_rows = (total_rows + block_size - 1) // block_size
    n_block_cols = (total_cols + block_size - 1) // block_size
    
    for br in range(n_block_rows):
        for bc in range(n_block_cols):
            fc = generate_block_geojson(br, bc, block_size)
            # Only add a block if it contains any features.
            if fc["features"]:
                layer = GeoJSON(
                    data=fc,
                    style_callback=lambda feature: {
                        "fillColor": feature["properties"]["fill"],
                        "color": feature["properties"]["stroke"],
                        "weight": feature["properties"]["stroke-width"],
                        "fillOpacity": feature["properties"]["fill-opacity"],
                    }
                )
                m.add_layer(layer)
                block_layers[(br, bc)] = layer

# For initial rendering, generate all blocks.
update_all_blocks()
print(f"Total time: {time.time() - start_time:.2f} seconds")
#display(m)

# -------------------------------
# Global store for drawn markers
# -------------------------------
# Dictionary mapping marker_id -> {'location': [lon, lat], 'name': str}
markers_dict = {}
marker_counter = 0  # Unique marker counter

def ask_marker_name(marker, marker_id):
    """
    Display a text input widget to ask for a marker name.
    When confirmed, assign the name to the marker and update the global dictionary.
    """
    name_input = widgets.Text(
        value='',
        placeholder='Enter source name',
        description='Source name:',
        disabled=False,
        style={'description_width': 'initial'}
    )
    confirm_button = widgets.Button(
        description='Confirm',
        disabled=False,
        button_style='success'
    )
    
    def on_button_clicked(b):
        marker_name = name_input.value.strip()
        if not marker_name:
            marker_name = f"Marker {marker_id}"
        # Attach the marker name
        marker.marker_name = marker_name
        # Update the marker entry with the name
        markers_dict[marker_id]['name'] = marker_name
        print(f"Marker {marker_id} named '{marker_name}'")
        widget_box.close()  # Close the widget after confirmation
        
    confirm_button.on_click(on_button_clicked)
    widget_box = widgets.VBox([name_input, confirm_button])
    display(widget_box)

def handle_draw(target, action, geo_json):
    global marker_counter, markers_dict
    
    if action == "created":
        coords = geo_json["geometry"]["coordinates"]  # [lon, lat]
        marker_id = marker_counter  # Assign an ID
        marker_counter += 1
        # Store with location and placeholder for marker name
        markers_dict[marker_id] = {'location': [coords[1], coords[0]], 'name': None}
    
        # Create a draggable marker
        custom_marker = Marker(location=[coords[1], coords[0]], draggable=True)
        custom_marker.marker_id = marker_id  # Attach marker ID
    
        # Ask user to enter a name for the new marker
        ask_marker_name(custom_marker, marker_id)
    
        # Function to update stored marker coordinates when moved
        def on_location_change(change, m_id=marker_id):
            new_location = change["new"]  # New [lat, lon]
            markers_dict[m_id]['location'] = [new_location[1], new_location[0]]
            print(f"Marker {m_id} moved to: {markers_dict[m_id]['location']}")
    
        # Observe marker movement
        custom_marker.observe(on_location_change, names="location")
    
        # Add the new draggable marker to the map
        m.add_layer(custom_marker)
        
        # Remove the default marker added by DrawControl
        for layer in list(m.layers):
            if isinstance(layer, GeoJSON) and layer.data.get("geometry", {}).get("type", "") == "Point":
                m.remove_layer(layer)
    
    elif action == "edited":
        print("Edit event received; marker updates are handled via the location observer.")
    
    elif action == "deleted":
        features = geo_json.get("features", [])
        if not features:
            features = [geo_json]
        for feature in features:
            marker_id = feature.get("properties", {}).get("marker_id")
            if marker_id is not None and marker_id in markers_dict:
                print(f"Deleted marker {marker_id} named '{markers_dict[marker_id]['name']}' at {markers_dict[marker_id]['location']}")
                markers_dict.pop(marker_id)
        print("Current markers after deletion:", markers_dict)

draw_control = DrawControl()
# For drawing tools we don't need, use empty dictionaries.
draw_control.polygon   = {}
draw_control.polyline  = {}
draw_control.rectangle = {}
draw_control.circle    = {}
# Enable marker drawing button by assigning a non-empty dictionary.
draw_control.marker    = {"repeatMode": False}

draw_control.on_draw(handle_draw)
m.add_control(draw_control)

# -------------------------------
# Display the map 
# -------------------------------
m

# 4. Setup MOHID Water input files

In [None]:
data_dir = os.path.join(case_dir,"Level_1/data")

In [None]:
continuous = 0 #if initial run continuous=0, else continuous=1

## 4.1 Model

In [None]:
def write_date(file_name, next_start_date, next_end_date):
    # Read all lines from the file
    with open(file_name, 'r') as file:
        file_lines = file.readlines()

    # Patterns that capture formatting groups:
    # This pattern assumes your lines are like:
    # [optional leading spaces]START[spaces][:][spaces][the rest...]
    pattern_start = re.compile(r"^(?P<leading>\s*START)(?P<sep>\s*:\s*).*$")
    pattern_end   = re.compile(r"^(?P<leading>\s*END)(?P<sep>\s*:\s*).*$")
    
    for i, line in enumerate(file_lines):
        # Match START and preserve groups: the keyword and the spacing around the colon.
        m = pattern_start.match(line)
        if m:
            # The new content uses the original captured keyword and separators.
            # Dte format"YYYY MM DD " followed by "0 0 0".
            new_line = f"{m.group('leading')}{m.group('sep')}{next_start_date.strftime('%Y %m %d ')}0 0 0\n"
            file_lines[i] = new_line
            continue

        # Similarly for END lines.
        m = pattern_end.match(line)
        if m:
            new_line = f"{m.group('leading')}{m.group('sep')}{next_end_date.strftime('%Y %m %d ')}0 0 0\n"
            file_lines[i] = new_line
            continue

    # Write the updated lines back to the file.
    with open(file_name, 'w') as file:
        file.writelines(file_lines)

file_name = os.path.join(data_dir, "Model_1.dat")
write_date(file_name, start_date, end_date)

# 5. Run MOHID Water

# 6. Visualize results