In [None]:
from pathlib import Path
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import numpy as np
import geopandas as gpd
from matplotlib.colors import LogNorm
import rioxarray
from scipy.ndimage import uniform_filter
from pyproj import Geod
import matplotlib.dates as mdates
from tqdm import tqdm
from scipy.stats import entropy
from pysteps.verification.probscores import CRPS
from pysteps.verification.spatialscores import fss  
from pysteps.verification.salscores import sal 
from scipy.spatial.distance import euclidean  

In [None]:
#############################################################
#################### CRPS COMPUTATION #######################
#############################################################

def apply_transformations(raster):
    rotated = np.rot90(raster, k=+1, axes=(0, 1))  
    mirrored = rotated[::-1, :]  
    return mirrored

# Parametri
orario_iniziale = "2023-10-31T02:45:00"  
durata_minuti = 120 
durata_calcolo = 240 # CRPS COMPUTATION DURATION
intervallo_minuti = 5  
num_membri = list(range(20)) 

ldcast_dir = Path("../../ens/oct_ld_ens")
gptcast_dir = Path("../../ens/oct_gpt_ens")

def find_file(directory, timestamp):
    file_list = sorted(directory.glob("*.nc"))
    for file in file_list:
        if timestamp in file.name:
            return file
    return None

extent_shapefile_path = "../../data/extent_prova/extent_prova.shp"
extent = gpd.read_file(extent_shapefile_path)
extent = extent.to_crs("EPSG:32632")  
extent_union = extent.unary_union  

def ritaglia_dati(data, area):
    if not data.rio.crs:
        data = data.rio.write_crs("EPSG:32632")  
    return data.rio.clip([area], crs="EPSG:32632", drop=True)

def create_observed_vector(orario_iniziale, durata_calcolo, intervallo_minuti, directory):
    observed_vector = []
    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") + timedelta(minutes=minuto)
        file_corrente = find_file(directory, orario_corrente.strftime("%Y-%m-%dT%H:%M:%S"))
        if file_corrente is None:
            raise ValueError(f"Nessun file trovato per l'orario: {orario_corrente}")
        ds_corrente = xr.open_dataset(file_corrente)
        raster_corrente = ritaglia_dati(ds_corrente["sequence"].isel(timestep=3), extent_union).values
        observed_vector.append(raster_corrente)
    return observed_vector

