In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from datetime import datetime, timedelta
import rioxarray
import glob 
import contextily as ctx  
import geopandas as gpd  
from matplotlib.colors import LogNorm 

In [None]:
input_zarr = xr.open_zarr("/Users/manuelmazza/Desktop/Tesi/Dati/predictions_sept_input_xarray.zarr")
data = xr.open_zarr("/Users/manuelmazza/Desktop/Tesi/Dati/predictions_sept_GPT_xarray.zarr")

In [None]:
#############################################################
#################  FORECAST DOWNLOADING  ####################
#############################################################

# TIME INTERVAL DEFINITION
orario_inizio = datetime.strptime("2023-09-15T02:00:00", "%Y-%m-%dT%H:%M:%S")
orario_fine = datetime.strptime("2023-09-15T06:00:00", "%Y-%m-%dT%H:%M:%S")
intervallo_minuti = 5

# Genera la lista di orari di interesse
orari_di_interesse = [orario_inizio + timedelta(minutes=i) for i in range(0, int((orario_fine - orario_inizio).total_seconds() / 60) + 1, intervallo_minuti)]

for forecast_time in data["forecast_start_times"]:
    if pd.Timestamp(forecast_time.values) not in orari_di_interesse:
        continue

    # The input time is 20 minutes before the first forecast
    input_time = pd.Timestamp(forecast_time.values) - pd.Timedelta(minutes=5 * 4)
    input_data = input_zarr.sel(input_start_times=input_time)
    pred = data.sel(forecast_start_times=forecast_time)

    # Rename to have it in the same data variable
    input_data = input_data.rename({"input": "sequence"})
    pred = pred.rename({"forecast": "sequence"})

    # Concat along the time dimension
    combined = xr.concat([input_data, pred], dim="timestep")
    timestamp = f"{input_time.to_numpy()}".split(".")[0]

    # Clean up empty dimensions that are no longer useful
    combined = combined.drop_vars(["input_start_times", "forecast_start_times", "ens"])

    for ens in combined["ens"]:
        ens_data = combined.sel(ens=ens)

        ens_data = ens_data.rio.write_crs("+proj=tmerc +lat_0=42.0 +lon_0=12.5 +ellps=WGS84")

        ens_data = ens_data.transpose("timestep", "y", "x")

        ens_data = ens_data.rio.set_spatial_dims(x_dim="x", y_dim="y")

        # WGS84 UTM 32N REPROJECTION
        ens_data = ens_data.rio.reproject("EPSG:32632")

        if "_CRS" in ens_data.attrs:
           del ens_data.attrs["_CRS"] 

        for var in ens_data.data_vars:
           if "_CRS" in ens_data[var].attrs:
              del ens_data[var].attrs["_CRS"]

        # NetCDF EXPORT
        output_filename = f"{timestamp}_{ens.to_numpy()}_utm32n.nc"
        ens_data.to_netcdf(output_filename)

In [None]:
test = xr.open_dataset("2023-09-15T01:40:00_0_utm32n.nc")
test

In [None]:
#############################################################
##############  ENSEMBLE VISUAL COMPARISON  #################
#############################################################

# INITIAL TIME STAMP DEFINITION and TIME STEP OF INTEREST DEFINITION
timestamp_iniziale = "2023-09-15T01:40:00"  
time_step = 27  

timestamp_iniziale_dt = datetime.strptime(timestamp_iniziale, "%Y-%m-%dT%H:%M:%S")
timestamp_corrente = timestamp_iniziale_dt + timedelta(minutes=5 * time_step)

file_pattern = f"{timestamp_iniziale}_*_utm32n.nc"
file_list = glob.glob(file_pattern)

if not file_list:
    raise FileNotFoundError(f"Nessun file trovato per il timestamp iniziale: {timestamp_iniziale}")

file_list.sort()

ensemble_data = [xr.open_dataset(file) for file in file_list]

num_ensemble = len(ensemble_data)
cols = 5  
rows = (num_ensemble + cols - 1) // cols  
fig, axes = plt.subplots(rows, cols, figsize=(20, 16))
axes = axes.flatten()  

vmin, vmax = 1e-1, 1e2  # MINIMUM AND MAXIMUM OF THE SCALE

