# MOHID Lagrangian

This Jupyter Notebook aims to help implement and run the MOHID Lagrangian model.

***
**Note 1**: Execute each cell through the <button class="btn btn-default btn-xs"><i class="icon-play fa fa-play"></i></button> button from the top MENU (or keyboard shortcut `Shift` + `Enter`).<br>
<br>
**Note 2**: Use the Kernel and Cell menus to restart the kernel and clear outputs.<br>
***

# Table of contents
- [1. Import required libraries](#1.-Import-required-libraries)
- [2. Download CMEMS for the area of interest](#2.-Download-CMEMS-for-the-area-of-interest)
    - [2.1 Draw a polygon to select the CMEMS area for download](#2.1-Draw-a-polygon-to-select-the-CMEMS-area-for-download)
    - [2.2 Set CMEMS product, dates and depths](#2.2-Set-CMEMS-product,-dates-and-depths)
    - [2.3 Download CMEMS](#2.3-Download-CMEMS)
- [3. Define sources](#3.-Define-sources)
    - [3.1 Load a NetCDF dataset](#3.1-Load-a-NetCDF-dataset)
    - [3.2 Load a MOHID HDF5 dataset](#3.2-Load-a-MOHID-HDF5-dataset)
    - [3.3 Draw markers on the map to define the source coordinates](#3.3-Draw-markers-on-the-map-to-define-the-source-coordinates)
- [4. Setup MOHID Lagrangian xml input files](#4.-Setup-MOHID-Lagrangian-xml-input-files)
    - [4.1 Parameter definitions](#4.1-Parameter-definitions)
    - [4.2 Simulation definitions](#4.2-Simulation-definitions)
    - [4.3 Source definitions](#4.3-Source-definitions)
- [5. Run MOHID Lagrangian](#5.-Run-MOHID-Lagrangian)
- [6. Visualize the final results](#6.-Visualize-the-final-results)

# 1. Import required libraries

In [None]:
from update_xml_case import *
import copernicusmarine
import os
from ipyleaflet import Map, TileLayer, DrawControl, GeoJSON, Marker
import json
import re
import datetime
import time
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, to_hex
import ipywidgets as widgets
from IPython.display import display
import pandas as pd
import shutil
import subprocess
import sys
import vtk
import folium
import matplotlib as mpl
from folium.plugins import MeasureControl
import glob

In [None]:
#Set run case
name = "Plastic_Case"

dirpath = "run_cases" 

xml_file_path = f"{name}.xml"

# Construct the path and change the working directory
os.chdir(os.path.join(dirpath, name))

# 2. Download CMEMS for the area of interest
Skip this step if you have another hydrodynamic dataset. 

## 2.1 Draw a polygon to select the CMEMS area for download

In [None]:
# Create an ipyleaflet map centered at (0,0)
m = Map(center=[0, 0], zoom=2)

# Define WMTS Layer
wmts_layer = TileLayer(
    url="https://wmts.marine.copernicus.eu/teroWmts/?"
        "service=WMTS&request=GetTile&version=1.0.0&"
        "layer=GLOBAL_ANALYSISFORECAST_PHY_001_024/"
        "cmems_mod_glo_phy-cur_anfc_0.083deg_P1M-m_202406/sea_water_velocity&"
        "tilematrixset=EPSG:3857&tilematrix={z}&tilerow={y}&tilecol={x}&"
        "format=image/png&transparent=True",
    name="Sea Water Velocity"
)

# Add WMTS layer to the map
m.add_layer(wmts_layer)

# Create a DrawControl for user interaction
draw_control = DrawControl(
    polygon={},  # Empty dict disables polygon
    rectangle={"shapeOptions": {"color": "blue"}},  # Enable rectangles
    circle={},  # Empty dict disables circles
    polyline={},  # Empty dict disables polylines
    marker={}  # Empty dict disables markers
)

# Function to handle drawn shapes and print bounds
def handle_draw(target, action, geo_json):
    global min_lat, max_lat, min_lon, max_lon
    
    if action == "created":
        
        # Extract coordinates from the drawn shape
        coords = geo_json["geometry"]["coordinates"][0]

        # Get min/max latitude and longitude
        lats = [point[1] for point in coords]  
        lons = [point[0] for point in coords]  

        min_lat, max_lat = min(lats), max(lats)
        min_lon, max_lon = min(lons), max(lons)

        # Print results
        print(f"✅ Polygon Bounds:")
        print(f"  - min_lon (West): {min_lon}")
        print(f"  - min_lat (South): {min_lat}")
        print(f"  - max_lon (East): {max_lon}")
        print(f"  - max_lat (North): {max_lat}")

# Attach the handler correctly
draw_control.on_draw(handle_draw)

# Add DrawControl to the map
m.add_control(draw_control)

# Display the interactive map
m

## 2.2 Set CMEMS product, dates and depths

In [None]:
output_dir = os.path.join(os.getcwd(),"Boundary_Conditions")

#hourly instanataneous
#product_id = "cmems_mod_glo_phy_anfc_merged-uv_PT1H-i"

#6-hourly instanataneous
#product_id = "cmems_mod_glo_phy-cur_anfc_0.083deg_PT6H-i"

#daily mean
product_id = "cmems_mod_glo_phy-cur_anfc_0.083deg_P1D-m"

#fill in info for CMEMS download
start = datetime.date(2024,1,1)
end = datetime.date(2024,2,1)

start_depth = 0.49402499198913574
end_depth = 0.49402499198913574
#end_depth = 5727.9

## 2.3 Download CMEMS

In [None]:
#####################################################
def download_file():
 
        variable = ['uo','vo']            
                                
        copernicusmarine.subset(
           dataset_id = product_id,
           minimum_longitude = min_lon, maximum_longitude = max_lon,
           minimum_latitude = min_lat, maximum_latitude = max_lat,
           minimum_depth = start_depth, maximum_depth = end_depth, 
           start_datetime = str(start.strftime('%Y-%m-%d'))+' 00:00:00', 
           end_datetime = str(end.strftime('%Y-%m-%d'))+' 00:00:00',
           variables = variable, 
           output_directory = output_dir,
           output_filename = output_file_name,
           netcdf3_compatible = True)
                
#####################################################
     
output_file_name = "cmems_"+str(start.strftime("%Y%m%d")) + "_" + str(end.strftime("%Y%m%d") + ".nc")


if not os.path.exists(output_dir):
        os.makedirs(output_dir)

download_file()   


# 3. Define sources

## 3.1 Load a NetCDF dataset

In [None]:
output_dir = os.path.join(os.getcwd(), "Boundary_Conditions")
output_file_name = "CMEMS_cur.nc" #Load your own NetCDF instead of CMEMS, otherwise comment out this line
CurFName = os.path.join(output_dir, output_file_name)
dataset = xr.open_dataset(CurFName)

CurFName = os.path.join(output_dir, output_file_name) 
# Open the datafile
CurDS = xr.open_dataset(CurFName, engine="netcdf4")

U = dataset['uo'].isel(time=0).isel(depth=0).squeeze()  # Result is 2D: (lat, lon)
V = dataset['vo'].isel(time=0).isel(depth=0).squeeze()  # Result is 2D: (lat, lon)

# Show info of dataset
CurDS

## 3.2 Load a MOHID HDF5 dataset 

In [None]:
import os
import h5py
import xarray as xr
import numpy as np
from pathlib import Path

# Define the output directory and file path
output_dir = os.path.join(os.getcwd(), "Boundary_Conditions")
output_file_name = "Hydrodynamic_2_Surface.hdf5"
CurFName = os.path.join(output_dir, output_file_name)

# Open the HDF5 file using h5py and load data into NumPy arrays.
with h5py.File(CurFName, 'r') as f:
    # Access the 'Results' group for velocity data.
    res_group = f['Results']
    instant = "00001"  # Define the time instant.
    vel_u_key = f"velocity U_{instant}"
    vel_v_key = f"velocity V_{instant}"
    
    # Extract and squeeze the velocity modulus data.
    velocity_u = np.squeeze(res_group['velocity U'][vel_u_key][-1,:,:])
    velocity_v = np.squeeze(res_group['velocity V'][vel_v_key][-1,:,:])
    
    # Extract grid corner coordinates for longitude and latitude.
    lon_corners = (f['Grid']['Longitude'][:,0])
    lat_corners = (f['Grid']['Latitude'][0,:])

velocity_u = velocity_u.T
velocity_v = velocity_v.T

# Compute the center values by averaging adjacent elements
lon_center = (lon_corners[:-1] + lon_corners[1:]) / 2
lat_center = (lat_corners[:-1] + lat_corners[1:]) / 2

# Create an xarray Dataset with matching dimensions.
dataset = xr.Dataset(
    {
        "uo": (("latitude","longitude"), velocity_u),
        "vo": (("latitude","longitude"), velocity_v),
    },
    coords={
        "longitude": lon_center,
        "latitude": lat_center,
    }
)

U = dataset['uo']
V = dataset['vo']

U = U.where((U > -99) & (U != 0), np.nan)
V = V.where((V > -99) & (V != 0), np.nan)

## 3.3 Draw markers on the map to define the source coordinates

In [None]:
# -------------------------------
# Start timing
# -------------------------------
start_time = time.time()

zi = np.sqrt(U**2 + V**2)  # Velocity magnitude, 2D array

# -------------------------------
# Optional Downsampling for large datasets
# -------------------------------
downsample_factor = 4  # Adjust as needed.
zi = zi[::downsample_factor, ::downsample_factor]
lon = dataset['longitude'].values[::downsample_factor]
lat = dataset['latitude'].values[::downsample_factor]

print("zi:", zi.shape)
print("Latitude size:", lat.shape)
print("Longitude size:", lon.shape)

# -------------------------------
# Get coordinate grids
# -------------------------------
LonGrid, LatGrid = np.meshgrid(lon, lat)
        
# -------------------------------
# Precompute color mapping (vectorized)
# -------------------------------
colormap = plt.cm.viridis
norm = Normalize(vmin=np.nanmin(zi), vmax=np.nanmax(zi))
flat_colors = [to_hex(c) for c in colormap(norm(zi.values.ravel()))]
colors = np.array(flat_colors).reshape(zi.shape)
print(f"Time spent on color mapping: {time.time() - start_time:.2f} seconds")

# -------------------------------
# Precompute grid cell corners for GeoJSON layers
# -------------------------------
lon_sw = LonGrid[:-1, :-1]   # Southwest corner
lon_se = LonGrid[:-1,  1:]   # Southeast corner
lon_ne = LonGrid[1:,   1:]   # Northeast corner
lon_nw = LonGrid[1:,  :-1]   # Northwest corner

lat_sw = LatGrid[:-1, :-1]
lat_se = LatGrid[:-1,  1:]
lat_ne = LatGrid[1:,   1:]
lat_nw = LatGrid[1:,  :-1]

zi_cells = zi.values[:-1, :-1]
print("zi_cells:", zi_cells.shape)
def generate_geojson_batches(batch_size=50):
    """
    Yields GeoJSON FeatureCollections for each batch of grid cells.
    Cells with NaN velocity magnitude are skipped.
    """
    n_rows, n_cols = lon_sw.shape
    for i in range(0, n_rows, batch_size):
        for j in range(0, n_cols, batch_size):
            i_end = min(i + batch_size, n_rows)
            j_end = min(j + batch_size, n_cols)
            features = []
            for r in range(i, i_end):
                for c in range(j, j_end):
                    if np.isnan(zi_cells[r, c]):
                        continue
                    feature = {
                        "type": "Feature",
                        "geometry": {
                            "type": "Polygon",
                            "coordinates": [[
                                [lon_sw[r, c], lat_sw[r, c]],
                                [lon_se[r, c], lat_se[r, c]],
                                [lon_ne[r, c], lat_ne[r, c]],
                                [lon_nw[r, c], lat_nw[r, c]]
                            ]]
                        },
                        "properties": {
                            "fill": colors[r, c],
                            "stroke": "#000000",
                            "fill-opacity": 0.5,
                            "stroke-width": 0.2
                        }
                    }
                    features.append(feature)
            if features:
                yield {"type": "FeatureCollection", "features": features}

# -------------------------------
# Create the map and add dataset layers
# -------------------------------
map_center = [(lat.min() + lat.max()) / 2, (lon.min() + lon.max()) / 2]
m = Map(center=map_center, zoom=4)

for batch_geojson in generate_geojson_batches():
    geojson_layer = GeoJSON(
        data=batch_geojson,
        style_callback=lambda feature: {
            "fillColor": feature["properties"]["fill"],
            "color": feature["properties"]["stroke"],
            "weight": feature["properties"]["stroke-width"],
            "fillOpacity": feature["properties"]["fill-opacity"]
        }
    )
    m.add_layer(geojson_layer)

print(f"Total time for GeoJSON layers: {time.time() - start_time:.2f} seconds")

# -------------------------------
# Global store for drawn markers
# -------------------------------
# Dictionary mapping marker_id -> {'location': [lon, lat], 'name': str}
markers_dict = {}
marker_counter = 0  # Unique marker counter

def ask_marker_name(marker, marker_id):
    """
    Display a text input widget to ask for a marker name.
    When confirmed, assign the name to the marker and update the global dictionary.
    """
    name_input = widgets.Text(
        value='',
        placeholder='Enter source name',
        description='Source name:',
        disabled=False,
        style={'description_width': 'initial'}
    )
    confirm_button = widgets.Button(
        description='Confirm',
        disabled=False,
        button_style='success'
    )
    
    def on_button_clicked(b):
        marker_name = name_input.value.strip()
        if not marker_name:
            marker_name = f"Marker {marker_id}"
        # Attach the marker name
        marker.marker_name = marker_name
        # Update the marker entry with the name
        markers_dict[marker_id]['name'] = marker_name
        print(f"Marker {marker_id} named '{marker_name}'")
        widget_box.close()  # Close the widget after confirmation
        
    confirm_button.on_click(on_button_clicked)
    widget_box = widgets.VBox([name_input, confirm_button])
    display(widget_box)

def handle_draw(target, action, geo_json):
    global marker_counter, markers_dict
    
    if action == "created":
        coords = geo_json["geometry"]["coordinates"]  # [lon, lat]
        marker_id = marker_counter  # Assign an ID
        marker_counter += 1
        # Store with location and placeholder for marker name
        markers_dict[marker_id] = {'location': [coords[1], coords[0]], 'name': None}
    
        # Create a draggable marker
        custom_marker = Marker(location=[coords[1], coords[0]], draggable=True)
        custom_marker.marker_id = marker_id  # Attach marker ID
    
        # Ask user to enter a name for the new marker
        ask_marker_name(custom_marker, marker_id)
    
        # Function to update stored marker coordinates when moved
        def on_location_change(change, m_id=marker_id):
            new_location = change["new"]  # New [lat, lon]
            markers_dict[m_id]['location'] = [new_location[1], new_location[0]]
            print(f"Marker {m_id} moved to: {markers_dict[m_id]['location']}")
    
        # Observe marker movement
        custom_marker.observe(on_location_change, names="location")
    
        # Add the new draggable marker to the map
        m.add_layer(custom_marker)
        
        # Remove the default marker added by DrawControl
        for layer in list(m.layers):
            if isinstance(layer, GeoJSON) and layer.data.get("geometry", {}).get("type", "") == "Point":
                m.remove_layer(layer)
    
    elif action == "edited":
        print("Edit event received; marker updates are handled via the location observer.")
    
    elif action == "deleted":
        features = geo_json.get("features", [])
        if not features:
            features = [geo_json]
        for feature in features:
            marker_id = feature.get("properties", {}).get("marker_id")
            if marker_id is not None and marker_id in markers_dict:
                print(f"Deleted marker {marker_id} named '{markers_dict[marker_id]['name']}' at {markers_dict[marker_id]['location']}")
                markers_dict.pop(marker_id)
        print("Current markers after deletion:", markers_dict)

draw_control = DrawControl()
# For drawing tools we don't need, use empty dictionaries.
draw_control.polygon   = {}
draw_control.polyline  = {}
draw_control.rectangle = {}
draw_control.circle    = {}
# Enable marker drawing button by assigning a non-empty dictionary.
draw_control.marker    = {"repeatMode": False}

draw_control.on_draw(handle_draw)
m.add_control(draw_control)

# -------------------------------
# Display the map 
# -------------------------------
m

# 4. Setup MOHID Lagrangian xml input files

## 4.1 Parameter definitions

In [None]:
# Get time limits (min and max time values in the dataset's 'time' variable)
start, end = min(dataset['time'].values), max(dataset['time'].values)

# Convert the numpy.datetime64 to a Python datetime object using pandas
Start = pd.to_datetime(start) #Date of initial instant based on nc file
End = pd.to_datetime(end) #Date of final instant based on nc file
#Start = datetime.datetime(2024, 1, 1, 0, 0, 0)
#End = datetime.datetime(2024, 1, 2, 0, 0, 0)

Integrator = 2 #Integration Algorithm 1:Euler, 2:Multi-Step Euler, 3:RK4 (default=1)
Threads = "auto" #Computation threads for shared memory computation (default=auto)
OutputWriteTime = 86400 #Time out data (seconds)
BufferSize = 86400 #control the amount of hydrodynamic data to store in RAM memory (seconds)

# Run the update function
update_parameter_definitions(xml_file_path,Start,End,Integrator,Threads,OutputWriteTime,BufferSize)

## 4.2 Simulation definitions

In [None]:
# -------------------------------
# Get area limits
# -------------------------------
min_lat, max_lat = min(dataset['latitude'].values), max(dataset['latitude'].values)
min_lon, max_lon = min(dataset['longitude'].values), max(dataset['longitude'].values)

resolution = 8000 #metres (m)
timestep = 180 #seconds (s)
BoundingBoxMin = min_lon,min_lat,-1 #defines the corners of your simulation domain x,y,z (deg,deg,m)
BoundingBoxMax = max_lon, max_lat,1 #defines the corners of your simulation domain x,y,z (deg,deg,m)
VerticalVelMethod = 3 #1:From velocity fields, 2:Divergence based, 3:Disabled. Default = 1
BathyminNetcdf = 0 #bathymetry is a property in the netcdf. 1:true, 0:false (computes from layer depth and openPoints. Default = 1
RemoveLandTracer = 0 #Remove tracers on land 0:No, 1:Yes. Default = 1
TracerMaxAge = 0 #maximum tracer age. Default = 0.0. read if > 0

# Run the update function
update_simulation_definitions(xml_file_path,resolution,timestep,BoundingBoxMin,BoundingBoxMax,VerticalVelMethod,BathyminNetcdf,RemoveLandTracer,TracerMaxAge)


## 4.3 Source definitions

In [None]:
rate_seconds = 3600 #emission step in seconds
rate_trcPerEmission = 5 #number of tracers emited every rate_seconds

update_source_definitions(xml_file_path, markers_dict,rate_seconds,rate_trcPerEmission)

# 5. Run MOHID Lagrangian 

In [None]:
dirout = f"{name}_out"

# Paths to executables and scripts
tools = r"../build/bin/RELEASE"
mohidlagrangian = os.path.join(tools, "MOHIDLagrangian.exe")

preprocessor_dir = r"../src/MOHIDLagrangianPreProcessor"
preprocessor = os.path.join(preprocessor_dir, "MOHIDLagrangianPreProcessor.py")

postprocessor_dir = r"../src/MOHIDLagrangianPostProcessor"
postprocessor = os.path.join(postprocessor_dir, "MOHIDLagrangianPostprocessor.py")

# Manage output directory
if os.path.exists(dirout):
    shutil.rmtree(dirout)
os.makedirs(dirout)

# Copy XML configuration file
shutil.copy(f"{name}.xml", dirout)

# Run Preprocessing
try:
    subprocess.run(
        [sys.executable, preprocessor, "-i", f"{dirout}/{name}.xml", "-o", dirout],
        check=True,
    )
except subprocess.CalledProcessError:
    print("Preprocessing failed.")
    sys.exit(1)

# Run Main Executable
try:
    subprocess.run(
        [mohidlagrangian, "-i", f"{dirout}/{name}.xml", "-o", dirout],
        check=True,
    )
except subprocess.CalledProcessError:
    print("Execution of MOHIDLagrangian failed.")
    sys.exit(1)

# Run Postprocessing
try:
    subprocess.run(
        [sys.executable, "-W", "ignore", postprocessor, "-i", f"{name}.xml", "-o", dirout],
        check=True,
    )
except subprocess.CalledProcessError:
    print("Postprocessing failed.")
    sys.exit(1)

print("All done!")

# 6. Visualize the final results

In [None]:
# Define the directory containing VTU files.
dirout = f"{name}_out"  # e.g., "case_out" or "Plastic_Case_out"

# Use glob to find all files with the pattern '{name}_*.vtu' in dirout.
vtu_files = glob.glob(os.path.join(dirout, f"{name}_*.vtu"))

if not vtu_files:
    raise ValueError(f"No VTU files matching '{name}_*.vtu' found in the directory: {dirout}")

# Function to extract the sequence number from a filename.
def extract_seq(filename):
    basename = os.path.basename(filename)
    match = re.search(r'_(\d+)\.vtu$', basename)
    if match:
        return int(match.group(1))
    else:
        # Return -1 if the file doesn't match the expected pattern.
        return -1

# Find the file with the maximum sequence number.
latest_vtu_file = max(vtu_files, key=extract_seq)
print("The latest VTU file is:", latest_vtu_file)

# Use the found file directly as the result filename.
ResFName = latest_vtu_file

# ---------------------------
# Read the VTU file using VTK and extract point data including 'source'
# ---------------------------
reader = vtk.vtkXMLUnstructuredGridReader()
reader.SetFileName(ResFName)
reader.Update()
data = reader.GetOutput()

points = data.GetPoints()
num_points = points.GetNumberOfPoints()

# Extract the 'source' field from point data.
source_field = data.GetPointData().GetArray("source")
if source_field is None:
    raise ValueError("No 'source' field found in the point data!")

lons, lats, sources = [], [], []
for i in range(num_points):
    point = points.GetPoint(i)
    lons.append(point[0]) 
    lats.append(point[1])
    src_val = source_field.GetValue(i)
    sources.append(src_val)

unique_sources = sorted(set(sources))
nunique = len(unique_sources)
print("Unique sources:", unique_sources)

# Use Matplotlib's "Dark2" colormap for a set of high-contrast colors.
cmap_dark2 = plt.get_cmap('Dark2')
colors_dark2 = [mpl.colors.rgb2hex(cmap_dark2(i)) for i in range(cmap_dark2.N)]
# If there are more sources than colors, cycle through the colormap:
colors_list = [colors_dark2[i % len(colors_dark2)] for i in range(nunique)]
color_mapping = {source: color for source, color in zip(unique_sources, colors_list)}

# ---------------------------
# Create a Folium map with a length scale and a measuring ruler.
# ---------------------------
center = [sum(lats) / len(lats), sum(lons) / len(lons)]
map_vtu = folium.Map(location=center, zoom_start=4, control_scale=True)

measure_control = MeasureControl(
    position="topleft",
    primary_length_unit="meters",
    secondary_length_unit="miles",
    active_color="red",
    completed_color="green"
)
map_vtu.add_child(measure_control)

# ---------------------------
# Add colored markers for each point.
# ---------------------------
for lat, lon, src in zip(lats, lons, sources):
    color = color_mapping.get(src, "black")
    folium.CircleMarker(
        location=[lat, lon],
        radius=1,
        color=color,
        fill=True,
        fill_color=color,
        fill_opacity=0.7,
        popup=f"Source: {src}"
    ).add_to(map_vtu)

map_vtu.save(os.path.join(dirout, "map_vtu.html"))
print("Map saved as map_vtu.html")

# Display the map inline in a Jupyter Notebook (if applicable)
map_vtu