# MOHID Preprocessing

- Create regular grids
- Download, load and filter coastlines for grid region
- Perform interpolation on bathymetric data
- Update griddata depth values
- Convert the griddata to a shapefile
- Plot bathymetry

***
**Note 1**: Execute each cell through the <button class="btn btn-default btn-xs"><i class="icon-play fa fa-play"></i></button> button from the top MENU (or keyboard shortcut `Shift` + `Enter`).<br>
<br>
**Note 2**: Use the Kernel and Cell menus to restart the kernel and clear outputs.<br>
***

# Table of contents
- [1. Import required libraries](#1.-Import-required-libraries)
- [2. Load the XYZ data](#2.-Load-the-XYZ-data)
- [3. Grid](#3.-Get)
    - [3.1 Load a previously generated Mohid grid file](#3.1-Load-a-previously-generated-Mohid-grid-file)
    - [3.2 Create a new grid](#3.2-Create-a-new-grid)
        - [3.2.1 Get grid dimensions and spacing](#3.2.1-Get-grid-dimensions-and-spacing)
        - [3.2.2 Grid generation](#3.2.2-Grid-generation)
        - [3.2.3 Save the grid to a MOHID-compatible file](#3.2.3-Save-the-grid-to-a-MOHID-compatible-file) 
- [4. Coastline](#4.-Coastline)
    - [4.1 Load coastline](#4.1-Load-coastline)
        - [4.1.1 Load your coastline](#4.1.1-Load-your-coastline)
        - [4.1.2 GSHHG Coastline Data](#4.1.2-GSHHG-Coastline-Data)
        - [4.1.3 OSM Coastline Data](#4.1.3-OSM-Coastline-Data)
    - [4.2 Filter Coastlines for Grid Region](#4.2-Filter-Coastlines-for-Grid-Region)
    - [4.3 Show coastline](#4.3-Show-coastline)
- [5. Bathymetry](#5.-Bathymetry)
    - [5.1 Load a previously generated Mohid griddata file](#5.1-Load-a-previously-generated-Mohid-griddata-file)
    - [5.2 Create a new griddata](#5.2-Create-a-new-griddata)
- [6. Visualize and update depth values by clicking on the map](#6.-Visualize-and-update-depth-values-by-clicking-on-the-map)
- [7. Save the griddata to a MOHID-compatible file](#7.-Save-the-griddata-to-a-MOHID-compatible-file)
- [8. Convert the griddata to a shapefile](#8.-Convert-the-griddata-to-a-shapefile)
- [9. Save shapefile to MOHID griddata](#9.-Save-shapefile-to-MOHID-griddata)
- [10. Plot MOHID griddata](#10.-Plot-MOHID-griddata)

# 1. Import required libraries

In [None]:
import numpy as np
import pandas as pd
from IPython.display import clear_output, display
from ipyleaflet import Map, Marker, basemaps, Popup, Polyline, Circle, GeoData, Polygon, GeoJSON, LayerGroup
from ipywidgets import HTML
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize, to_hex
import random
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter, label
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt
from urllib.request import urlopen, Request
import io
from io import StringIO
from PIL import Image
import geopandas as gpd
from shapely.geometry import Point, box, Polygon
import requests
import zipfile
import os
from scipy.ndimage import label, find_objects, gaussian_filter
import matplotlib.colors as mcolors
import ipywidgets as widgets
import shapefile
from shapely import contains_xy
from shapely.ops import unary_union
from mpl_toolkits.axes_grid1 import make_axes_locatable
import time
import fiona
import pyogrio
from pathlib import Path
from math import radians, cos, sin
import math

# 2. Load the XYZ data
This step is optional (only perform if you already have an xyz file with bathymetry data).
If you already have a previously generated Mohid griddata file, go to [5.1 Load a previously generated Mohid griddata file](#5.1-Load-a-previously-generated-Mohid-griddata-file)

In [None]:
# 1. Configuration
file_path = Path(r'Data/data.xyz')  # Replace with your file path

MAX_POINTS = 1000  # maximum number of circles to draw

# 2. Read raw lines
raw_lines = file_path.read_text().splitlines()

# 3. Detect markers
has_begin = any(line.strip().startswith("<begin_xyz>") for line in raw_lines)
has_end   = any(line.strip().startswith("<end_xyz>")   for line in raw_lines)

# 4. Extract only the data lines
data_lines = []
if has_begin and has_end:
    in_block = False
    for line in raw_lines:
        s = line.strip()
        if s.startswith("<begin_xyz>"):
            in_block = True
            continue
        if s.startswith("<end_xyz>"):
            in_block = False
            continue
        if in_block and s and not s.startswith("<"):
            data_lines.append(s)
else:
    # No markers: take every non-empty, non-tagged line
    for line in raw_lines:
        s = line.strip()
        if s and not s.startswith("<"):
            data_lines.append(s)

# 5. Load into a DataFrame
txt = "\n".join(data_lines)
data = pd.read_csv(
    StringIO(txt),
    header=None,
    names=["longitude", "latitude", "depth"],
    sep=r"\s+",
)
data_list = data.to_dict(orient="records")

# 6. Sample points if too many
if len(data_list) > MAX_POINTS:
    data_sample = random.sample(data_list, MAX_POINTS)
else:
    data_sample = data_list

# 7. Compute map center
center_lat = np.mean([pt["latitude"]  for pt in data_sample])
center_lon = np.mean([pt["longitude"] for pt in data_sample])

# 8. Initialize the map
grid_map = Map(center=(center_lat, center_lon), zoom=10)

# 9. Prepare colormap
cmap = cm.viridis
norm = Normalize(
    vmin=min(pt["depth"] for pt in data_sample),
    vmax=max(pt["depth"] for pt in data_sample),
)

def value_to_color(val):
    return to_hex(cmap(norm(val)))

# 10. Plot circles
for pt in data_sample:
    circle = Circle(
        location=(pt["latitude"], pt["longitude"]),
        radius=30,
        color=value_to_color(pt["depth"]),
        fill_color=value_to_color(pt["depth"]),
        fill_opacity=0.7,
        weight=1,
    )
    grid_map.add_layer(circle)

# 11. Display
grid_map

# 3. Grid
Load a previously generated Mohid grid file or create a new grid 

## 3.1 Load a previously generated Mohid grid file

In [None]:
# Load grid from file
file_path = Path(r"mohid_grid.grd") #define the path to your grid

# Load a MOHID grid data file
with open(file_path, 'r') as f:
    lines = f.readlines()

grid_data = []
x_coords = []
y_coords = []
polylines = []
n_rows, n_cols = None, None  # Use None to detect missing values
start_reading_xx = False
start_reading_yy = False
start_reading_grid = False

# 2. Initialize and display the map
#grid_map = Map(center=(0.0, 0.0), zoom=2)
#display(grid_map)

for line in lines:
    line = line.strip()  # Remove leading and trailing spaces
    parts = line.split()

    if line.startswith("ILB_IUB"):
        n_rows = int(parts[3]) 
    elif line.startswith("JLB_JUB"):
        n_cols = int(parts[3])
    elif line.startswith("ORIGIN"):
        x0 = float(parts[2])
        y0 = float(parts[3])
    elif line.startswith("DX"):
        dx = float(parts[2])
    elif line.startswith("DY"):
        dy = float(parts[2])
    elif line.startswith("GRID_ANGLE"):
        grid_angle = float(parts[2])
    elif "<BeginXX>" in line:
        start_reading_xx = True
        continue
    elif "<EndXX>" in line:
        start_reading_xx = False
    elif start_reading_xx:
        try:
            x_coords.append(float(line))
        except ValueError:
            print(f"Warning: Skipping invalid line -> {line}")
    elif "<BeginYY>" in line:
        start_reading_yy = True
        continue
    elif "<EndYY>" in line:
        start_reading_yy = False
    elif start_reading_yy:
        try:
            y_coords.append(float(line))
        except ValueError:
            print(f"Warning: Skipping invalid line -> {line}")

# Debugging prints
print(f"Extracted Dimensions: n_rows={n_rows}, n_cols={n_cols}")

grid_map = Map(center=(y0, x0), zoom=8)
display(grid_map)

# 5. Grid generation function
def generate_grid(x0, y0):
    global polylines, Xr, Yr
    
    # Create a regular mesh
    x_coords = np.linspace(x0, x0 + dx * n_cols, n_cols + 1)
    y_coords = np.linspace(y0, y0 + dy * n_rows, n_rows + 1)
    X, Y = np.meshgrid(x_coords, y_coords)
    
    # Precompute rotation
    theta = radians(grid_angle)
    ct, st = cos(theta), sin(theta)
    
    # Rotate each mesh point about (x0, y0)
    Xr = np.empty_like(X)
    Yr = np.empty_like(Y)
    for j in range(n_rows + 1):
        for i in range(n_cols + 1):
            dx_i = X[j, i] - x0
            dy_j = Y[j, i] - y0
            Xr[j, i] = x0 + (dx_i * ct - dy_j * st)
            Yr[j, i] = y0 + (dx_i * st + dy_j * ct)
    
    # Remove existing grid layers
    for layer in polylines:
        if layer in grid_map.layers:
            grid_map.remove_layer(layer)
    polylines.clear()
    
    # Build and draw rotated grid lines
    grid_group = LayerGroup()
    # Horizontal lines
    for j in range(n_rows + 1):
        pts = [(Yr[j, i], Xr[j, i]) for i in range(n_cols + 1)]
        grid_group.add_layer(Polyline(locations=pts, color="blue", weight=1))
    # Vertical lines
    for i in range(n_cols + 1):
        pts = [(Yr[j, i], Xr[j, i]) for j in range(n_rows + 1)]
        grid_group.add_layer(Polyline(locations=pts, color="blue", weight=1))
    
    grid_map.add_layer(grid_group)
    polylines.append(grid_group)
    
if not x_coords:
    #x_coords = np.linspace(x0, x0 + dx * n_cols, n_cols+1)
    #y_coords = np.linspace(y0, y0 + dy * n_rows, n_rows+1)
    generate_grid(x0, y0)
else:
    x_coords = np.array(x_coords) + x0
    y_coords = np.array(y_coords) + y0

    Xr, Yr = np.meshgrid(x_coords, y_coords)
    
# Ensure grid dimensions are valid
if n_rows is None or n_cols is None:
    raise ValueError("Grid dimensions could not be determined from the file.")
    
print(f"Loaded grid")

## 3.2 Create a new grid

### 3.2.1 Get grid dimensions and spacing

In [None]:
nx = int(input("Enter the number of cells in the x-direction (nx): "))
ny = int(input("Enter the number of cells in the y-direction (ny): "))

In [None]:
dx = float(input("Enter the cell size in the x-direction (dx, in degrees): "))
dy = float(input("Enter the cell size in the y-direction (dy, in degrees): "))

### 3.2.2 Grid generation
Click on the map to create the grid

In [None]:
# Full interactive rotatable grid with manual angle input
import numpy as np
from math import radians, cos, sin
from ipyleaflet import Map, Marker, Polyline, LayerGroup
from ipywidgets import HTML, FloatText

# 1. Parameters and state holders
marker = None
polylines = []
x0 = None
y0 = None
grid_angle = 0.0     # rotation angle in degrees

# Initialize the interactive map
try:
    grid_map
except NameError:
    grid_map = Map(center=(0.0, 0.0), zoom=2)

# Display the map
display(grid_map)

# 3. Instructions widget
instructions = HTML("""
<h4>Interactive Rotatable Grid</h4>
<ol>
  <li>Click on the map to set the grid origin.</li>
  <li>Enter a specific angle (in degrees) to rotate the grid.</li>
</ol>
""")
display(instructions)

# 4. Angle input widget
angle_input = FloatText(
    value=0.0,
    description='Angle (°):',
    step=0.1
)
display(angle_input)

def on_angle_change(change):
    global grid_angle
    grid_angle = change['new']
    if x0 is not None and y0 is not None:
        generate_grid(x0, y0)

angle_input.observe(on_angle_change, names='value')

# 5. Grid generation function
def generate_grid(x0, y0):
    global polylines, Xr, Yr
    
    # Create a regular mesh
    x_coords = np.linspace(x0, x0 + dx * nx, nx + 1)
    y_coords = np.linspace(y0, y0 + dy * ny, ny + 1)
    X, Y = np.meshgrid(x_coords, y_coords)
    
    # Precompute rotation
    theta = radians(grid_angle)
    ct, st = cos(theta), sin(theta)
    
    # Rotate each mesh point about (x0, y0)
    Xr = np.empty_like(X)
    Yr = np.empty_like(Y)
    for j in range(ny + 1):
        for i in range(nx + 1):
            dx_i = X[j, i] - x0
            dy_j = Y[j, i] - y0
            Xr[j, i] = x0 + (dx_i * ct - dy_j * st)
            Yr[j, i] = y0 + (dx_i * st + dy_j * ct)
    
    # Remove existing grid layers
    for layer in polylines:
        if layer in grid_map.layers:
            grid_map.remove_layer(layer)
    polylines.clear()
    
    # Build and draw rotated grid lines
    grid_group = LayerGroup()
    # Horizontal lines
    for j in range(ny + 1):
        pts = [(Yr[j, i], Xr[j, i]) for i in range(nx + 1)]
        grid_group.add_layer(Polyline(locations=pts, color="blue", weight=1))
    # Vertical lines
    for i in range(nx + 1):
        pts = [(Yr[j, i], Xr[j, i]) for j in range(ny + 1)]
        grid_group.add_layer(Polyline(locations=pts, color="blue", weight=1))
    
    grid_map.add_layer(grid_group)
    polylines.append(grid_group)

# 6. Click handler to set origin and draw grid
def handle_map_click(**kwargs):
    global marker, x0, y0
    if kwargs.get("type") == "click":
        lat, lng = kwargs["coordinates"]
        x0, y0 = lng, lat
        
        # Remove old marker
        if marker is not None and marker in grid_map.layers:
            grid_map.remove_layer(marker)
        
        # Add new marker
        marker = Marker(location=(y0, x0))
        grid_map.add_layer(marker)
        
        # Generate and display the grid
        generate_grid(x0, y0)

grid_map.on_interaction(handle_map_click)

### 3.2.3 Save the grid to a MOHID-compatible file

In [None]:
# Get current date and time
now = datetime.now()

# Format the date and time
formatted_date_time = now.strftime("%d-%m-%Y %H:%M:%S")

output_file = "mohid_grid.grd"
with open(output_file, "w") as f:
    f.write("PROJ4_STRING              : +proj=longlat +datum=WGS84 +no_defs\n")
    f.write("COMENT1                   : Grid generated by MOHID Jupyter Notebook\n")
    f.write("COMENT1                   : Generation Time: " + formatted_date_time + "\n")
    f.write("LATITUDE                  : " + str(y0) + "\n")
    f.write("LONGITUDE                 : " + str(x0) + "\n")
    f.write("COORD_TIP                 : 4\n")
    f.write("ILB_IUB                   : 1 " + str(ny) + "\n")
    f.write("JLB_JUB                   : 1 " + str(nx) + "\n")
    f.write("ORIGIN                    : " + str(x0) + " " + str(y0) + "\n")
    f.write("GRID_ANGLE                : " + str(grid_angle) + "\n")
    f.write("CONSTANT_SPACING_X        : 1\n")
    f.write("CONSTANT_SPACING_Y        : 1\n")
    f.write("DX                        : " + str(dx) + "\n")
    f.write("DY                        : " + str(dy) + "\n")

print(f"\nGrid saved to {output_file}")

# 4. Coastline

## 4.1 Load coastline
Load your coastline shapefile or download it from:
 - GSHHG (coarse resolution)
 - OSM (fine resolution)

### 4.1.1 Load your coastline
Shapefile or MOHID polygon file

#### Shapefile

In [None]:
coastline_shapefile = "Coastlines/Para_3_completo.shp" #define the path to your coastline data

#### MOHID polygon file 

In [None]:
# Paths
coastline_file = Path(r"")  #define the path to your coastline data

out_dir = Path("Coastlines")
out_dir.mkdir(parents=True, exist_ok=True)

# 1. Read and parse coordinates
coords = []
with coastline_file.open() as f:
    for line in f:
        line = line.strip()
        # Skip empty lines and polygon markers
        if not line or line.startswith("<beginpolygon>") or line.startswith("<endpolygon>"):
            continue
        parts = line.split()
        if len(parts) != 3:
            continue
        lon, lat, _ = parts
        coords.append((float(lon), float(lat)))

# 2. Ensure polygon closure
if coords and coords[0] != coords[-1]:
    coords.append(coords[0])

# 3. Build the Shapely polygon and wrap in a GeoDataFrame
polygon = Polygon(coords)
gdf = gpd.GeoDataFrame({"geometry": [polygon]}, crs="EPSG:4326")

# 4. Export to Shapefile (and GeoJSON as an example)
coastline_shapefile = out_dir / "output_polygon.shp"
gdf.to_file(coastline_shapefile, driver="ESRI Shapefile")

print(f"Shapefile written to: {coastline_shapefile.resolve()}")

### 4.1.2 GSHHG Coastline Data
Download (first time) or load GSHHG Coastline Data

In [None]:
# Define URL and local paths
gshhg_url = "https://www.soest.hawaii.edu/pwessel/gshhg/gshhg-shp-2.3.7.zip"
zip_path = "gshhg_shapefiles.zip"
extract_path = "gshhg_data"

# Download the file if not exists
if not os.path.exists(zip_path):
    print("Downloading GSHHG data...")
    response = requests.get(gshhg_url)
    with open(zip_path, "wb") as f:
        f.write(response.content)

# Extract the files
if not os.path.exists(extract_path):
    print("Extracting GSHHG data...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_path)

print("GSHHG data is ready.")

coastline_shapefile = os.path.join(extract_path, "GSHHS_shp", "f", "GSHHS_f_L1.shp") #if using downloaded GSHHG coastline data


### 4.1.3 OSM Coastline Data
Download (first time) or load OSM Coastline Data

In [None]:
import os
from zipfile import ZipFile
import requests
import numpy as np
import geopandas as gpd
from shapely.geometry import box
from ipyleaflet import Map, GeoJSON

# 1. Download OSM land-polygons (WGS84, split into tiles)
url      = "https://osmdata.openstreetmap.de/download/land-polygons-split-4326.zip"
zipfile  = "land-polygons.zip"
out_dir  = "land-polygons"

if not os.path.exists(zipfile):
    print("Downloading land polygons…")
    r = requests.get(url, stream=True)
    with open(zipfile, "wb") as f:
        for chunk in r.iter_content(1024):
            f.write(chunk)

if not os.path.exists(out_dir):
    print("Extracting…")
    with ZipFile(zipfile, "r") as z:
        z.extractall(out_dir)

# 2. Read the big land-polygons layer
#    The shapefile is called 'land_polygons.shp' inside the unzip folder
coastline_shapefile = os.path.join(out_dir, "land-polygons-split-4326", "land_polygons.shp")


## 4.2 Filter Coastlines for Grid Region

In [None]:
coastlines = gpd.read_file(coastline_shapefile)

# Define a bounding box
np_x = np.array(Xr)
np_y = np.array(Yr)

c = 3
x_min=np.min(np_x)-dx*c; y_min=np.min(np_y)-dy*c; x_max=np.max(np_x)+dx*c; y_max=np.max(np_y)+dy*c

bbox = [x_min, y_min, x_max, y_max] #[min_lon, min_lat, max_lon, max_lat]

# Create a bounding box geometry
bbox_geom = gpd.GeoDataFrame(geometry=[box(*bbox)], crs=coastlines.crs)

# Filter coastlines for the bounding box
filtered_coastlines = gpd.overlay(coastlines, bbox_geom, how='intersection')

## 4.3 Show coastline

In [None]:
# Create an interactive map centered in the region
center_lat = (bbox[1] + bbox[3]) / 2
center_lon = (bbox[0] + bbox[2]) / 2
m = Map(center=(center_lat, center_lon), zoom=8)

# Add coastlines
geo_layer = GeoData(geo_dataframe=filtered_coastlines, style={"color": "blue", "weight": 1})

# Add to map
m.add_layer(geo_layer)

# Display the map
m

# 5. Bathymetry
Load a previously generated Mohid griddata file or create a new griddata

## 5.1 Load a previously generated Mohid griddata file

In [None]:
# Load grid data from file
file_path = "mohid_griddata.dat" #define the path for your griddata file

# Load a MOHID grid data file
with open(file_path, 'r') as f:
    lines = f.readlines()

grid_data = []
x_coords = []
y_coords = []
nx, ny = None, None  # Use None to detect missing values
start_reading_xx = False
start_reading_yy = False
start_reading_grid = False

for line in lines:
    line = line.strip()  # Remove leading and trailing spaces
    parts = line.split()

    if line.startswith("ILB_IUB"):
        ny = int(parts[3]) 
    elif line.startswith("JLB_JUB"):
        nx = int(parts[3])
    elif line.startswith("ORIGIN"):
        x0 = float(parts[2])
        y0 = float(parts[3])
    elif line.startswith("DX"):
        dx = float(parts[2])
    elif line.startswith("DY"):
        dy = float(parts[2])
    elif line.startswith("GRID_ANGLE"):
        grid_angle = float(parts[2])        
    elif "<BeginXX>" in line:
        start_reading_xx = True
        continue
    elif "<EndXX>" in line:
        start_reading_xx = False
    elif start_reading_xx:
        try:
            x_coords.append(float(line))
        except ValueError:
            print(f"Warning: Skipping invalid line -> {line}")
    elif "<BeginYY>" in line:
        start_reading_yy = True
        continue
    elif "<EndYY>" in line:
        start_reading_yy = False
    elif start_reading_yy:
        try:
            y_coords.append(float(line))
        except ValueError:
            print(f"Warning: Skipping invalid line -> {line}")
    elif "<BeginGridData2D>" in line:
        start_reading_grid = True
        continue
    elif "<EndGridData2D>" in line:
        start_reading_grid = False
    elif start_reading_grid:
        try:
            grid_data.append(float(line))
        except ValueError:
            print(f"Warning: Skipping invalid line -> {line}")

# Debugging prints
print(f"Extracted Dimensions: n_rows={ny}, n_cols={nx}")
print(f"Grid Data Length: {len(grid_data)}")

# Ensure grid dimensions are valid
if ny is None or nx is None:
    raise ValueError("Grid dimensions could not be determined from the file.")
    
if not x_coords:
    # Create a regular mesh
    x_coords = np.linspace(x0, x0 + dx * nx, nx + 1)
    y_coords = np.linspace(y0, y0 + dy * ny, ny + 1)
    X, Y = np.meshgrid(x_coords, y_coords)
    
    # Precompute rotation
    theta = radians(grid_angle)
    ct, st = cos(theta), sin(theta)
    
    # Rotate each mesh point about (x0, y0)
    Xr = np.empty_like(X)
    Yr = np.empty_like(Y)
    for j in range(ny + 1):
        for i in range(nx + 1):
            dx_i = X[j, i] - x0
            dy_j = Y[j, i] - y0
            Xr[j, i] = x0 + (dx_i * ct - dy_j * st)
            Yr[j, i] = y0 + (dx_i * st + dy_j * ct)
else:
    x_coords = np.array(x_coords) + x0
    y_coords = np.array(y_coords) + y0

    Xr, Yr = np.meshgrid(x_coords, y_coords)
    

# Check if data size matches expected shape
expected_size = ny * nx
if len(grid_data) != expected_size:
    raise ValueError(f"Mismatch: Grid data size {len(grid_data)} does not match expected ({expected_size}).")

# Convert to NumPy array and reshape correctly
zi = np.array(grid_data).reshape(ny, nx)
    
print(f"Loaded grid data shape: {zi.shape}")

## 5.2 Create a new griddata
Interpolate bathymetry considering the loaded xyz data and coastline

In [None]:
# Remove NaN values from the dataset
mask = ~np.isnan(data['longitude']) & ~np.isnan(data['latitude']) & ~np.isnan(data['depth'])
lons, lats, depths = data['longitude'][mask], data['latitude'][mask], data['depth'][mask]

# Initialize the grid with -99
zi = np.full((Xr.shape[0] - 1, Xr.shape[1] - 1), -99, dtype=float)

# Convert coastline geometries to Shapely Polygons
coast_polygons = list(filtered_coastlines.geometry)

# Compute cell‐centers (average of four rotated corners)
xc = 0.25*(Xr[:-1, :-1] + Xr[:-1, 1:] + Xr[1:, :-1] + Xr[1:, 1:])
yc = 0.25*(Yr[:-1, :-1] + Yr[:-1, 1:] + Yr[1:, :-1] + Yr[1:, 1:])

# Assume 'filtered_coastlines' is a GeoSeries containing the coastline polygons
water_mask = ~contains_xy(filtered_coastlines.union_all(), xc, yc)

# Interpolate depth values only for water cells
zi[water_mask] = griddata(
    (lons, lats), depths,
    (xc[water_mask], yc[water_mask]),
    method="nearest"
)

# Apply depth constraints
#zi[zi < 0] = 0  # Ensure minimum depth is 0

# Smooth water depths without land bias
# Mask land as NaN, filter, then renormalize by weight
sigma      = 1.0        # Gaussian smoothing sigma
depths_nan = np.where(water_mask, zi, np.nan)
num_filt = gaussian_filter(np.nan_to_num(depths_nan), sigma=sigma)
w_filt   = gaussian_filter(np.isfinite(depths_nan).astype(float), sigma=sigma)
# Prepare an array of NaNs
zi_smooth = np.full_like(num_filt, np.nan)
# Only divide where weight > 0
valid = w_filt > 0
zi_smooth[valid] = num_filt[valid] / w_filt[valid]

# Restore nodata on land and copy smoothed water depths
nodata_val = -99.0
zi = np.where(water_mask, zi_smooth, nodata_val)

min_patch  = 20         # min cells to keep a water patch
labels, nfeat = label(water_mask)
for lbl in range(1, nfeat + 1):
    comp = (labels == lbl)
    if comp.sum() < min_patch:
        zi[comp] = nodata_val
        
print(f"\nBathymetric data interpolated")

# 6. Visualize and update depth values by clicking on the map
This code is currently efficient for not-so-large griddatas. To update large gridatas (e.g., 400 x 400 cells), convert first to shapefile, modify the cell depths in QGIS, and then convert the griddata to a MOHID-compatible file again. You can use this Jupyter Notebook to convert between formats.

In [None]:
# -----------------------
# Setup and timing
# -----------------------
start_time = time.time()

# Assume x_grid, y_grid, and zi are defined externally
LonGrid = np.array(Xr)
LatGrid = np.array(Yr)
min_lon, max_lon = LonGrid.min(), LonGrid.max()
min_lat, max_lat = LatGrid.min(), LatGrid.max()

# Create a dedicated output widget for displaying interactive controls.
output = widgets.Output()
display(output)

# -----------------------
# Create the Map
# -----------------------
m = Map(center=(LatGrid.mean(), LonGrid.mean()), zoom=8)
marker = None  # For interactive marker

# Dictionary to hold block layers (for efficient redraw)
block_layers = {}
block_size = 10  # Adjust as needed

# Precompute cell corners (if LonGrid & LatGrid shape is (M, N), cells are (M-1, N-1)).
lon_sw = LonGrid[:-1, :-1]  # Southwest corner
lon_se = LonGrid[:-1,  1:]  # Southeast corner
lon_ne = LonGrid[1:,   1:]  # Northeast corner
lon_nw = LonGrid[1:,  :-1]  # Northwest corner

lat_sw = LatGrid[:-1, :-1]
lat_se = LatGrid[:-1,  1:]
lat_ne = LatGrid[1:,   1:]
lat_nw = LatGrid[1:,  :-1]

# -----------------------
# Color Mapping Globals and Functions
# -----------------------
_nbins = 10
_bins = None
_discrete_colors = None

def map_value_to_color(value):
    if value == -99:
        return "#ffffff00"  # Transparent for invalid cells
    # np.digitize returns indices in 1.._nbins+1; subtract 1 for 0-indexing.
    bin_index = np.digitize(value, _bins) - 1
    bin_index = int(np.clip(bin_index, 0, _nbins - 1))
    return _discrete_colors[bin_index]

def precompute_color_grid(zi, nbins=10):
    """
    Precompute a flat color mapping for the depth grid zi.
    Only valid cells (≠ -99) get a color from the viridis colormap.
    Also sets the global _bins, _nbins, and _discrete_colors.
    """
    global _bins, _nbins, _discrete_colors
    _nbins = nbins
    valid = zi != -99
    if np.any(valid):
        valid_min = zi[valid].min()
        valid_max = zi[valid].max()
    else:
        valid_min, valid_max = 0, 1
    _bins = np.linspace(valid_min, valid_max, nbins + 1)
    cmap = plt.colormaps.get_cmap('viridis')
    _discrete_colors = [mcolors.to_hex(c) for c in cmap(np.linspace(0, 1, nbins))]
    vectorized_map = np.vectorize(map_value_to_color)
    return vectorized_map(zi)

# Precompute colors for the entire grid (used as initial reference).
color_mapped_zi = precompute_color_grid(zi, nbins=_nbins)

# -----------------------
# GeoJSON Block Generation & Map Layer Updates
# -----------------------
def generate_block_geojson(block_row, block_col, block_size):
    """
    Create a GeoJSON FeatureCollection for cells within one spatial block.
    Only cells with valid depth (zi != -99) are included.
    """
    features = []
    total_rows, total_cols = zi.shape
    row_start = block_row * block_size
    row_end = min((block_row + 1) * block_size, total_rows)
    col_start = block_col * block_size
    col_end = min((block_col + 1) * block_size, total_cols)
    
    for i in range(row_start, row_end):
        for j in range(col_start, col_end):
            if zi[i, j] == -99:
                continue
            coordinates = [[
                [float(lon_sw[i, j]), float(lat_sw[i, j])],
                [float(lon_se[i, j]), float(lat_se[i, j])],
                [float(lon_ne[i, j]), float(lat_ne[i, j])],
                [float(lon_nw[i, j]), float(lat_nw[i, j])],
                [float(lon_sw[i, j]), float(lat_sw[i, j])]  # Close the ring
            ]]
            feature = {
                "type": "Feature",
                "geometry": {
                    "type": "Polygon",
                    "coordinates": coordinates
                },
                "properties": {
                    "fill": map_value_to_color(zi[i, j]),
                    "stroke": "#000000",
                    "fill-opacity": 0.5,
                    "stroke-width": 0.2,
                    "i": i,
                    "j": j
                }
            }
            features.append(feature)
    return {"type": "FeatureCollection", "features": features}

def update_all_blocks():
    """
    Clear and update all blocks by generating GeoJSON layers for the entire grid.
    """
    global block_layers
    for layer in block_layers.values():
        m.remove_layer(layer)
    block_layers = {}
    
    total_rows, total_cols = zi.shape
    n_block_rows = (total_rows + block_size - 1) // block_size
    n_block_cols = (total_cols + block_size - 1) // block_size
    
    for br in range(n_block_rows):
        for bc in range(n_block_cols):
            fc = generate_block_geojson(br, bc, block_size)
            if fc["features"]:
                layer = GeoJSON(
                    data=fc,
                    style_callback=lambda feature: {
                        "fillColor": feature["properties"]["fill"],
                        "color": feature["properties"]["stroke"],
                        "weight": feature["properties"]["stroke-width"],
                        "fillOpacity": feature["properties"]["fill-opacity"],
                    }
                )
                m.add_layer(layer)
                block_layers[(br, bc)] = layer

# For initial rendering, generate all blocks.
update_all_blocks()
print(f"Total time: {time.time() - start_time:.2f} seconds")
display(m)

# -----------------------
# Function to Compute Grid Cell from Clicked Coordinate
# -----------------------
def get_grid_index(lon, lat):
    """
    Convert clicked coordinates (lon, lat) into grid indices (j, i)
    by comparing the click location to cell centers.
    
    Returns:
       (j, i): j is the column index and i is the row index in the depth grid.
       Returns (None, None) if click is outside the grid.
    """
    if lon < min_lon or lon > max_lon or lat < min_lat or lat > max_lat:
        print("Warning: Click is outside grid boundaries.")
        return None, None

    centers_lon = 0.25 * (LonGrid[:-1, :-1] + LonGrid[:-1, 1:] + LonGrid[1:, :-1] + LonGrid[1:, 1:])
    centers_lat = 0.25 * (LatGrid[:-1, :-1] + LatGrid[:-1, 1:] + LatGrid[1:, :-1] + LatGrid[1:, 1:])
    distances = np.sqrt((centers_lon - lon) ** 2 + (centers_lat - lat) ** 2)
    flat_index = np.argmin(distances)
    i, j = np.unravel_index(flat_index, centers_lon.shape)
    return j, i  # (column index, row index)

# -----------------------
# Function to Update Only the Block that Contains a Changed Cell
# -----------------------
def update_block_for_cell(i, j):
    """
    Regenerate the GeoJSON layer for the block containing cell (i, j).
    """
    global block_layers
    block_row = i // block_size
    block_col = j // block_size
    fc = generate_block_geojson(block_row, block_col, block_size)
    key = (block_row, block_col)
    
    if key in block_layers:
        m.remove_layer(block_layers[key])
    
    if fc["features"]:
        new_layer = GeoJSON(
            data=fc,
            style_callback=lambda feature: {
                "fillColor": feature["properties"]["fill"],
                "color": feature["properties"]["stroke"],
                "weight": feature["properties"]["stroke-width"],
                "fillOpacity": feature["properties"]["fill-opacity"],
            }
        )
        m.add_layer(new_layer)
        block_layers[key] = new_layer
    else:
        if key in block_layers:
            del block_layers[key]

# -----------------------
# Interactive Cell Depth Update Callback
# -----------------------
def update_depth(**kwargs):
    """
    On a map click, display a text box and a submit button (inside our output widget)
    to let the user update the cell's depth. Once submitted, only the affected block is updated.
    """
    global marker
    if kwargs.get("type") != "click":
        return
    
    lat_click, lon_click = kwargs["coordinates"]
    if marker:
        m.remove_layer(marker)
    marker = Marker(location=(lat_click, lon_click))
    m.add_layer(marker)
    
    j, i = get_grid_index(lon_click, lat_click)
    if j is None or i is None:
        return
    
    # Create the text box and button using ipywidgets.
    input_box = widgets.Text(
        placeholder='Enter new depth value...',
        description=f'Cell ({i}, {j}) Depth {zi[i, j]:.1f}:',
        disabled=False,
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='400px')
    )
    
    submit_button = widgets.Button(description="Submit")
    
    def on_submit(_):
        try:
            new_val = float(input_box.value)
            zi[i, j] = new_val
            update_block_for_cell(i, j)
        except ValueError:
            print("Invalid input. Enter a numeric value.")
    
    submit_button.on_click(on_submit)
    
    # Instead of a global clear_output that wipes the map, clear only our output widget.
    with output:
        output.clear_output(wait=True)
        display(input_box, submit_button)

# Attach the interactive callback for map clicks.
m.on_interaction(update_depth)

# 7. Save the griddata to a MOHID-compatible file

In [None]:
rows, cols = zi.shape

np_x = np.array(Xr)
np_y = np.array(Yr)

# Get current date and time
now = datetime.now()

# Format the date and time
formatted_date_time = now.strftime("%d-%m-%Y %H:%M:%S")

output_file = "mohid_griddata.dat"
with open(output_file, "w") as f:
    f.write("PROJ4_STRING              : +proj=longlat +datum=WGS84 +no_defs\n")
    f.write("COMENT1                   : Grid generated by MOHID Jupyter Notebook\n")
    f.write("COMENT1                   : Generation Time: " + formatted_date_time + "\n")
    f.write("LATITUDE                  : " + str(y0) + "\n")
    f.write("LONGITUDE                 : " + str(x0) + "\n")
    f.write("COORD_TIP                 : 4\n")
    f.write("ILB_IUB                   : 1 " + str(int(rows)) + "\n")
    f.write("JLB_JUB                   : 1 " + str(int(cols)) + "\n")
    f.write("ORIGIN                    : " + str(x0) + " " + str(y0) + "\n")
    f.write("GRID_ANGLE                : " + str(grid_angle) + "\n")
    f.write("CONSTANT_SPACING_X        : 1\n")
    f.write("CONSTANT_SPACING_Y        : 1\n")
    f.write("DX                        : " + str(dx) + "\n")
    f.write("DY                        : " + str(dy) + "\n")
    f.write("FILL_VALUE                : -99\n")
    f.write("<BeginGridData2D>\n")
    
    for i in range(rows):
        for j in range(cols):
            f.write(f"{zi[i][j]:.1f}\n")
            
    f.write("<EndGridData2D>")       
    

print(f"\nGrid saved to {output_file}")

# 8. Convert the griddata to a shapefile

In [None]:
# Create a shapefile writer object
w = shapefile.Writer('depth_grid_cells', shapefile.POLYGON)
w.field('i', 'N')       # Row index (i)
w.field('j', 'N')       # Column index (j)
w.field('Depth', 'F', decimal=2)  # Depth value (float)

# Loop through the grid and create polygon cells
for i in range(len(zi)):
    for j in range(len(zi[i])):
        depth = zi[i, j]  # Get depth value
        
        if depth <= -99:
            depth = float('nan')
        
        # Define the 4 corner points of the grid cell with double precision
        lon1, lat1 = round(Xr[i, j], 15), round(Yr[i, j], 15)         # Bottom-left
        lon2, lat2 = round(Xr[i, j+1], 15), round(Yr[i, j+1], 15)     # Bottom-right
        lon3, lat3 = round(Xr[i+1, j+1], 15), round(Yr[i+1, j+1], 15) # Top-right
        lon4, lat4 = round(Xr[i+1, j], 15), round(Yr[i+1, j], 15) # Top-left

        # Create a polygon for the grid cell
        w.poly([[ (lon1, lat1), (lon2, lat2), (lon3, lat3), (lon4, lat4), (lon1, lat1) ]])  
        
        # Add attributes (Row, Col, Depth)
        w.record(i+1, j+1, depth)

# Save the shapefile
w.close()

print("Shapefile 'depth_grid_cells.shp' created successfully.")


# 9. Save shapefile to MOHID griddata

In [None]:
# ───────── USER CONFIG ─────────
shapefile_path = "depth_grid_cells.shp"
output_file    = "mohid_griddata_from_shapefile.dat"
fill_val       = -99.0

# 1. Read & inspect
gdf = gpd.read_file(shapefile_path)

# 2. Handle missing CRS by inferring WGS84 if bounds look geographic
if gdf.crs is None:
    minx, miny, maxx, maxy = gdf.total_bounds
    if -180 <= minx <= 180 and -90 <= miny <= 90:
        gdf.set_crs("EPSG:4326", inplace=True)
        print("No CRS: assuming EPSG:4326")
    else:
        raise ValueError("Shapefile has no CRS and bounds exceed geographic ranges.")

# 3. Reproject to WGS84 if needed
if gdf.crs.to_string() != "EPSG:4326":
    gdf = gdf.to_crs("EPSG:4326")

# 4. Flatten MultiPolygons
gdf.geometry = gdf.geometry.map(
    lambda geom: geom if geom.geom_type == "Polygon" else unary_union(geom)
)

# 5. Compute dominant grid angle (0–180°)
angles = []
for poly in gdf.geometry:
    coords = list(poly.exterior.coords)
    for (x0, y0), (x1, y1) in zip(coords[:-1], coords[1:]):
        a = math.degrees(math.atan2(y1 - y0, x1 - x0)) % 180
        angles.append(a)
grid_angle = float(pd.Series(angles).round().mode().iloc[0]) - 90

# 6. Compute true dx, dy by rotating centroids
xs = np.array([p.centroid.x for p in gdf.geometry])
ys = np.array([p.centroid.y for p in gdf.geometry])
θ = math.radians(-grid_angle)
cos_t, sin_t = math.cos(θ), math.sin(θ)
xr = xs * cos_t - ys * sin_t
yr = xs * sin_t + ys * cos_t
ux = np.unique(np.round(xr, 6))
uy = np.unique(np.round(yr, 6))
dx = float(pd.Series(np.diff(ux)).median())
dy = float(pd.Series(np.diff(uy)).median())

# 7. Compute true rotated-grid origin
# 7.1 Add rotated centroid coords to GeoDataFrame
gdf["xr"] = xr
gdf["yr"] = yr

# 7.2 Find southwesternmost cell in rotated space
sw_index = (gdf["xr"] + gdf["yr"]).idxmin()
sw_cell = gdf.loc[sw_index]

# 7.3 Rotate each corner of that cell and pick the SW corner
corners = list(sw_cell.geometry.exterior.coords)
rotated_corners = [
    (
        x * cos_t - y * sin_t,     # x' (rotated)
        x * sin_t + y * cos_t,     # y' (rotated)
        x,                         # original lon
        y                          # original lat
    )
    for x, y in corners
]
sw_corner = min(rotated_corners, key=lambda t: (t[0], t[1]))
origin_x, origin_y = sw_corner[2], sw_corner[3]

# 8. Build depth array
rows, cols = int(gdf["i"].max()), int(gdf["j"].max())
zi = np.full((rows, cols), fill_val, dtype=float)
for _, row in gdf.iterrows():
    i, j = int(row["i"]) - 1, int(row["j"]) - 1
    d = row.get("Depth", np.nan)
    zi[i, j] = float(d) if not np.isnan(d) else fill_val

# 9. Write MOHID .dat
now = datetime.now().strftime("%d-%m-%Y %H:%M:%S")
with open(output_file, "w") as f:
    f.write("PROJ4_STRING              : +proj=longlat +datum=WGS84 +no_defs\n")
    f.write("COMMENT1                  : Grid generated by MOHID export\n")
    f.write(f"COMMENT2                  : Generated on {now}\n")
    f.write(f"LATITUDE                  : {origin_y}\n")
    f.write(f"LONGITUDE                 : {origin_x}\n")
    f.write("COORD_TIP                 : 4\n")
    f.write(f"ILB_IUB                   : 1 {rows}\n")
    f.write(f"JLB_JUB                   : 1 {cols}\n")
    f.write(f"ORIGIN                    : {origin_x} {origin_y}\n")
    f.write(f"GRID_ANGLE                : {grid_angle:.6f}\n")
    f.write("CONSTANT_SPACING_X        : 1\n")
    f.write("CONSTANT_SPACING_Y        : 1\n")
    f.write(f"DX                        : {dx:.6f}\n")
    f.write(f"DY                        : {dy:.6f}\n")
    f.write(f"FILL_VALUE                : {int(fill_val)}\n")
    f.write("<BeginGridData2D>\n")
    for ii in range(rows):
        for jj in range(cols):
            f.write(f"{zi[ii, jj]:.1f}\n")
    f.write("<EndGridData2D>\n")

print(f"Saved: {output_file} (angle={grid_angle:.2f}°, dx={dx:.6f}, dy={dy:.6f})")

# 10. Plot MOHID griddata

In [None]:
np_x = np.array(Xr)
np_y = np.array(Yr)

x_min = np.min(np_x)
y_min = np.min(np_y)
x_max = np.max(np_x)
y_max = np.max(np_y)

#extent = [x_max, x_min, y_max, y_min]

# Calculate the cell spacing
dx = np.abs(np_x[0][1] - np_x[0][0])
dy = np.abs(np_y[1][0] - np_y[0][0])

# Expand the extent by n cells in all directions
n = 5

lon_min = x_min - n * dx
lon_max = x_max + n * dx
lat_min = y_min - n * dy
lat_max = y_max + n * dy

extent = [
    lon_min,  # Left
    lon_max,  # Right
    lat_min,  # Bottom
    lat_max]  # Top

def calculate_zoom_level():
    """
    Calculate zoom level based on the geographic extent.
    """
    # Approximate calculation for zoom level based on extent
    lat_range = lat_max - lat_min
    lon_range = lon_max - lon_min

    # Smaller ranges mean higher zoom levels
    range_avg = max(lat_range, lon_range)  # Focus on the larger dimension
    zoom = int(np.log2(360 / range_avg))   # Base-2 logarithm for zoom estimation

    # Limit zoom levels to reasonable values (e.g., 1 to 19)
    return max(1, min(zoom +1, 19))  

# Calculate zoom level
zoom_level = calculate_zoom_level()
print(f"Automatically calculated zoom level: {zoom_level}")

# Set the image size and create the figure
fig = plt.figure(figsize=(15, 15))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_extent(extent)

def image_spoof(self, tile):
    url = self._image_url(tile)  # Get the URL of the street map API
    req = Request(url)  # Start request
    req.add_header('User-agent', 'Anaconda 3')  # Add user agent to request
    fh = urlopen(req)
    im_data = io.BytesIO(fh.read())  # Get image
    fh.close()  # Close URL
    img = Image.open(im_data)  # Open image with PIL
    img = img.convert(self.desired_tile_form)  # Set image format
    return img, self.tileextent(tile), 'lower'  # Reformat for cartopy

cimgt.GoogleTiles.get_image = image_spoof  # Reformat web request for street map spoofing
osm_img = cimgt.GoogleTiles(style='satellite')
ax.add_image(osm_img, zoom_level) #you can increase the zoom level for better resolution of the satellite image

zi_mask = np.ma.masked_array(zi, zi <= -99)  # Mask to set transparency to a certain value

# Normalizing the data for the colorbar
norm = Normalize(vmin=np.min(zi_mask), vmax=np.max(zi_mask))

# Plot pcolormesh
pc = ax.pcolormesh(Xr, Yr, zi_mask, cmap='viridis', norm=norm)

# Plot scatter data if available
#if 'depths' in locals() and 'lons' in locals() and 'lats' in locals():
#    scatter = ax.scatter(lons, lats, c=depths, s=0.1, cmap='viridis', norm=norm, label='Data Points')

#Title
plt.title('Interpolated Bathymetric Data', fontsize=18, loc='center')

# Adjust the size of the colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1, axes_class=plt.Axes)

cbar = plt.colorbar(pc, cax=cax, orientation="vertical")
cbar.set_label('Depth (m)', labelpad=25, rotation=270, fontsize=16)
cbar.ax.tick_params(labelsize=14)

plt.savefig(rf'Griddata.png', format='png', dpi=300, bbox_inches='tight')