for i, (file, ax) in enumerate(zip(file_list, axes)):
    data = ensemble_data[i]["sequence"].isel(timestep=time_step) 
    
    masked_data = data.where(data >= 0.001)
    
    ax.set_xlim(masked_data.x.min().item(), masked_data.x.max().item())
    ax.set_ylim(masked_data.y.min().item(), masked_data.y.max().item())
    ctx.add_basemap(
        ax,
        crs="EPSG:32632", 
        source=ctx.providers.OpenStreetMap.Mapnik,  
        zoom=8,  # ZOOM
        alpha=0.5  # TRASPARENCY
    )
    
    im = masked_data.plot.imshow(
        ax=ax,
        cmap="turbo",  
        norm=LogNorm(vmin=vmin, vmax=vmax),  
        add_colorbar=False,  
        origin="upper",
        alpha=masked_data.notnull().astype(float),  
        zorder=2  
    )

    # SHAPEFILE OF SEVESO BASIN
    shapefile_path = "/Users/manuelmazza/Desktop/Tesi/Dati/nodo_idraulico_shape/Bacini_OUT.shp"
    bacini = gpd.read_file(shapefile_path)

    bacini_filtrati = bacini[bacini["Nome_bacin"].isin(["SEVESO alto", "SEVESO medio", "SEVESO basso"])]
    
    bacini_filtrati = bacini_filtrati.to_crs("EPSG:32632")
    
    bacini_uniti = bacini_filtrati.unary_union
    
    if bacini_uniti.geom_type == "MultiPolygon":
        for geom in bacini_uniti.geoms:
            x, y = geom.exterior.xy
            ax.plot(x, y, color="black", linewidth=2, zorder=3)
    else:
        x, y = bacini_uniti.exterior.xy
        ax.plot(x, y, color="black", linewidth=2, zorder=3)

    ax.set_title(f"Ensemble {i + 1}", fontsize=12)
    
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("")
    ax.set_ylabel("")

for j in range(len(file_list), len(axes)):
    fig.delaxes(axes[j])

cbar_ax = fig.add_axes([0.3, 0.92, 0.4, 0.02])  
cbar = fig.colorbar(im, cax=cbar_ax, orientation="horizontal")
cbar.set_label("mm h⁻¹", fontsize=15)

# TITLE
fig.suptitle(
    f"Timestamp: {timestamp_corrente.strftime('%Y-%m-%d %H:%M:%S')} - Ensemble Members (GPTCast)",
    fontsize=25
)

plt.tight_layout(rect=[0, 0, 1, 0.9])  
plt.show()

In [None]:
#############################################################
##########  Observation vs Simulation COMPARISON  ###########
#############################################################

def read_raster(file_path, time_step):
    data = xr.open_dataset(file_path)
    return data["sequence"].isel(timestep=time_step).values

# Parametri
orario_specificato = "2023-09-15T01:40:00"  # TIME STAMP
durata_minuti = 100  
intervallo_minuti = 20  
num_membri = list(range(20))  

file_list = sorted(Path(".").glob("*.nc"))
file_specificato = None
for file in file_list:
    if orario_specificato in file.name:
        file_specificato = file
        break

if file_specificato is None:
    raise ValueError(f"Nessun file trovato per l'orario specificato: {orario_specificato}")

ds_specificato = xr.open_dataset(file_specificato)
raster_osservato = read_raster(file_specificato, time_step=3)

orario_iniziale = datetime.strptime(orario_specificato, "%Y-%m-%dT%H:%M:%S")
file_paths = []
for minuto in range(0, durata_minuti + 1, intervallo_minuti):
    orario_corrente = orario_iniziale + timedelta(minutes=minuto)
    file_name = f"{orario_corrente.strftime('%Y-%m-%dT%H:%M:%S')}_0_utm32n.nc"  # MEMBER = 0, CHANGE IT EVENTUALLY
    file_paths.append(file_name)

fig, axes = plt.subplots(2, len(file_paths), figsize=(15, 6), constrained_layout=True)

for i, file_path in enumerate(file_paths):
    if not Path(file_path).exists():
        raise ValueError(f"File non trovato: {file_path}")
    raster = read_raster(file_path, time_step=3)
    im = axes[0, i].imshow(raster, cmap="turbo", norm=LogNorm(vmin=1e-1, vmax=1e2))
    axes[0, i].set_title(f"+{i * intervallo_minuti} min", fontsize=10)
    axes[0, i].grid(True, color="gray", linestyle="--", linewidth=0.5)
    axes[0, i].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)  
    
time_steps_ldcast = [7, 11, 15, 19, 23]
for i, time_step in enumerate(time_steps_ldcast):
    raster = read_raster(file_specificato, time_step=time_step)
    im = axes[1, i + len(file_paths) - len(time_steps_ldcast)].imshow(raster, cmap="turbo", norm=LogNorm(vmin=1e-1, vmax=1e2))
    axes[1, i + len(file_paths) - len(time_steps_ldcast)].grid(True, color="gray", linestyle="--", linewidth=0.5)
    axes[1, i + len(file_paths) - len(time_steps_ldcast)].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) 

axes[1, 0].axis("off")  

cbar = fig.colorbar(im, ax=axes[:, -1], orientation="vertical", fraction=0.05, pad=0.04)
cbar.set_label("mm h⁻¹", fontsize=12)

axes[0, 0].set_ylabel("Observation", fontsize=12)
axes[1, len(file_paths) - len(time_steps_ldcast)].set_ylabel("GPTCast", fontsize=12)

plt.show()