In [1]:
import xarray as xr

sagnn1 = xr.open_dataset('output_1452x7x49x69.nc')
sagnn2 = xr.open_dataset('output_15x7x49x69.nc')


In [2]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

import cartopy.crs as ccrs

from IPython.display import HTML


def plot_example(prediction, actual, seed: int = 0, offset: int = 0, frame_rate: int = 16, levels: int = 10) -> HTML:
    lon, lat = actual.longitude, actual.latitude
        
    bounds = [lon.min().item(), lon.max().item(), lat.min().item(), lat.max().item()]

    pred = prediction.isel(time=seed)
    target = actual.isel(time=seed + offset)

    init_time_pred = pred.time.values
    init_time_target = target.time.values
    
    times_pred = np.array(prediction['prediction_timedelta'].values).astype('timedelta64[ns]')
    times_pred = np.array([np.datetime64(init_time_pred + time) for time in times_pred])

    times_target = np.array(actual['prediction_timedelta'].values).astype('timedelta64[ns]')    
    times_target = np.array([np.datetime64(init_time_target + time) for time in times_target])

    times = times_target

    print('Predictions:', times_pred[0])
    print('Actual:', times_target[0])


    fig, axs = plt.subplots(1, 3, figsize=(14, 7), subplot_kw={'projection': ccrs.PlateCarree()})

    vmin = min(pred['wind_speed'].values.min(), target['wind_speed'].values.min())
    vmax = max(pred['wind_speed'].values.max(), target['wind_speed'].values.max())

    fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

    # print('Predictions:', pred['wind_speed'].shape)

    for ax in axs:
        ax.coastlines()
        ax.set_extent(bounds, crs=ccrs.PlateCarree())
        
    pred_states = pred['wind_speed'].transpose('prediction_timedelta', 'latitude', 'longitude').values
    target_states = target['wind_speed'].transpose('prediction_timedelta', 'latitude', 'longitude').values
    # err_states = np.abs(target_states - pred_states)
    # emin = err_states.min() 
    # emax = err_states.max()

    pr = axs[0].contourf(lon, lat, pred_states[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
    ac = axs[1].contourf(lon, lat, target_states[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
    # err = axs[2].contourf(lon, lat, err_states[0], levels=levels, vmin=emin, vmax=emax, transform=ccrs.PlateCarree(), cmap='coolwarm')
    axs[0].set_title(f'Predicted {0} - {times[0]}')
    axs[1].set_title(f'Actual {0} - {times[0]}')
    # axs[2].set_title(f'Error {0} - {times[0]}')


    fig.colorbar(pr, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)', shrink=0.3)
    fig.colorbar(ac, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)', shrink=0.3)
    # fig.colorbar(err, ax=axs[2], orientation='vertical', label='Absolute Error (m/s)', shrink=0.3)

    def animate(i):
        for ax in axs:
            ax.clear()
            ax.coastlines()

        axs[0].contourf(lon, lat, pred_states[i], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        axs[1].contourf(lon, lat, target_states[i], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        # axs[2].contourf(lon, lat, err_states[i], levels=levels, transform=ccrs.PlateCarree(), cmap='coolwarm')

        axs[0].set_title(f'Predicted {i} - {times[i]}') # .strftime("%Y-%m-%d %H:%M:%S")
        axs[1].set_title(f'Actual {i} - {times[i]}')
        # axs[2].set_title(f'Error {i} - {times[i]}')

            
    frames = pred_states.shape[0]

    interval = 1000 / frame_rate

    ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

    plt.close(fig)

    return HTML(ani.to_jshtml())


plot_example(sagnn1, sagnn2, seed=0, offset=0, frame_rate=3, levels=10)

Predictions: 2020-01-01T12:00:00.000000000
Actual: 2020-01-01T12:00:00.000000000
