In [None]:
pip install hvplot

In [None]:
import numpy as np
import xarray as xr
import hvplot.xarray
import holoviews as hv
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [None]:
# Update function for animation
def update(frame, ds):
    ax.clear()
    cax = ax.imshow(out_data.isel(time=frame).values, origin='lower', aspect='auto',
                    extent=[float(ds.lon.min()), float(ds.lon.max()), float(ds.lat.min()), float(ds.lat.max())])
    ax.set_title(f'Time: {str(out_data.time[frame].values)}')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    cbar.update_normal(cax)
    return cax,

In [None]:
output_path = '/homes/g24meda/lab/4dvarnet-starter/outputs/strong_4dvar_qg/promising/best_lrgmod_01_lrgrad_100_nstep_20_sigma2_kernelsize21_alpha1_1_alpha2_05_avgpool2_dt10min/11-26-41/QG_new_4_nadirs_DC_2020a_ssh/test_data.nc'
#'/homes/g24meda/lab/4dvarnet-starter/outputs/2024-05-29/14-59-05/Trained_Bilin_Prior_4_nadirs_DC_2020a_ssh/test_data.nc'
#'/homes/g24meda/lab/4dvarnet-starter/outputs/2024-05-30/23-05-33/QG_4_nadirs_DC_2020a_ssh/test_data.nc'

In [None]:
test_out = xr.open_dataset(output_path)
test_out

# 2D plot

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Plotting each subplot
test_out.out.isel(time=10).plot(ax=axs[0])
test_out.tgt.isel(time=10).plot(ax=axs[1])
test_out.inp.isel(time=10).plot(ax=axs[2])

plt.show()

# Animation

In [None]:
%matplotlib agg
# Extract the data variables and ensure they are loaded into memory
out_data = test_out['out'].load()
tgt_data = test_out['tgt'].load()
inp_data = test_out['inp'].load()

# Determine the global min and max values for the color scale for each variable
vmin, vmax = tgt_data.min().values, tgt_data.max().values

# Create a figure and axis for the plot
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Initialize the plots with the first time slice
cax_out = axs[0].imshow(out_data.isel(time=0).values, origin='lower', aspect='auto',
                        extent=[float(test_out.lon.min()), float(test_out.lon.max()), float(test_out.lat.min()), float(test_out.lat.max())],
                        cmap='coolwarm', vmin=vmin, vmax=vmax);
axs[0].set_title(f'Out Time: {str(out_data.time[0].values)}');
axs[0].set_xlabel('Longitude');
axs[0].set_ylabel('Latitude');
cbar_out = fig.colorbar(cax_out, ax=axs[0], orientation='vertical');
cbar_out.set_label('out');

cax_tgt = axs[1].imshow(tgt_data.isel(time=0).values, origin='lower', aspect='auto',
                        extent=[float(test_out.lon.min()), float(test_out.lon.max()), float(test_out.lat.min()), float(test_out.lat.max())],
                        cmap='coolwarm', vmin=vmin, vmax=vmax);
axs[1].set_title(f'Tgt Time: {str(tgt_data.time[0].values)}');
axs[1].set_xlabel('Longitude');
axs[1].set_ylabel('Latitude');
cbar_tgt = fig.colorbar(cax_tgt, ax=axs[1], orientation='vertical');
cbar_tgt.set_label('tgt');

cax_inp = axs[2].imshow(inp_data.isel(time=0).values, origin='lower', aspect='auto',
                        extent=[float(test_out.lon.min()), float(test_out.lon.max()), float(test_out.lat.min()), float(test_out.lat.max())],
                        cmap='coolwarm', vmin=vmin, vmax=vmax);
axs[2].set_title(f'Inp Time: {str(inp_data.time[0].values)}');
axs[2].set_xlabel('Longitude');
axs[2].set_ylabel('Latitude');
cbar_inp = fig.colorbar(cax_inp, ax=axs[2], orientation='vertical');
cbar_inp.set_label('inp');


# Update function for animation
def update(frame):
    for ax in axs:
        ax.clear()
    
    cax_out = axs[0].imshow(out_data.isel(time=frame).values, origin='lower', aspect='auto',
                            extent=[float(test_out.lon.min()), float(test_out.lon.max()), float(test_out.lat.min()), float(test_out.lat.max())],
                            cmap='coolwarm', vmin=vmin, vmax=vmax)
    axs[0].set_title(f'Out Time: {str(out_data.time[frame].values)}')
    axs[0].set_xlabel('Longitude')
    axs[0].set_ylabel('Latitude')
    cbar_out.update_normal(cax_out)

    cax_tgt = axs[1].imshow(tgt_data.isel(time=frame).values, origin='lower', aspect='auto',
                            extent=[float(test_out.lon.min()), float(test_out.lon.max()), float(test_out.lat.min()), float(test_out.lat.max())],
                            cmap='coolwarm', vmin=vmin, vmax=vmax)
    axs[1].set_title(f'Tgt Time: {str(tgt_data.time[frame].values)}')
    axs[1].set_xlabel('Longitude')
    axs[1].set_ylabel('Latitude')
    cbar_tgt.update_normal(cax_tgt)

    cax_inp = axs[2].imshow(inp_data.isel(time=frame).values, origin='lower', aspect='auto',
                            extent=[float(test_out.lon.min()), float(test_out.lon.max()), float(test_out.lat.min()), float(test_out.lat.max())],
                            cmap='coolwarm', vmin=vmin, vmax=vmax)
    axs[2].set_title(f'Inp Time: {str(inp_data.time[frame].values)}')
    axs[2].set_xlabel('Longitude')
    axs[2].set_ylabel('Latitude')
    cbar_inp.update_normal(cax_inp)

    return cax_out, cax_tgt, cax_inp

# Create the animation
ani = FuncAnimation(fig, update, frames=len(out_data.time), interval=200, blit=False)

# To display the animation inline in a Jupyter Notebook
from IPython.display import HTML
HTML(ani.to_jshtml())