# MOHID Lagrangian

This Jupyter Notebook aims to help implement and run the MOHID Lagrangian model.

***
**Note 1**: Execute each cell through the <button class="btn btn-default btn-xs"><i class="icon-play fa fa-play"></i></button> button from the top MENU (or keyboard shortcut `Shift` + `Enter`).<br>
<br>
**Note 2**: Use the Kernel and Cell menus to restart the kernel and clear outputs.<br>
***

# Table of contents
- [1. Import required libraries](#1.-Import-required-libraries)
- [2. General options](#2.-General-options)
    - [2.1 Set run case](#2.1-Set-run-case)
    - [2.2 Set dates](#2.2-Set-dates)
    - [2.3 Draw a polygon to select the area of interest](#2.3-Draw-a-polygon-to-select-the-area-of-interest)
- [3. Download currents for the area of interest](#3.-Download-currents-for-the-area-of-interest)
    - [3.1 Create Copernicus Marine credentials file](#3.1-Create-Copernicus-Marine-credentials-file)
    - [3.2 Set CMEMS product and depths](#3.2-Set-CMEMS-product-and-depths) 
    - [3.3 Download CMEMS](#3.3-Download-CMEMS)
- [4. Download wind field for the area of interest](#4.-Download-wind-field-for-the-area-of-interest)
    - [4.1 Setup the CDS API personal access token](#4.1-Setup-the-CDS-API-personal-access-token)
    - [4.2 Download ERA5 Reanalysis](#4.2-Download-ERA5-Reanalysis)
- [5. Load the hydrodynamic solution](#5.-Load-the-hydrodynamic-solution)
    - [5.1 Load a NetCDF file from CMEMS](#5.1-Load-a-NetCDF-file-from-CMEMS)
    - [5.2 Load a HDF5 file from MOHID](#5.2-Load-a-HDF5-file-from-MOHID)
- [6. Define sources location](#6.-Define-sources-location)
    - [6.1 Draw markers on the map to define the source coordinates](#6.1-Draw-markers-on-the-map-to-define-the-source-coordinates)
- [7. Setup MOHID Lagrangian xml input files](#6.-Setup-MOHID-Lagrangian-xml-input-files)
    - [7.1 Parameter definitions](#7.1-Parameter-definitions)
    - [7.2 Simulation definitions](#7.2-Simulation-definitions)
    - [7.3 Source definitions](#7.3-Source-definitions)
- [8. Run MOHID Lagrangian](#8.-Run-MOHID-Lagrangian)
- [9. Visualize results](#9.-Visualize-results)
    - [9.1 Particles](#9.1-Particles)
    - [9.2 Heatmap](#9.2-Heatmap)

# 1. Import required libraries

In [None]:
from update_xml_case import *
import copernicusmarine
import cdsapi
import zipfile
import os
from ipyleaflet import Map, TileLayer, DrawControl, GeoJSON, Marker
import json
import re
import datetime
import time
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, to_hex
import ipywidgets as widgets
from IPython.display import display
import pandas as pd
import shutil
import subprocess
import sys
import vtk
import folium
import matplotlib as mpl
from folium.plugins import HeatMap, HeatMapWithTime, MeasureControl
import glob
import h5py
import netCDF4
from pathlib import Path
from folium.plugins import TimestampedGeoJson

# 2. General options

## 2.1 Set run case

In [None]:
name = "Plastic_Case"

dirpath = "run_cases" 

xml_file_path = f"{name}.xml"

# Construct the path and change the working directory
os.chdir(os.path.join(dirpath, name))

## 2.2 Set dates

In [None]:
start_date = datetime.datetime(2025,1,1,0,0,0)
end_date = datetime.datetime(2025,2,1,0,0,0)

## 2.3 Draw a polygon to select the area of interest

In [None]:
# Create an ipyleaflet map centered at (0,0)
m = Map(center=[0, 0], zoom=2)

# Define WMTS Layer
wmts_layer = TileLayer(
    url="https://wmts.marine.copernicus.eu/teroWmts/?"
        "service=WMTS&request=GetTile&version=1.0.0&"
        "layer=GLOBAL_ANALYSISFORECAST_PHY_001_024/"
        "cmems_mod_glo_phy-cur_anfc_0.083deg_P1M-m_202406/sea_water_velocity&"
        "tilematrixset=EPSG:3857&tilematrix={z}&tilerow={y}&tilecol={x}&"
        "format=image/png&transparent=True",
    name="Sea Water Velocity",
    no_wrap=True
)

# Add WMTS layer to the map
m.add_layer(wmts_layer)

# Create a DrawControl for user interaction
draw_control = DrawControl(
    polygon={},  # Empty dict disables polygon
    rectangle={"shapeOptions": {"color": "blue"}},  # Enable rectangles
    circle={},  # Empty dict disables circles
    polyline={},  # Empty dict disables polylines
    marker={}  # Empty dict disables markers
)

# Function to handle drawn shapes and print bounds
def handle_draw(target, action, geo_json):
    global min_lat, max_lat, min_lon, max_lon
    
    if action == "created":
        
        # Extract coordinates from the drawn shape
        coords = geo_json["geometry"]["coordinates"][0]

        # Get min/max latitude and longitude
        lats = [point[1] for point in coords]  
        lons = [point[0] for point in coords]  

        min_lat, max_lat = min(lats), max(lats)
        min_lon, max_lon = min(lons), max(lons)

        # Print results
        print(f"✅ Polygon Bounds:")
        print(f"  - min_lon (West): {min_lon}")
        print(f"  - min_lat (South): {min_lat}")
        print(f"  - max_lon (East): {max_lon}")
        print(f"  - max_lat (North): {max_lat}")

# Attach the handler correctly
draw_control.on_draw(handle_draw)

# Add DrawControl to the map
m.add_control(draw_control)

# Display the interactive map
m

# 3. Download currents for the area of interest
Skip this step if you have another hydrodynamic dataset. 

## 3.1 Create Copernicus Marine credentials file
#It has to be done only once!

In [None]:
#The login command will check your Copernicus Marine credentials and create the configuration file. 
copernicusmarine.login()

## 3.2 Set CMEMS product and depths

In [None]:
output_dir_cmems = os.path.join(os.getcwd(),"nc_fields","currents")

#hourly instanataneous
#product_id = "cmems_mod_glo_phy_anfc_merged-uv_PT1H-i"

#6-hourly instanataneous
#product_id = "cmems_mod_glo_phy-cur_anfc_0.083deg_PT6H-i"

#daily mean
product_id = "cmems_mod_glo_phy-cur_anfc_0.083deg_P1D-m"

start_depth = 0.49402499198913574
end_depth = 0.49402499198913574
#end_depth = 5727.9

## 3.3 Download CMEMS

In [None]:
#####################################################
def download_file():
 
        variable = ['uo','vo']            
                                
        copernicusmarine.subset(
           dataset_id = product_id,
           minimum_longitude = min_lon, maximum_longitude = max_lon,
           minimum_latitude = min_lat, maximum_latitude = max_lat,
           minimum_depth = start_depth, maximum_depth = end_depth, 
           start_datetime = str(start_date.strftime('%Y-%m-%d'))+' 00:00:00', 
           end_datetime = str(end_date.strftime('%Y-%m-%d'))+' 00:00:00',
           variables = variable, 
           output_directory = output_dir_cmems,
           output_filename = output_file_cmems,
           netcdf3_compatible = True)
                
#####################################################
     
output_file_cmems = "cmems_"+str(start_date.strftime("%Y%m%d")) + "_" + str(end_date.strftime("%Y%m%d") + ".nc")


if not os.path.exists(output_dir_cmems):
        os.makedirs(output_dir_cmems)

nc_files = glob.iglob(os.path.join(output_dir_cmems,"*.nc"))
        
for filename in nc_files:
    os.remove(filename)
    
download_file()   


# 4. Download wind field for the area of interest

## 4.1 Setup the CDS API personal access token
It has to be done only once!

If you do not have an account yet, please register (https://cds.climate.copernicus.eu/). If you are not logged in, please login. Once logged in, copy the URL and key.

Create a file named .cdsapirc in your home directory.

$HOME/.cdsapirc (in your Unix/Linux environment)

%USERPROFILE%.cdsapirc file (in your windows environment,%USERPROFILE% is usually located at C:\Users\Username folder).

Paste the URL and key into .cdsapirc file.

The CDS API expects to find the .cdsapirc file in your home directory.

## 4.2 Download ERA5 Reanalysis

In [None]:
era5_dir = os.path.join(os.getcwd(),"nc_fields","winds")
target = os.path.join(os.getcwd(),"nc_fields","winds", "ERA5.zip")

if not os.path.exists(era5_dir):
        os.makedirs(era5_dir)

nc_files = glob.iglob(os.path.join(era5_dir,"*.nc"))
        
for filename in nc_files:
    os.remove(filename)  
            
# Extract days across months correctly
days = []
current_date = start_date
while current_date <= end_date:
    days.append(str(current_date.day))
    current_date += datetime.timedelta(days=1)
    
dataset = "reanalysis-era5-single-levels"
request = {
    "product_type": ["reanalysis"],
    "variable": [
                '10m_u_component_of_wind', 
                '10m_v_component_of_wind',
    ],
    "year": sorted([str(start_date.year), str(end_date.year)]),
    "month": sorted([str(start_date.month), str(end_date.month)]),
    "day": days,
    "time": [
        "00:00", "01:00", "02:00",
        "03:00", "04:00", "05:00",
        "06:00", "07:00", "08:00",
        "09:00", "10:00", "11:00",
        "12:00", "13:00", "14:00",
        "15:00", "16:00", "17:00",
        "18:00", "19:00", "20:00",
        "21:00", "22:00", "23:00"
    ],
    "data_format": "netcdf",
    "download_format": "zip",
    "area": [max_lat, min_lon, min_lat, max_lon]
}

#target = 'ERA5.zip'
client = cdsapi.Client()

try:
    client.retrieve(dataset, request, target)
    print("Download completed successfully!")
except Exception as e:
    print("An error occurred:", e)

try:
    # Open and extract the ZIP file
    with zipfile.ZipFile(target, 'r') as zip_ref:
        zip_ref.extractall(era5_dir)
    print("Extraction completed successfully.")
    
    # Remove the ZIP file after successful extraction
    os.remove(target)
    print("Zip file has been removed.")
    
except zipfile.BadZipFile:
    print("Error: The file is not a valid ZIP archive.")
except FileNotFoundError:
    print("Error: The ZIP file was not found.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

original_file_name = os.path.join(era5_dir,"data_stream-oper_stepType-instant.nc")
output_file_era5 = os.path.join(era5_dir,"era5_"+str(start_date.strftime("%Y%m%d")) + "_" + str(end_date.strftime("%Y%m%d") + ".nc"))
os.rename(original_file_name, output_file_era5)

## Open the file in update mode ('r+' allows in-place modifications)
#with netCDF4.Dataset(output_file_era5, mode='r+') as ds:
#    ds.renameVariable('valid_time', 'time')
#    ds.renameDimension('valid_time', 'time')

print(f"Files extracted to {era5_dir}")

# 5. Load the hydrodynamic solution
Load a NetCDF file from CMEMS or a HDF5 file from MOHID

## 5.1 Load a NetCDF file from CMEMS

In [None]:
output_dir_cmems = os.path.join(os.getcwd(),"nc_fields","currents")
output_file_cmems = "cmems_"+str(start_date.strftime("%Y%m%d")) + "_" + str(end_date.strftime("%Y%m%d") + ".nc")
#output_file_cmems = "cmems_20250101_20250105.nc" #Load your own NetCDF instead of CMEMS, otherwise comment out this line
CurFName = os.path.join(output_dir_cmems, output_file_cmems)
dataset = xr.open_dataset(CurFName)

CurFName = os.path.join(output_dir_cmems, output_file_cmems) 
# Open the datafile
CurDS = xr.open_dataset(CurFName, engine="netcdf4")

U = dataset['uo'].isel(time=0).isel(depth=0).squeeze()  # Result is 2D: (lat, lon)
V = dataset['vo'].isel(time=0).isel(depth=0).squeeze()  # Result is 2D: (lat, lon)

# Show info of dataset
CurDS

## 5.2 Load a HDF5 file from MOHID

In [None]:
# Define the output directory and file path
output_dir = os.path.join(os.getcwd(), "Boundary_Conditions")
output_file_name = "Hydrodynamic_2_Surface.hdf5"
CurFName = os.path.join(output_dir, output_file_name)

# Open the HDF5 file using h5py and load data into NumPy arrays.
with h5py.File(CurFName, 'r') as f:
    # Access the 'Results' group for velocity data.
    res_group = f['Results']
    instant = "00001"  # Define the time instant.
    vel_u_key = f"velocity U_{instant}"
    vel_v_key = f"velocity V_{instant}"
    
    # Extract and squeeze the velocity modulus data.
    velocity_u = np.squeeze(res_group['velocity U'][vel_u_key][-1,:,:])
    velocity_v = np.squeeze(res_group['velocity V'][vel_v_key][-1,:,:])
    
    # Extract grid corner coordinates for longitude and latitude.
    lon_corners = (f['Grid']['Longitude'][:,0])
    lat_corners = (f['Grid']['Latitude'][0,:])

velocity_u = velocity_u.T
velocity_v = velocity_v.T

# Compute the center values by averaging adjacent elements
lon_center = (lon_corners[:-1] + lon_corners[1:]) / 2
lat_center = (lat_corners[:-1] + lat_corners[1:]) / 2

# Create an xarray Dataset with matching dimensions.
dataset = xr.Dataset(
    {
        "uo": (("latitude","longitude"), velocity_u),
        "vo": (("latitude","longitude"), velocity_v),
    },
    coords={
        "longitude": lon_center,
        "latitude": lat_center,
    }
)

U = dataset['uo']
V = dataset['vo']

U = U.where((U > -99) & (U != 0), np.nan)
V = V.where((V > -99) & (V != 0), np.nan)

# 6. Define sources location

## 6.1 Draw markers on the map to define the source coordinates

In [None]:
# -------------------------------
# Start timing
# -------------------------------
start_time = time.time()

zi = np.sqrt(U**2 + V**2)  # Velocity magnitude, 2D array

# Create an Output widget to capture the callback prints
output = widgets.Output()
display(output)

# -------------------------------
# Optional Downsampling for large datasets
# -------------------------------
downsample_factor = 4  # Adjust as needed.
zi = zi[::downsample_factor, ::downsample_factor]
lon = dataset['longitude'].values[::downsample_factor]
lat = dataset['latitude'].values[::downsample_factor]

#print("zi:", zi.shape)
#print("Latitude size:", lat.shape)
#print("Longitude size:", lon.shape)

# -------------------------------
# Get coordinate grids
# -------------------------------
LonGrid, LatGrid = np.meshgrid(lon, lat)
        
# -------------------------------
# Precompute color mapping (vectorized)
# -------------------------------
colormap = plt.cm.viridis
norm = Normalize(vmin=np.nanmin(zi), vmax=np.nanmax(zi))
flat_colors = [to_hex(c) for c in colormap(norm(zi.values.ravel()))]
colors = np.array(flat_colors).reshape(zi.shape)
#print(f"Time spent on color mapping: {time.time() - start_time:.2f} seconds")

# -------------------------------
# Precompute grid cell corners for GeoJSON layers
# -------------------------------
lon_sw = LonGrid[:-1, :-1]   # Southwest corner
lon_se = LonGrid[:-1,  1:]   # Southeast corner
lon_ne = LonGrid[1:,   1:]   # Northeast corner
lon_nw = LonGrid[1:,  :-1]   # Northwest corner

lat_sw = LatGrid[:-1, :-1]
lat_se = LatGrid[:-1,  1:]
lat_ne = LatGrid[1:,   1:]
lat_nw = LatGrid[1:,  :-1]

zi_cells = zi.values[:-1, :-1]
#print("zi_cells:", zi_cells.shape)
def generate_geojson_batches(batch_size=50):
    """
    Yields GeoJSON FeatureCollections for each batch of grid cells.
    Cells with NaN velocity magnitude are skipped.
    """
    n_rows, n_cols = lon_sw.shape
    for i in range(0, n_rows, batch_size):
        for j in range(0, n_cols, batch_size):
            i_end = min(i + batch_size, n_rows)
            j_end = min(j + batch_size, n_cols)
            features = []
            for r in range(i, i_end):
                for c in range(j, j_end):
                    if np.isnan(zi_cells[r, c]):
                        continue
                    feature = {
                        "type": "Feature",
                        "geometry": {
                            "type": "Polygon",
                            "coordinates": [[
                                [lon_sw[r, c], lat_sw[r, c]],
                                [lon_se[r, c], lat_se[r, c]],
                                [lon_ne[r, c], lat_ne[r, c]],
                                [lon_nw[r, c], lat_nw[r, c]]
                            ]]
                        },
                        "properties": {
                            "fill": colors[r, c],
                            "stroke": "#000000",
                            "fill-opacity": 0.5,
                            "stroke-width": 0.2
                        }
                    }
                    features.append(feature)
            if features:
                yield {"type": "FeatureCollection", "features": features}

# -------------------------------
# Create the map and add dataset layers
# -------------------------------
map_center = [(lat.min() + lat.max()) / 2, (lon.min() + lon.max()) / 2]
m = Map(center=map_center, zoom=4)

for batch_geojson in generate_geojson_batches():
    geojson_layer = GeoJSON(
        data=batch_geojson,
        style_callback=lambda feature: {
            "fillColor": feature["properties"]["fill"],
            "color": feature["properties"]["stroke"],
            "weight": feature["properties"]["stroke-width"],
            "fillOpacity": feature["properties"]["fill-opacity"]
        }
    )
    m.add_layer(geojson_layer)

#print(f"Total time for GeoJSON layers: {time.time() - start_time:.2f} seconds")

# -------------------------------
# Global store for drawn markers
# -------------------------------
# Dictionary mapping marker_id -> {'location': [lon, lat], 'name': str}
markers_dict = {}
marker_counter = 0  # Unique marker counter

import ipywidgets as widgets
from IPython.display import display

def ask_marker_name(marker, marker_id):
    """
    Display a text input widget to ask for a marker name.
    When confirmed, assign the name to the marker and update the global dictionary.
    """
    with output:
        print("ask_marker_name invoked for marker", marker_id)
        
        # Create the input widget and confirm button.
        name_input = widgets.Text(
            value='',
            placeholder='Enter source name',
            description='Source name:',
            disabled=False,
            style={'description_width': 'initial'}
        )
        confirm_button = widgets.Button(
            description='Confirm',
            disabled=False,
            button_style='success'
        )
        
        # Create the widget container _before_ defining the callback.
        widget_box = widgets.VBox([name_input, confirm_button])
        
        # Define the callback, explicitly capturing widget_box in the default argument.
        def on_button_clicked(b, widget=widget_box):
            marker_name = name_input.value.strip()
            if not marker_name:
                marker_name = f"Marker {marker_id}"
            # Attach the marker name.
            marker.marker_name = marker_name
            # Update the marker entry with the name.
            markers_dict[marker_id]['name'] = marker_name
            print(f"Marker {marker_id} named '{marker_name}'")
            widget.close()  # Close the widget after confirmation.
        
        confirm_button.on_click(on_button_clicked)
        # Display the widget, so it shows in the notebook.
        display(widget_box)

def handle_draw(target, action, geo_json):
    global marker_counter, markers_dict
    
    if action == "created":
        coords = geo_json["geometry"]["coordinates"]  # [lon, lat]
        marker_id = marker_counter  # Assign an ID
        marker_counter += 1
        # Store with location and placeholder for marker name
        markers_dict[marker_id] = {'location': [coords[1], coords[0]], 'name': None}
    
        # Create a draggable marker
        custom_marker = Marker(location=[coords[1], coords[0]], draggable=True)
        custom_marker.marker_id = marker_id  # Attach marker ID
    
        # Ask user to enter a name for the new marker
        ask_marker_name(custom_marker, marker_id)
    
        # Function to update stored marker coordinates when moved
        def on_location_change(change, m_id=marker_id):
            new_location = change["new"]  # New [lat, lon]
            markers_dict[m_id]['location'] = [new_location[1], new_location[0]]
            print(f"Marker {m_id} moved to: {markers_dict[m_id]['location']}")
    
        # Observe marker movement
        custom_marker.observe(on_location_change, names="location")
    
        # Add the new draggable marker to the map
        m.add_layer(custom_marker)
        
        # Remove the default marker added by DrawControl
        for layer in list(m.layers):
            if isinstance(layer, GeoJSON) and layer.data.get("geometry", {}).get("type", "") == "Point":
                m.remove_layer(layer)
    
    elif action == "edited":
        print("Edit event received; marker updates are handled via the location observer.")
    
    elif action == "deleted":
        features = geo_json.get("features", [])
        if not features:
            features = [geo_json]
        for feature in features:
            marker_id = feature.get("properties", {}).get("marker_id")
            if marker_id is not None and marker_id in markers_dict:
                print(f"Deleted marker {marker_id} named '{markers_dict[marker_id]['name']}' at {markers_dict[marker_id]['location']}")
                markers_dict.pop(marker_id)
        print("Current markers after deletion:", markers_dict)

draw_control = DrawControl()
# For drawing tools we don't need, use empty dictionaries.
draw_control.polygon   = {}
draw_control.polyline  = {}
draw_control.rectangle = {}
draw_control.circle    = {}
# Enable marker drawing button by assigning a non-empty dictionary.
draw_control.marker    = {"repeatMode": False}

draw_control.on_draw(handle_draw)
m.add_control(draw_control)

# -------------------------------
# Display the map 
# -------------------------------
m

# 7. Setup MOHID Lagrangian xml input files

## 7.1 Parameter definitions

In [None]:
# Get time limits (min and max time values in the dataset's 'time' variable)
#start, end = min(dataset['time'].values), max(dataset['time'].values)

# Convert the numpy.datetime64 to a Python datetime object using pandas
Start = pd.to_datetime(start_date) 
End = pd.to_datetime(end_date)
#Start = datetime.datetime(2024, 1, 1, 0, 0, 0)
#End = datetime.datetime(2024, 1, 2, 0, 0, 0)

Integrator = 2 #Integration Algorithm 1:Euler, 2:Multi-Step Euler, 3:RK4 (default=1)
Threads = "auto" #Computation threads for shared memory computation (default=auto)
OutputWriteTime = 86400 #Time out data (seconds)
BufferSize = 86400 #control the amount of hydrodynamic data to store in RAM memory (seconds)

# Run the update function
update_parameter_definitions(xml_file_path,Start,End,Integrator,Threads,OutputWriteTime,BufferSize)

## 7.2 Simulation definitions

In [None]:
# -------------------------------
# Get area limits
# -------------------------------
min_lat, max_lat = min(dataset['latitude'].values), max(dataset['latitude'].values)
min_lon, max_lon = min(dataset['longitude'].values), max(dataset['longitude'].values)

resolution = 8000 #metres (m)
timestep = 180 #seconds (s)
BoundingBoxMin = min_lon,min_lat,-1 #defines the corners of your simulation domain x,y,z (deg,deg,m)
BoundingBoxMax = max_lon, max_lat,1 #defines the corners of your simulation domain x,y,z (deg,deg,m)
VerticalVelMethod = 3 #1:From velocity fields, 2:Divergence based, 3:Disabled. Default = 1
BathyminNetcdf = 0 #bathymetry is a property in the netcdf. 1:true, 0:false (computes from layer depth and openPoints. Default = 1
RemoveLandTracer = 0 #Remove tracers on land 0:No, 1:Yes. Default = 1
TracerMaxAge = 0 #maximum tracer age. Default = 0.0. read if > 0

# Run the update function
update_simulation_definitions(xml_file_path,resolution,timestep,BoundingBoxMin,BoundingBoxMax,VerticalVelMethod,BathyminNetcdf,RemoveLandTracer,TracerMaxAge)


## 7.3 Source definitions

In [None]:
print(markers_dict)

In [None]:
start_emission = 0 #start of emission in seconds, considering start_date
end_emission = 86400 #end of emission in seconds, considering start_date. Define end='end' to consider end_date

rate_seconds = 3600 #emission step in seconds
rate_trcPerEmission = 10 #number of tracers emitted every rate_seconds

update_source_definitions(xml_file_path, markers_dict,rate_seconds,rate_trcPerEmission,start_emission,end_emission)

# 8. Run MOHID Lagrangian 

In [None]:
dirout = f"{name}_out"

# Paths to executables and scripts
tools = r"../build/bin/RELEASE"
mohidlagrangian = os.path.join(tools, "MOHIDLagrangian.exe")

preprocessor_dir = r"../src/MOHIDLagrangianPreProcessor"
preprocessor = os.path.join(preprocessor_dir, "MOHIDLagrangianPreProcessor.py")

postprocessor_dir = r"../src/MOHIDLagrangianPostProcessor"
postprocessor = os.path.join(postprocessor_dir, "MOHIDLagrangianPostprocessor.py")

# Manage output directory
if os.path.exists(dirout):
    shutil.rmtree(dirout)
os.makedirs(dirout)

# Copy XML configuration file
shutil.copy(f"{name}.xml", dirout)

# Run Preprocessing
try:
    subprocess.run(
        [sys.executable, preprocessor, "-i", f"{dirout}/{name}.xml", "-o", dirout],
        check=True,
    )
except subprocess.CalledProcessError:
    print("Preprocessing failed.")
    sys.exit(1)

# Run Main Executable
try:
    subprocess.run(
        [mohidlagrangian, "-i", f"{dirout}/{name}.xml", "-o", dirout],
        check=True,
    )
except subprocess.CalledProcessError:
    print("Execution of MOHIDLagrangian failed.")
    sys.exit(1)

# Run Postprocessing
try:
    subprocess.run(
        [sys.executable, "-W", "ignore", postprocessor, "-i", f"{name}.xml", "-o", dirout],
        check=True,
    )
except subprocess.CalledProcessError:
    print("Postprocessing failed.")
    sys.exit(1)

print("All done!")

# 9. Visualize results

## 9.1 Particles

In [None]:
# ---- PARAMETERS ----
dirout  = f"{name}_out"              # e.g. "Plastic_Case_out"
pattern = os.path.join(dirout, f"{name}_*.vtu")

# ---- 1) Collect & sort only numeric‐suffix VTUs ----
all_vtu = [
    fn for fn in glob.glob(pattern)
    if re.search(rf'{re.escape(name)}_[0-9]+\.vtu$', os.path.basename(fn))
]
if not all_vtu:
    raise FileNotFoundError(f"No timestep VTU files in {dirout!r}")

def seq(fn):
    return int(re.search(r'_(\d+)\.vtu$', os.path.basename(fn)).group(1))

all_vtu.sort(key=seq)

# ---- 2) Scan for unique 'source' values ----
unique_sources = set()
for fn in all_vtu:
    rdr = vtk.vtkXMLUnstructuredGridReader()
    rdr.SetFileName(fn)
    rdr.Update()
    arr = rdr.GetOutput().GetPointData().GetArray("source")
    if arr:
        for i in range(arr.GetNumberOfTuples()):
            unique_sources.add(arr.GetValue(i))
    else:
        print(f"  ⚠️  Skipping scan for {os.path.basename(fn)} (no 'source')")

unique_sources = sorted(unique_sources)
if not unique_sources:
    raise RuntimeError("No 'source' values found in any VTU.")

# Assign each source a hex color from Dark2
cmap      = plt.get_cmap("Dark2")
colors    = [mpl.colors.rgb2hex(cmap(i/len(unique_sources))) 
             for i in range(len(unique_sources))]
color_map = {src: colors[i] for i, src in enumerate(unique_sources)}

# ---- 3) Build time‐stamped GeoJSON features ----
dt = datetime.timedelta(seconds=OutputWriteTime)

features=[]
current_time = start_date

for step, fn in enumerate(all_vtu):
    rdr = vtk.vtkXMLUnstructuredGridReader()
    rdr.SetFileName(fn); rdr.Update()
    data = rdr.GetOutput()
    pts  = data.GetPoints()

    # pull out the array, then bail out if it’s missing
    src = data.GetPointData().GetArray("source")
    if src is None:
        # no points to render this frame → skip timestamp bump if you like
        continue

    # advance your clock
    current_time += dt
    tstr = current_time.strftime("%Y-%m-%dT%H:%M:%SZ")

    for i in range(pts.GetNumberOfPoints()):
        lon, lat, _ = pts.GetPoint(i)
        val         = src.GetValue(i)
        features.append({
            "type":"Feature",
            "geometry":{"type":"Point","coordinates":[lon,lat]},
            "properties":{
                "time": tstr,
                "icon": "circle",
                "iconstyle":{
                    "fillColor": color_map[val],
                    "fillOpacity": 0.7,
                    "stroke": False,
                    "radius": 4
                },
                "popup": f"Step: {step}, Source: {val}"
            }
        })

if not features:
    raise RuntimeError("No features generated—check your VTUs and 'source' array.")

times = sorted({f['properties']['time'] for f in features})
#print("All timestamps:", times)

geojson = {"type": "FeatureCollection", "features": features}

# ---- 4) Create & save the animated Folium map ----
# center on the first feature
first_lon, first_lat = features[0]["geometry"]["coordinates"]
m = folium.Map(location=[first_lat, first_lon],
               zoom_start=4, control_scale=True)

measure_control = MeasureControl(
    position="topleft",
    primary_length_unit="meters",
    secondary_length_unit="miles",
    active_color="red",
    completed_color="green"
)
m.add_child(measure_control)

TimestampedGeoJson(
    data=geojson,
    period=f"PT{OutputWriteTime}S",
    duration       = "PT0S",   # ← show only the instant
    transition_time=200,         # ms fade time
    add_last_point=True,
    auto_play=True,
    loop=True,
    time_slider_drag_update=True
).add_to(m)

out_html = os.path.join(dirout, "animated_map.html")
m.save(out_html)
print("✅ Saved animated map to", out_html)
m

## 9.2 Heatmap

In [None]:
# PARAMETERS
dirout          = f"{name}_out"
os.makedirs(dirout, exist_ok=True)

# Collect & sort only numeric‐suffix VTUs ----
all_vtu = [
    fn for fn in glob.glob(pattern)
    if re.search(rf'{re.escape(name)}_[0-9]+\.vtu$', os.path.basename(fn))
]
if not all_vtu:
    raise FileNotFoundError(f"No timestep VTU files in {dirout!r}")

def seq(fn):
    return int(re.search(r'_(\d+)\.vtu$', os.path.basename(fn)).group(1))

all_vtu.sort(key=seq)

# accumulate frames & timestamps
dt          = datetime.timedelta(seconds=OutputWriteTime)
heat_data   = []
timestamps  = []
all_lats    = []
all_lons    = []
current_time = start_date

for fn in all_vtu:
    rdr = vtk.vtkXMLUnstructuredGridReader()
    rdr.SetFileName(fn); rdr.Update()
    pts = rdr.GetOutput().GetPoints()
    if not pts or pts.GetNumberOfPoints()==0:
        continue

    coords = []
    for i in range(pts.GetNumberOfPoints()):
        lon, lat, _ = pts.GetPoint(i)
        coords.append([lat, lon])
        all_lats.append(lat)
        all_lons.append(lon)

    current_time += dt
    timestamps.append(current_time.strftime("%Y-%m-%dT%H:%M:%SZ"))
    heat_data.append(coords)

if not heat_data:
    raise RuntimeError("No data extracted from VTUs")

# center map
center_lat = np.mean(all_lats)
center_lon = np.mean(all_lons)
m = folium.Map(location=[center_lat, center_lon], zoom_start=6)

# add measure tool
m.add_child(MeasureControl())

# static density layer
HeatMap(
    data=[pt for frame in heat_data for pt in frame],
    min_opacity=0.3, radius=8, blur=12
).add_to(folium.FeatureGroup(name="Overall Density").add_to(m))

# animated layer
HeatMapWithTime(
    data        = heat_data,
    index       = timestamps,
    name        = "Time-series Heat",
    auto_play   = True,
    max_opacity = 0.8,
    radius      = 7,
    position    = "bottomright"
).add_to(m)

# toggle control + save
folium.LayerControl().add_to(m)
out_path = os.path.join(dirout, "heatmap_with_time.html")
m.save(out_path)
print("✅ Saved animated heatmap to:", out_path)
m