def create_lead_time_vector(orario_iniziale, durata_calcolo, intervallo_minuti, lead_time, directory, num_membri):
    lead_time_vector = []
    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") - timedelta(minutes=lead_time - minuto)
        timestep_corrente = 3 + (lead_time // intervallo_minuti)
        ensemble_stack = []
        for membro in num_membri:
            file_membro = directory / f"{orario_corrente.strftime('%Y-%m-%dT%H:%M:%S')}_{membro}_utm32n.nc"
            if not file_membro.exists():
                continue
            ds_membro = xr.open_dataset(file_membro)
            if timestep_corrente >= ds_membro["sequence"].shape[0]:
                continue
            raster_membro = ds_membro["sequence"].isel(timestep=timestep_corrente).values
            if raster_membro.ndim == 2:  
                raster_membro = apply_transformations(raster_membro)
                raster_membro_da = xr.DataArray(
                    raster_membro,
                    dims=["y", "x"],
                    coords={"x": ds_membro["x"].values, "y": ds_membro["y"].values},
                )
                raster_membro_da = raster_membro_da.rio.write_crs("EPSG:32632")
                raster_membro = ritaglia_dati(raster_membro_da, extent_union).values
                ensemble_stack.append({
                    "membro": membro,
                    "timestep": timestep_corrente,
                    "data": raster_membro
                })
            else:
                print(f"Error: ensemble member not 2D. Shape: {raster_membro.shape}")
        lead_time_vector.append({
            "timestamp": orario_corrente,
            "ensemble": ensemble_stack
        })
    return lead_time_vector

def calculate_crps(observed_vector, lead_time_vector):
    crps_results = []
    for i, raster_osservato in enumerate(observed_vector):
        ensemble_stack = []
        for ensemble_entry in lead_time_vector[i]["ensemble"]:
            raster_membro = ensemble_entry["data"]
            if raster_osservato.shape != raster_membro.shape:
                raise ValueError(
                    f"Observed dimension {raster_osservato.shape} "
                    f"does not correspond to simulated dimension {raster_membro.shape}."
                )
            ensemble_stack.append(raster_membro)
        
        if len(ensemble_stack) > 0:
            ensemble_stack = np.stack(ensemble_stack, axis=0)
            crps_value = CRPS(ensemble_stack, raster_osservato)
            crps_results.append(crps_value)
    return crps_results

results_crps = {"LDCast": {}, "GPTCast": {}}

for model_name in results_crps.keys():
    results_crps[model_name][0] = 0  

for model_name, model_dir in [("LDCast", ldcast_dir), ("GPTCast", gptcast_dir)]:
    observed_vector = create_observed_vector(orario_iniziale, durata_minuti, intervallo_minuti, model_dir)
    for lead_time in tqdm(range(5, durata_minuti + 1, intervallo_minuti), desc=f"Lead time {model_name}"):
        lead_time_vector = create_lead_time_vector(orario_iniziale, durata_minuti, intervallo_minuti, lead_time, model_dir, num_membri)
        crps_results = calculate_crps(observed_vector, lead_time_vector)
        results_crps[model_name][lead_time] = crps_results

print("CRPS computation completed for both models.")

In [None]:
#############################################################
######################## CRPS PLOT ##########################
#############################################################

for model_name in results_crps.keys():
    results_crps[model_name][0] = [0]  

def plot_crps_with_dispersion(results_crps):
    plt.figure(figsize=(10, 6))

    for model_name, color in zip(results_crps.keys(), ["blue", "orange"]):
        crps_results = results_crps[model_name]  
        lead_times = sorted(crps_results.keys()) 
        mean_crps = []  
        std_crps = [] 
        for lead_time in lead_times:
            crps_values = crps_results[lead_time]  
            mean_crps.append(np.mean(crps_values))  
            std_crps.append(np.std(crps_values))  

            plt.scatter([lead_time] * len(crps_values), crps_values, color=color, alpha=0.5, label=f"Values {model_name}" if lead_time == lead_times[0] else "")

        plt.plot(lead_times, mean_crps, color=color, linewidth=2, label=f"Mean {model_name}")

        plt.fill_between(
            lead_times,
            np.array(mean_crps) - np.array(std_crps),
            np.array(mean_crps) + np.array(std_crps),
            color=color,
            alpha=0.2,
            label=f"{model_name} (± std)" if lead_time == lead_times[0] else ""
        )

    plt.title("CRPS evolution with Lead Time, October Case, Seveso Basin", fontsize=18)
    plt.xlabel("Lead Time [min]", fontsize=16)
    plt.ylabel("CRPS [mm h⁻¹]", fontsize=16)
    plt.grid()
    plt.legend(fontsize=14)
    plt.tick_params(axis="both", labelsize=15)  

    plt.tight_layout()
    
    plt.show()

plot_crps_with_dispersion(results_crps)

In [None]:
#############################################################
##################### SAL COMPUTATION #######################
#############################################################

def apply_transformations(raster):
    rotated = np.rot90(raster, k=+1, axes=(0, 1))  
    mirrored = rotated[::-1, :]  
    return mirrored

orario_iniziale = "2023-10-31T02:45:00"  
durata_minuti = 120  
durata_calcolo = 240 
intervallo_minuti = 5  
num_membri = list(range(20)) 

ldcast_dir = Path("../../ens/oct_ld_ens")
gptcast_dir = Path("../../ens/oct_gpt_ens")

def find_file(directory, timestamp):
    file_list = sorted(directory.glob("*.nc"))
    for file in file_list:
        if timestamp in file.name:
            return file
    return None

extent_shapefile_path = "../../data/extent_prova/extent_prova.shp"
extent = gpd.read_file(extent_shapefile_path)
extent = extent.to_crs("EPSG:32632")  
extent_union = extent.unary_union  

def ritaglia_dati(data, area):
    if not data.rio.crs:
        data = data.rio.write_crs("EPSG:32632")  
    return data.rio.clip([area], crs="EPSG:32632", drop=True)

def create_observed_vector(orario_iniziale, durata_calcolo, intervallo_minuti, directory):
    observed_vector = []
    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") + timedelta(minutes=minuto)
        file_corrente = find_file(directory, orario_corrente.strftime("%Y-%m-%dT%H:%M:%S"))
        if file_corrente is None:
            raise ValueError(f"No file for time: {orario_corrente}")
        ds_corrente = xr.open_dataset(file_corrente)
        raster_corrente = ritaglia_dati(ds_corrente["sequence"].isel(timestep=3), extent_union).values
        observed_vector.append(raster_corrente)
    return observed_vector

def create_lead_time_vector(orario_iniziale, durata_calcolo, intervallo_minuti, lead_time, directory, num_membri):
    lead_time_vector = []
    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") - timedelta(minutes=lead_time - minuto)
        timestep_corrente = 3 + (lead_time // intervallo_minuti)
        ensemble_stack = []
        for membro in num_membri:
            file_membro = directory / f"{orario_corrente.strftime('%Y-%m-%dT%H:%M:%S')}_{membro}_utm32n.nc"
            if not file_membro.exists():
                continue
            ds_membro = xr.open_dataset(file_membro)
            if timestep_corrente >= ds_membro["sequence"].shape[0]:
                continue
            raster_membro = ds_membro["sequence"].isel(timestep=timestep_corrente).values
            if raster_membro.ndim == 2:  
                raster_membro = apply_transformations(raster_membro)
                raster_membro_da = xr.DataArray(
                    raster_membro,
                    dims=["y", "x"],
                    coords={"x": ds_membro["x"].values, "y": ds_membro["y"].values},
                )
                raster_membro_da = raster_membro_da.rio.write_crs("EPSG:32632")
                raster_membro = ritaglia_dati(raster_membro_da, extent_union).values
                ensemble_stack.append({
                    "membro": membro,
                    "timestep": timestep_corrente,
                    "data": raster_membro
                })
            else:
                print(f"Error: ensemble member not 2D. Shape: {raster_membro.shape}")
        lead_time_vector.append({
            "timestamp": orario_corrente,
            "ensemble": ensemble_stack
        })
    return lead_time_vector

def calculate_sal(observed_vector, lead_time_vector):
    sal_results = {"S": [], "A": [], "L": []}  
    for i, raster_osservato in enumerate(observed_vector):
        S_lead_time = []
        A_lead_time = []
        L_lead_time = []
        for ensemble_entry in lead_time_vector[i]["ensemble"]:
            raster_membro = ensemble_entry["data"]
            if raster_osservato.shape != raster_membro.shape:
                raise ValueError(
                    f"Observed dimension {raster_osservato.shape} "
                    f"does not correspond to simulated dimension {raster_membro.shape}."
                )
            S, A, L = sal(raster_osservato, raster_membro)
            S_lead_time.append(S)
            A_lead_time.append(A)
            L_lead_time.append(L)
        if S_lead_time and A_lead_time and L_lead_time:
            sal_results["S"].append(np.mean(S_lead_time))
            sal_results["A"].append(np.mean(A_lead_time))
            sal_results["L"].append(np.mean(L_lead_time))
    return sal_results

results_sal = {"LDCast": {}, "GPTCast": {}}

for model_name in results_sal.keys():
    results_sal[model_name][0] = {"S": 0, "A": 0, "L": 0}  

for model_name, model_dir in [("LDCast", ldcast_dir), ("GPTCast", gptcast_dir)]:
    observed_vector = create_observed_vector(orario_iniziale, durata_minuti, intervallo_minuti, model_dir)
    for lead_time in tqdm(range(5, durata_minuti + 1, intervallo_minuti), desc=f"Lead time {model_name}"):
        lead_time_vector = create_lead_time_vector(orario_iniziale, durata_minuti, intervallo_minuti, lead_time, model_dir, num_membri)
        sal_results = calculate_sal(observed_vector, lead_time_vector)
        results_sal[model_name][lead_time] = sal_results

print("SAL computation completed for both models.")

In [None]:
#############################################################
######################### SAL PLOT ##########################
#############################################################

for model_name in results_sal.keys():
    if 0 not in results_sal[model_name]:
        results_sal[model_name][0] = {"S": [0], "A": [0], "L": [0]}  

def plot_sal_components(results_sal, durata_minuti, intervallo_minuti):
    fig, axes = plt.subplots(3, 1, figsize=(10, 12), sharex=True)
    components = ["S", "A", "L"]
    titles = ["Structure [S]", "Amplitude [A]", "Location [L]"]
    colors = {"LDCast": "blue", "GPTCast": "orange"}
    
    for i, component in enumerate(components):
        ax = axes[i]
        for model_name, model_results in results_sal.items():
            lead_times = sorted(model_results.keys())
            mean_values = []
            std_values = []
            
            for lead_time in lead_times:
                values = model_results[lead_time][component]
                if not isinstance(values, (list, np.ndarray)):
                    values = [values]  

                values = np.array(values)
                mean_values.append(np.nanmean(values))  
                std_values.append(np.nanstd(values))  
                
                ax.scatter([lead_time] * len(values), values, color=colors[model_name], alpha=0.5, label=f"Values {model_name}" if lead_time == intervallo_minuti else "")
            
            ax.plot(lead_times, mean_values, color=colors[model_name], label=f"Mean {model_name}")
            
            ax.fill_between(lead_times, 
                            np.array(mean_values) - np.array(std_values), 
                            np.array(mean_values) + np.array(std_values), 
                            color=colors[model_name], alpha=0.3)
        
        ax.set_title(f"{titles[i]}, October Case, Seveso Basin", fontsize=14)
        
        ax.set_ylabel(f"{component} [-]", fontsize=16)
        
        if component in ["S", "A"]:
            ax.set_ylim(-2, 2) 
        elif component == "L":
            ax.set_ylim(0, 2) 
        
        ax.tick_params(axis="both", labelsize=14)
        
        if component == "A":
            ax.legend(loc="lower left", fontsize=14) 
        else:
            ax.legend(loc="upper left", fontsize=14)  
        ax.grid(True)
    
    axes[-1].set_xlabel("Lead Time [min]", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])  

    plt.show()

plot_sal_components(results_sal, durata_minuti, intervallo_minuti)

In [None]:
#############################################################
#################### FSS COMPUTATION ########################
#############################################################

def apply_transformations(raster):
    rotated = np.rot90(raster, k=+1, axes=(0, 1))  
    mirrored = rotated[::-1, :]  
    return mirrored

orario_iniziale = "2023-10-31T02:45:00"  
durata_minuti = 120  
durata_calcolo = 120 
intervallo_minuti = 5  
num_membri = list(range(20)) 

ldcast_dir = Path("../../ens/oct_ld_ens")
gptcast_dir = Path("../../ens/oct_gpt_ens")

finestre_spaziali = [1, 8, 14, 64, 128, 267]  # SPATIAL WINDOW
soglie = [0.1, 2.5, 10] # THRESHOLDS

def find_file(directory, timestamp):
    file_list = sorted(directory.glob("*.nc"))
    for file in file_list:
        if timestamp in file.name:
            return file
    return None

extent_shapefile_path = "../../data/extent_prova/extent_prova.shp"
extent = gpd.read_file(extent_shapefile_path)
extent = extent.to_crs("EPSG:32632")  
extent_union = extent.unary_union  

def ritaglia_dati(data, area):
    if not data.rio.crs:
        data = data.rio.write_crs("EPSG:32632")  
    return data.rio.clip([area], crs="EPSG:32632", drop=True)

def create_observed_vector(orario_iniziale, durata_calcolo, intervallo_minuti, directory):
    observed_vector = []
    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") + timedelta(minutes=minuto)
        file_corrente = find_file(directory, orario_corrente.strftime("%Y-%m-%dT%H:%M:%S"))
        if file_corrente is None:
            raise ValueError(f"No file for time: {orario_corrente}")
        ds_corrente = xr.open_dataset(file_corrente)
        raster_corrente = ritaglia_dati(ds_corrente["sequence"].isel(timestep=3), extent_union).values
        observed_vector.append(raster_corrente)
    return observed_vector

def create_lead_time_vector(orario_iniziale, durata_calcolo, intervallo_minuti, lead_time, directory, num_membri):
    lead_time_vector = []
    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") - timedelta(minutes=lead_time - minuto)
        timestep_corrente = 3 + (lead_time // intervallo_minuti)
        ensemble_stack = []
        for membro in num_membri:
            file_membro = directory / f"{orario_corrente.strftime('%Y-%m-%dT%H:%M:%S')}_{membro}_utm32n.nc"
            if not file_membro.exists():
                continue
            ds_membro = xr.open_dataset(file_membro)
            if timestep_corrente >= ds_membro["sequence"].shape[0]:
                continue
            raster_membro = ds_membro["sequence"].isel(timestep=timestep_corrente).values
            if raster_membro.ndim == 2:  
                raster_membro = apply_transformations(raster_membro)
                raster_membro_da = xr.DataArray(
                    raster_membro,
                    dims=["y", "x"],
                    coords={"x": ds_membro["x"].values, "y": ds_membro["y"].values},
                )
                raster_membro_da = raster_membro_da.rio.write_crs("EPSG:32632")
                raster_membro = ritaglia_dati(raster_membro_da, extent_union).values
                ensemble_stack.append({
                    "membro": membro,
                    "timestep": timestep_corrente,
                    "data": raster_membro
                })
            else:
                print(f"Error: ensemble member not 2D. Shape: {raster_membro.shape}")
        lead_time_vector.append({
            "timestamp": orario_corrente,
            "ensemble": ensemble_stack
        })
    return lead_time_vector

def calculate_fss(observed_vector, lead_time_vector, finestre_spaziali, soglie):
    fss_results = {window: {threshold: [] for threshold in soglie} for window in finestre_spaziali}
    for i, raster_osservato in enumerate(observed_vector):
        if "ensemble" not in lead_time_vector[i] or not lead_time_vector[i]["ensemble"]:
            print(f"Attention: No ensemble for lead time[{i}]")
            continue
        
        for window in finestre_spaziali:
            for threshold in soglie:
                fss_lead_time = []
                for ensemble_entry in lead_time_vector[i]["ensemble"]:
                    raster_membro = ensemble_entry["data"]
                    if raster_osservato.shape != raster_membro.shape:
                        raise ValueError(
                            f"Observed dimension {raster_osservato.shape} "
                            f"does not correspond to simulated dimension {raster_membro.shape}."
                        )
                    fss_value = fss(raster_osservato, raster_membro, threshold, window)
                    fss_lead_time.append(fss_value)
                if fss_lead_time:
                    fss_results[window][threshold].append(np.mean(fss_lead_time))
    return fss_results

results_fss = {"LDCast": {}, "GPTCast": {}}

for model_name in results_fss.keys():
    results_fss[model_name][0] = {window: {threshold: 1 for threshold in soglie} for window in finestre_spaziali}

for model_name, model_dir in [("LDCast", ldcast_dir), ("GPTCast", gptcast_dir)]:
    observed_vector = create_observed_vector(orario_iniziale, durata_minuti, intervallo_minuti, model_dir)
    for lead_time in tqdm(range(5, durata_minuti + 1, intervallo_minuti), desc=f"Lead time {model_name}"):
        lead_time_vector = create_lead_time_vector(orario_iniziale, durata_minuti, intervallo_minuti, lead_time, model_dir, num_membri)
        fss_results = calculate_fss(observed_vector, lead_time_vector, finestre_spaziali, soglie)
        results_fss[model_name][lead_time] = fss_results

print("FSS computation completed for both models.")

In [None]:
#############################################################
#################### FSS vs THRESHOLD #######################
#############################################################

def plot_fss_lines(results_fss, finestra, soglie):
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = {"LDCast": "blue", "GPTCast": "orange"}
    line_styles = {0.1: "-", 2.5: "--", 10: ":"} 
    
    for model_name, model_results in results_fss.items():
        lead_times = sorted(model_results.keys())
        
        for threshold in soglie:
            mean_values = []
            
            for lead_time in lead_times:
                fss_values = model_results[lead_time][finestra][threshold]
                mean_values.append(np.mean(fss_values))
            
            ax.plot(lead_times, mean_values, color=colors[model_name], linestyle=line_styles[threshold], marker="o",
                    label=f"{model_name} (T = {threshold} mm h⁻¹)")
    
    ax.set_title(f"FSS vs Lead Time, October Case, Seveso Basin (Window: {finestra} km x {finestra} km)", fontsize=18)
    ax.set_xlabel("Lead Time [min]", fontsize=16)
    ax.set_ylabel("FSS [-]", fontsize=16)
    ax.legend(loc="upper right", fontsize=14)
    plt.tick_params(axis="both", labelsize=15)
    ax.grid(True)
    plt.tight_layout()

    plt.show()

plot_fss_lines(results_fss, finestra=14, soglie=[0.1, 2.5, 10])

In [None]:
#############################################################
################### FSS vs SPATIAL (LD) #####################
#############################################################

def plot_fss_spatial_window(results_fss, model_name, threshold, lead_times, seveso_window, y_limits=None):
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = {20: "blue", 40: "orange", 60: "green", 80: "red", 100: "purple"}  
    line_styles = {20: "-", 40: "-", 60: "-", 80: "-", 100: "-"} 
    
    model_results = results_fss[model_name]
    spatial_windows = sorted(model_results[20].keys())  

    spatial_windows = [w for w in spatial_windows if w != 267]
    
    for lead_time in lead_times:
        mean_values = []
        
        for window in spatial_windows:
            fss_values = model_results[lead_time][window][threshold]
            mean_values.append(np.mean(fss_values))
        
        ax.plot(spatial_windows, mean_values, color=colors[lead_time], linestyle=line_styles[lead_time],
                marker="o", label=f"Lead Time: {lead_time} min")
    
    ax.axvline(x=seveso_window, color="red", linestyle="--", label="Seveso Basin")
    
    ax.set_title(f"FSS vs Spatial Window, October Case, {model_name}, Seveso Basin (Threshold: {threshold}  mm h⁻¹)", fontsize=18)
    ax.set_xlabel("Spatial Window [km x km]", fontsize=16)
    ax.set_ylabel("FSS [-]", fontsize=16)
    ax.set_xticks(spatial_windows)
    ax.set_xticklabels([f"{w}x{w}" for w in spatial_windows], rotation=90, ha="right")  
    ax.legend(loc="lower right", fontsize=15)
    plt.tick_params(axis="both", labelsize=15)

    if y_limits:
        ax.set_ylim(y_limits)
        
    ax.grid(True)
    plt.tight_layout()

    plt.show()

plot_fss_spatial_window(results_fss, model_name="LDCast", threshold=0.1, lead_times=[20, 40, 60, 80, 100], seveso_window=14, y_limits=(0, 1))
plot_fss_spatial_window(results_fss, model_name="LDCast", threshold=2.5, lead_times=[20, 40, 60, 80, 100], seveso_window=14, y_limits=(0, 1))
plot_fss_spatial_window(results_fss, model_name="LDCast", threshold=10, lead_times=[20, 40, 60, 80, 100], seveso_window=14, y_limits=(0, 1))

In [None]:
#############################################################
################### FSS vs SPATIAL (GPT) ####################
#############################################################

def plot_fss_spatial_window(results_fss, model_name, threshold, lead_times, seveso_window, y_limits=None):
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = {20: "blue", 40: "orange", 60: "green", 80: "red", 100: "purple"}  
    line_styles = {20: "-", 40: "-", 60: "-", 80: "-", 100: "-"}  
    
    model_results = results_fss[model_name]
    spatial_windows = sorted(model_results[20].keys())  

    spatial_windows = [w for w in spatial_windows if w != 267]
    
    for lead_time in lead_times:
        mean_values = []
        
        for window in spatial_windows:
            fss_values = model_results[lead_time][window][threshold]
            mean_values.append(np.mean(fss_values))
        
        ax.plot(spatial_windows, mean_values, color=colors[lead_time], linestyle=line_styles[lead_time],
                marker="o", label=f"Lead Time: {lead_time} min")
    
    ax.axvline(x=seveso_window, color="red", linestyle="--", label="Seveso Basin")
    
    ax.set_title(f"FSS vs Spatial Window, October Case, {model_name}, Seveso Basin (Threshold: {threshold}  mm h⁻¹)", fontsize=16)
    ax.set_xlabel("Spatial Window [km x km]", fontsize=16)
    ax.set_ylabel("FSS [-]", fontsize=16)
    ax.set_xticks(spatial_windows)
    ax.set_xticklabels([f"{w}x{w}" for w in spatial_windows], rotation=90, ha="right") 
    ax.legend(loc="lower right", fontsize=15)
    plt.tick_params(axis="both", labelsize=15)

    if y_limits:
        ax.set_ylim(y_limits)
        
    ax.grid(True)
    plt.tight_layout()

    plt.show()

plot_fss_spatial_window(results_fss, model_name="GPTCast", threshold=0.1, lead_times=[20, 40, 60, 80, 100], seveso_window=14, y_limits=(0, 1))
plot_fss_spatial_window(results_fss, model_name="GPTCast", threshold=2.5, lead_times=[20, 40, 60, 80, 100], seveso_window=14, y_limits=(0, 1))
plot_fss_spatial_window(results_fss, model_name="GPTCast", threshold=10, lead_times=[20, 40, 60, 80, 100], seveso_window=14, y_limits=(0, 1))

In [None]:
#############################################################
##################### CoM COMPUTATION #######################
#############################################################

def apply_transformations(raster):
    rotated = np.rot90(raster, k=+1, axes=(0, 1))  
    mirrored = rotated[::-1, :]  
    return mirrored

orario_iniziale = "2023-10-31T02:45:00"  
durata_minuti = 120  
durata_calcolo = 120 
intervallo_minuti = 5  
num_membri = list(range(20)) 

ldcast_dir = Path("../../ens/oct_ld_ens")
gptcast_dir = Path("../../ens/oct_gpt_ens")

def find_file(directory, timestamp):
    file_list = sorted(directory.glob("*.nc"))
    for file in file_list:
        if timestamp in file.name:
            return file
    return None

extent_shapefile_path = "../../data/extent_prova/extent_prova.shp"
extent = gpd.read_file(extent_shapefile_path)
extent = extent.to_crs("EPSG:32632")  
extent_union = extent.unary_union  

def ritaglia_dati(data, area):
    if not data.rio.crs:
        data = data.rio.write_crs("EPSG:32632")  
    return data.rio.clip([area], crs="EPSG:32632", drop=True)

def create_observed_vector(orario_iniziale, durata_calcolo, intervallo_minuti, directory):
    observed_vector = []
    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") + timedelta(minutes=minuto)
        file_corrente = find_file(directory, orario_corrente.strftime("%Y-%m-%dT%H:%M:%S"))
        if file_corrente is None:
            raise ValueError(f"No file for time: {orario_corrente}")
        ds_corrente = xr.open_dataset(file_corrente)
        raster_corrente = ritaglia_dati(ds_corrente["sequence"].isel(timestep=3), extent_union)
        observed_vector.append({
            "data": raster_corrente.values,
            "x": raster_corrente.coords['x'].values,
            "y": raster_corrente.coords['y'].values
        })
    return observed_vector

def create_lead_time_vector(orario_iniziale, durata_calcolo, intervallo_minuti, lead_time, directory, num_membri):
    lead_time_vector = []
    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") - timedelta(minutes=lead_time - minuto)
        timestep_corrente = 3 + (lead_time // intervallo_minuti)
        ensemble_stack = []
        for membro in num_membri:
            file_membro = directory / f"{orario_corrente.strftime('%Y-%m-%dT%H:%M:%S')}_{membro}_utm32n.nc"
            if not file_membro.exists():
                continue
            ds_membro = xr.open_dataset(file_membro)
            if timestep_corrente >= ds_membro["sequence"].shape[0]:
                continue
            raster_membro = ds_membro["sequence"].isel(timestep=timestep_corrente).values
            raster_trasformato = apply_transformations(raster_membro)
            
            raster_da = xr.DataArray(
                raster_trasformato,
                dims=["y", "x"],
                coords={"y": ds_membro.coords["y"].values, "x": ds_membro.coords["x"].values}
            )
            
            raster_ritagliato = ritaglia_dati(raster_da, extent_union)
            
            if raster_ritagliato.ndim == 2:  
                ensemble_stack.append({
                    "membro": membro,
                    "timestep": timestep_corrente,
                    "data": raster_ritagliato.values,
                    "x": raster_ritagliato.coords['x'].values,
                    "y": raster_ritagliato.coords['y'].values
                })
            else:
                print(f"Error: Ensemble member not 2D. Shape: {raster_ritagliato.shape}")
        lead_time_vector.append({
            "timestamp": orario_corrente,
            "ensemble": ensemble_stack
        })
    return lead_time_vector

def calcola_centro_di_massa(raster, x_coords, y_coords):
    raster = np.nan_to_num(raster, nan=0)
    total_mass = np.sum(raster)
    if total_mass == 0:
        return None  
    x_center = np.sum(x_coords * raster) / total_mass
    y_center = np.sum(y_coords * raster) / total_mass
    return x_center, y_center

def calcola_distanze_e_centri(orario_iniziale, durata_calcolo, intervallo_minuti, lead_time, directory, num_membri, extent_union):
    distanze = []
    centri_osservati = []
    centri_simulati_medi = []

    observed_vector = create_observed_vector(orario_iniziale, durata_calcolo, intervallo_minuti, directory)

    for minuto in range(0, durata_calcolo + 1, intervallo_minuti):
        orario_corrente = datetime.strptime(orario_iniziale, "%Y-%m-%dT%H:%M:%S") - timedelta(minutes=lead_time - minuto)
        timestep_corrente = 3 + (lead_time // intervallo_minuti)

        raster_osservato = observed_vector[minuto // intervallo_minuti]
        x_coords, y_coords = np.meshgrid(
            raster_osservato['x'],  
            raster_osservato['y']   
        )
        centro_osservato = calcola_centro_di_massa(raster_osservato['data'], x_coords, y_coords)
        centri_osservati.append(centro_osservato)

        ensemble_stack = []
        for membro in num_membri:
            file_membro = directory / f"{orario_corrente.strftime('%Y-%m-%dT%H:%M:%S')}_{membro}_utm32n.nc"
            if not file_membro.exists():
                continue
            ds_membro = xr.open_dataset(file_membro)
            if timestep_corrente >= ds_membro["sequence"].shape[0]:
                continue
            raster_membro = ritaglia_dati(ds_membro["sequence"].isel(timestep=timestep_corrente), extent_union)
            if raster_membro.ndim == 2:  
                ensemble_stack.append({
                    "data": raster_membro.values,
                    "x": raster_membro.coords['x'].values,
                    "y": raster_membro.coords['y'].values
                })

        centri_simulati = [
            calcola_centro_di_massa(r["data"], *np.meshgrid(r["x"], r["y"]))
            for r in ensemble_stack
        ]
        centri_simulati = [c for c in centri_simulati if c is not None]  
        if len(centri_simulati) > 0:
            x_medio = np.mean([c[0] for c in centri_simulati])
            y_medio = np.mean([c[1] for c in centri_simulati])
            centro_simulato_medio = (x_medio, y_medio)
            centri_simulati_medi.append(centro_simulato_medio)

            if centro_osservato is not None:
                distanza = np.sqrt((centro_simulato_medio[0] - centro_osservato[0])**2 + (centro_simulato_medio[1] - centro_osservato[1])**2) / 1000  
                distanze.append(distanza)

    return distanze, centri_osservati, centri_simulati_medi

distanze_centro_di_massa = {"LDCast": {}, "GPTCast": {}}
centri_di_massa = {"LDCast": {}, "GPTCast": {}}

for model_name, model_dir in [("LDCast", ldcast_dir), ("GPTCast", gptcast_dir)]:
    for lead_time in tqdm(range(5, durata_minuti + 1, intervallo_minuti), desc=f"Lead time {model_name}"):
        distanze, centri_osservati, centri_simulati_medi = calcola_distanze_e_centri(
            orario_iniziale, durata_calcolo, intervallo_minuti, lead_time, model_dir, num_membri, extent_union
        )
        distanze_centro_di_massa[model_name][lead_time] = distanze
        centri_di_massa[model_name][lead_time] = {
            "osservati": centri_osservati,
            "simulati_medi": centri_simulati_medi
        }

print("Distance and center of mass computation completed.")

In [None]:
#############################################################
######################## CoM PLOT ###########################
#############################################################

def calcola_centro_shapefile(extent_union):
    return extent_union.centroid.x, extent_union.centroid.y

centro_shapefile = calcola_centro_shapefile(extent_union)

lead_times_da_plottare = [20, 40, 60]
colori_lead_time = ["blue", "orange", "green", "red", "purple"]  

shapefile_path = "../../data/nodo_idraulico_shape/Bacini_OUT.shp"
bacini = gpd.read_file(shapefile_path)

extent_shapefile_path = "../../data/extent_prova/extent_prova.shp"
extent = gpd.read_file(extent_shapefile_path)
extent = extent.to_crs("EPSG:32632")  
extent_union = extent.unary_union  

bacini_filtrati = bacini[bacini["Nome_bacin"].isin(["SEVESO alto", "SEVESO medio", "SEVESO basso"])]

bacini_filtrati = bacini_filtrati.to_crs("EPSG:32632")

bacini_uniti = bacini_filtrati.unary_union

for model_name in ["LDCast", "GPTCast"]:
    for lead_time_focus in lead_times_da_plottare:  
        plt.figure(figsize=(10, 10))
        plt.title(f"CoM Tracking, October Case - {model_name} (Lead Time: {lead_time_focus} min)", fontsize=16)
        plt.xlabel("X Coordinate [mE]", fontsize=16)
        plt.ylabel("Y Coordinate [mN]", fontsize=16)

        ax = plt.gca()
        gpd.GeoSeries(bacini_uniti).plot(ax=ax, color="red", edgecolor="red", alpha=0.1, zorder=1, label="Basin Boundary")

        for lead_time in lead_times_da_plottare:
            centri_osservati = centri_di_massa[model_name][lead_time]["osservati"]
            if centri_osservati:
                x_osservati, y_osservati = zip(*[c for c in centri_osservati if c is not None])
                plt.scatter(
                    x_osservati, y_osservati,
                    color="black",
                    alpha=1.0 if lead_time == lead_time_focus else 0.1,
                    label=""  
                )

        plt.scatter([], [], color="black", alpha=1.0, label="Observed")

        for lead_time, colore in zip(lead_times_da_plottare, colori_lead_time):
            centri_simulati = centri_di_massa[model_name][lead_time]["simulati_medi"]
            if centri_simulati:
                x_simulati, y_simulati = zip(*[c for c in centri_simulati if c is not None])
                plt.scatter(
                    x_simulati, y_simulati,
                    color=colore,
                    alpha=1.0 if lead_time == lead_time_focus else 0.3,
                    label=f"Lead Time: {lead_time} min" if lead_time == lead_time_focus else ""
                )

        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

        plt.legend(loc="upper right", fontsize=16)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.grid()

        plt.show()

In [None]:
#############################################################
######################## DIST PLOT ##########################
#############################################################

start_time = datetime(2023, 10, 31, 4, 00)  
end_time = datetime(2023, 10, 31, 6, 00)    
time_range = [start_time + timedelta(minutes=i) for i in range(0, (end_time - start_time).seconds // 60 + 1, intervallo_minuti)]  

for model_name in ["LDCast", "GPTCast"]:
    plt.figure(figsize=(12, 6))
    plt.title(f"Euclidean Distance between Centres of Mass, October Case - {model_name}", fontsize=16)
    plt.xlabel("Date and Hour (CEST/UTC+1)", fontsize=14)
    plt.ylabel("Euclidean Distance [km]", fontsize=14)

    for lead_time, colore in zip(lead_times_da_plottare, colori_lead_time):
        distanze = distanze_centro_di_massa[model_name][lead_time]
        time_range_lead = time_range[:len(distanze)]  

        plt.plot(time_range_lead, distanze, marker="o", label=f"Lead Time: {lead_time} min", color=colore)

        media_distanze = np.mean(distanze)
        plt.axhline(y=media_distanze, color=colore, linestyle="--", alpha=0.7, label=f"Mean (Lead Time: {lead_time} min)")

        print(f"Model: {model_name}, Lead Time: {lead_time} min, Mean Distance: {media_distanze:.2f} km")

    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d %H:%M"))  
    plt.gca().xaxis.set_major_locator(mdates.MinuteLocator(interval=30)) 
    plt.xticks(rotation=45, fontsize=12)
    plt.yticks(fontsize=12)

    plt.xlim(start_time, end_time)
    plt.ylim(0, 18)

    plt.legend(loc="upper left", fontsize=14)
    plt.grid()
    
    plt.show()