# Test pretrained model
This notebook tests the pretrained model on a single datacube taken from the radar dataset (https://arcodatahub.com/datasets/datasets/italian-radar-dpc-sri.zarr).

In [None]:
import sys
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import pysteps.visualization.precipfields as pysteps_plot

sys.path.append('../')

import torch
from lightning_model import RadarLightningModel

# Load radar data
We first load a sample of the italian radar dataset saved in the data folder

In [None]:
radar = xr.open_dataarray('../data/test_radar_sample_54.nc')
radar

This contains 18 sequences of radar images on the whole Italy, from the 28th to the 29th of October 2024. This is one of the most intense precipitation on Italy during 2024.

In [None]:
# Create figure
fig, ax = plt.subplots(figsize=(4, 4.5))

def update(frame):
    ax.clear()
    data = radar.isel(time=frame)
    pysteps_plot.plot_precip_field(data.values, ax=ax, colorbar=False)
    ax.set_title(f'Precipitation - {data.time.values}')
    return ax,

# Create animation
ani = animation.FuncAnimation(fig, update, frames=len(radar.time), 
                             interval=500, blit=False, repeat=True)

# Display in notebook
from IPython.display import HTML
display(HTML(ani.to_jshtml()))
plt.close()

# Initialize the model 
Initialize the model and load the weights from the checkpoint. You can change the number of future steps (forecast steps) and the ensemble size (ensemble_size). The other hyperparameters are fixed.

In [None]:
# Set model's parameters
forecast_steps   = 12
ensemble_size    = 10
device           = 'gpu'

# Initialize the model and load the checkpoint
model = RadarLightningModel.from_checkpoint(checkpoint_path="../checkpoints/ConvGRU-CRPS_6past_12fut.ckpt")

# Run the inference
We can run the inference and plot the forecast

In [None]:
# Past and future steps
past_steps = 7
forecast_steps = 12
past, future = radar[:past_steps], radar[past_steps:past_steps+forecast_steps]

# Predict the future rainrate
pred = model.predict(past, forecast_steps, ensemble_size=2, device='cuda')

### Plot the forecast

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from pysteps.visualization import plot_precip_field

# Create figure with 4 subplots
fig, axs = plt.subplots(1, 4, figsize=(16,4.5))

# Ensemble mean
ensemble_mean = np.nanmean(pred, axis=0)

# Initialize plots with pysteps
row_labels = ['Ground Truth', 'Ensemble Mean', 'Member 1', 'Member 2']
data_sources = [future, ensemble_mean, pred[0], pred[1]]

# Plot initial frame
for i, (ax, label, data) in enumerate(zip(axs, row_labels, data_sources)):
    plot_precip_field(data[0], ax=ax, units='mm/h', colorscale='pysteps')
    ax.set_title(label, fontsize=14)

plt.tight_layout()

# Animation function
def update(frame):
    for i, (ax, data) in enumerate(zip(axs, data_sources)):
        ax.clear()
        plot_precip_field(data[frame], ax=ax, units='mm/h', colorscale='pysteps', colorbar=False)
        ax.set_title(f'{row_labels[i]} - Step {frame}', fontsize=14)
    return axs

# Create animation
anim = FuncAnimation(fig, update, frames=forecast_steps, interval=500, blit=False)

# Display
HTML(anim.to_jshtml())

display(HTML(anim.to_jshtml()))
plt.close()