In [None]:
import xarray as xr
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.animation import FuncAnimation
import contextily as ctx
import geopandas as gpd

In [None]:
input_zarr = xr.open_zarr("../../data/predictions_sept_input_xarray.zarr")
data = xr.open_zarr("../../data/predictions_sept_GPT_xarray.zarr")

In [None]:
#############################################################
################  STD FILES DOWNLOADING  ####################
#############################################################

for forecast_time in data["forecast_start_times"]:
    # The input time is 20 minutes before the first forecast
    input_time = forecast_time - pd.Timedelta(minutes=5 * 4)
    input_data = input_zarr.sel(input_start_times=input_time)
    pred = data.sel(forecast_start_times=forecast_time)

    # Rename to have it in the same data variable
    input_data = input_data.rename({"input": "sequence"})
    pred = pred.rename({"forecast": "sequence"})

    if "spatial_ref" in input_data.coords:
        input_data = input_data.drop_vars("spatial_ref")
    if "spatial_ref" in pred.coords:
        pred = pred.drop_vars("spatial_ref")

    # Concat along the time dimension
    combined = xr.concat([input_data, pred], dim="timestep")
    timestamp = f"{input_time.to_numpy()}".split(".")[0]

    # Clean up empty dimensions that are no longer useful
    combined = combined.drop_vars(["input_start_times", "forecast_start_times"])

    # ENSEMBLE STANDARD DEVIATION
    ensemble_std = combined.std(dim="ens", skipna=True)

    ensemble_std = ensemble_std.rio.write_crs("+proj=tmerc +lat_0=42.0 +lon_0=12.5 +ellps=WGS84")

    ensemble_std = ensemble_std.transpose("timestep", "y", "x")

    ensemble_std = ensemble_std.rio.set_spatial_dims(x_dim="y", y_dim="x")

    # WGS84 UTM 32N REPROJECTION
    ensemble_std = ensemble_std.rio.reproject("EPSG:32632")

    output_filename_std = f"{timestamp}_ensemble_std_utm32n.nc"
    ensemble_std.to_netcdf(output_filename_std)

In [None]:
test = xr.open_dataset("2023-09-15T00:45:00_ensemble_std_utm32n.nc") # CHANGE TO SELECT THE FILE
test

In [None]:
#############################################################
######################  STD PLOT  ###########################
#############################################################

start_time = datetime.strptime("2023-09-15T01:00:00", "%Y-%m-%dT%H:%M:%S") # INITIAL TIMESTAMP
time_interval = timedelta(minutes=5)

time_step = 4  # TIME STEP TO PLOT
current_time = start_time + time_step * time_interval
timestamp = current_time.strftime("%Y-%m-%dT%H:%M:%S")

if time_step < 4:
    title_text = f"{timestamp} UTC (Observed)"
else:
    title_text = f"{timestamp} UTC (Predicted, GPTCast)"

fig, ax = plt.subplots(figsize=(8, 6))

data = test["sequence"].isel(timestep=time_step).astype('float64')

if data.rio.crs is None:
    data = data.rio.write_crs("EPSG:32632")  

masked_data = data.where(data >= 0.001)

xmin, xmax = masked_data.x.min().item(), masked_data.x.max().item()
ymin, ymax = masked_data.y.min().item(), masked_data.y.max().item()
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)

ctx.add_basemap(
    ax, 
    crs="EPSG:32632", 
    source=ctx.providers.OpenStreetMap.Mapnik, 
    zoom=10, 
    alpha=0.9
)

vmin, vmax = 1e-1, 1e2 

im = masked_data.plot.imshow(
    ax=ax,
    cmap="viridis",  
    norm=LogNorm(vmin=vmin, vmax=vmax),  
    add_colorbar=False,  
    origin="upper",
    alpha=masked_data.notnull().astype(float),  
    zorder=2  
    )

# SEVESO SHAPEFILE
shapefile_path = "../../data/nodo_idraulico_shape/Bacini_OUT.shp"
bacini = gpd.read_file(shapefile_path)

bacini_filtrati = bacini[bacini["Nome_bacin"].isin(["SEVESO alto", "SEVESO medio", "SEVESO basso"])]

bacini_filtrati = bacini_filtrati.to_crs("EPSG:32632")

bacini_uniti = bacini_filtrati.unary_union

if bacini_uniti.geom_type == "MultiPolygon":
    for geom in bacini_uniti.geoms:
        x, y = geom.exterior.xy
        ax.plot(x, y, color="black", linewidth=2, zorder=3)
else:
    x, y = bacini_uniti.exterior.xy
    ax.plot(x, y, color="black", linewidth=2, zorder=3)

ax.ticklabel_format(style='plain', axis='both')

ax.set_xlabel("")
ax.set_ylabel("")

ax.set_title(title_text, fontsize=16)

plt.show()

In [None]:
#############################################################
###################  STD ANIMATION  #########################
#############################################################

start_time = datetime.strptime("2023-09-15T02:45:00", "%Y-%m-%dT%H:%M:%S")
time_interval = timedelta(minutes=5)

fig, ax = plt.subplots(figsize=(8, 6))

im = None  
cbar = None 

def update_plot(time_step):
    global im, cbar 
    current_time = start_time + time_step * time_interval
    timestamp = current_time.strftime("%Y-%m-%dT%H:%M:%S")
    
    title_color = "black" if time_step < 4 else "red"
    
    if time_step < 4:
        title_text = f"{timestamp} UTC (Observed)"
    else:
        title_text = f"{timestamp} UTC (Predicted, GPTCast)"
    
    data = test["sequence"].isel(timestep=time_step).astype('float64')
    
    if data.rio.crs is None:
        data = data.rio.write_crs("EPSG:32632") 
    
    masked_data = data.where(data >= 0.001)
    
    ax.clear()
    ax.set_xlim(masked_data.x.min().item(), masked_data.x.max().item())
    ax.set_ylim(masked_data.y.min().item(), masked_data.y.max().item())
    
    ctx.add_basemap(
        ax, 
        crs="EPSG:32632", 
        source=ctx.providers.OpenStreetMap.Mapnik, 
        zoom=8, 
        alpha=1
    )

    vmin, vmax = 1e-1, 1e2  
    
    im = masked_data.plot.imshow(
        ax=ax,
        cmap="viridis",  
        norm=LogNorm(vmin=vmin, vmax=vmax),  
        add_colorbar=False,  
        origin="upper",
        alpha=masked_data.notnull().astype(float),  
        zorder=2  
    )

    shapefile_path = "../../data/nodo_idraulico_shape/Bacini_OUT.shp"
    bacini = gpd.read_file(shapefile_path)

    bacini_filtrati = bacini[bacini["Nome_bacin"].isin(["SEVESO alto", "SEVESO medio", "SEVESO basso"])]
    
    bacini_filtrati = bacini_filtrati.to_crs("EPSG:32632")
    
    bacini_uniti = bacini_filtrati.unary_union
    
    if bacini_uniti.geom_type == "MultiPolygon":
        for geom in bacini_uniti.geoms:
            x, y = geom.exterior.xy
            ax.plot(x, y, color="black", linewidth=2, zorder=3)
    else:
        x, y = bacini_uniti.exterior.xy
        ax.plot(x, y, color="black", linewidth=2, zorder=3)

    if cbar is None:
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label("Standard Deviation [mm h⁻¹]", fontsize=12)
    else:
        cbar.update_normal(im) 

    if cbar is None:
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label("Intensity [mm h⁻¹]", fontsize=12)
    else:
        cbar.update_normal(im) 
    
    ax.ticklabel_format(style='plain', axis='both')
    
    ax.set_xlabel("")
    ax.set_ylabel("")
    
    ax.set_title(title_text, fontsize=14, color=title_color)

ani = FuncAnimation(
    fig, update_plot, frames=28, interval=500  # 28 frame, 500ms per frame
)

ani.save("animation.mp4", writer="ffmpeg", fps=2)

# Mostra l'animazione
plt.show()