In [1]:
import os
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter

from prototypes.ocn_only.R1.config import OcnTrainer as Emulator 
from graphufs.stacked_utils import get_channel_index
from graphufs.datasets import Dataset

In [2]:
_scratch = "/pscratch/sd/n/nagarwal"
_prototype = "ocn-only"
_expt = "R1"

In [3]:
inputs = xr.open_zarr(os.path.join(_scratch, _prototype, _expt, "training", "inputs.zarr"))
targets = xr.open_zarr(os.path.join(_scratch, _prototype, _expt, "training", "targets.zarr"))

In [4]:
targets

Unnamed: 0,Array,Chunk
Bytes,897.18 GiB,18.00 MiB
Shape,"(37983, 192, 384, 43)","(1, 192, 384, 32)"
Dask graph,75966 chunks in 2 graph layers,75966 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 897.18 GiB 18.00 MiB Shape (37983, 192, 384, 43) (1, 192, 384, 32) Dask graph 75966 chunks in 2 graph layers Data type float64 numpy.ndarray",37983  1  43  384  192,

Unnamed: 0,Array,Chunk
Bytes,897.18 GiB,18.00 MiB
Shape,"(37983, 192, 384, 43)","(1, 192, 384, 32)"
Dask graph,75966 chunks in 2 graph layers,75966 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [5]:
em = Emulator()
tds = Dataset(em, mode="training")
xinputs, xtargets, _ = tds.get_xarrays(0)
tmeta_inp = get_channel_index(xinputs)
tmeta_tar = get_channel_index(xtargets)

In [6]:
tmeta_tar

{0: {'varname': 'LW', 'time': 0},
 1: {'varname': 'SSH', 'time': 0},
 2: {'varname': 'SW', 'time': 0},
 3: {'varname': 'so', 'z_l': 0, 'time': 0},
 4: {'varname': 'so', 'z_l': 1, 'time': 0},
 5: {'varname': 'so', 'z_l': 2, 'time': 0},
 6: {'varname': 'so', 'z_l': 3, 'time': 0},
 7: {'varname': 'so', 'z_l': 4, 'time': 0},
 8: {'varname': 'so', 'z_l': 5, 'time': 0},
 9: {'varname': 'so', 'z_l': 6, 'time': 0},
 10: {'varname': 'so', 'z_l': 7, 'time': 0},
 11: {'varname': 'so', 'z_l': 8, 'time': 0},
 12: {'varname': 'so', 'z_l': 9, 'time': 0},
 13: {'varname': 'temp', 'z_l': 0, 'time': 0},
 14: {'varname': 'temp', 'z_l': 1, 'time': 0},
 15: {'varname': 'temp', 'z_l': 2, 'time': 0},
 16: {'varname': 'temp', 'z_l': 3, 'time': 0},
 17: {'varname': 'temp', 'z_l': 4, 'time': 0},
 18: {'varname': 'temp', 'z_l': 5, 'time': 0},
 19: {'varname': 'temp', 'z_l': 6, 'time': 0},
 20: {'varname': 'temp', 'z_l': 7, 'time': 0},
 21: {'varname': 'temp', 'z_l': 8, 'time': 0},
 22: {'varname': 'temp', 'z_l':

In [7]:
def animate(dataset, frames, vmin=None, vmax=None, output_file=None):
    """
    Animates channels of the specified dataset.

    Parameters:
    - dataset: xarray dataset.
    - frames: List or array indicating the frames for animation.
    - output_file: If provided, saves the animation to this file (e.g., 'animation.mp4').
    """
    
    # Create a figure with subplots 
    fig = plt.figure(figsize=(8, 6))
    dataset.isel(sample=0).plot()
    plt.title(f"sample=0")
    
    def update(frame):
        # Clear previous plots
        plt.clf()
        # Plot the current frame for dataset on ax
        dataset_mean = dataset.mean(dim="sample")
        diff = dataset.isel(sample=frame) - dataset_mean
        ax = diff.plot(vmin=vmin, vmax=vmax, cmap="RdBu_r")
        plt.title(f"sample={frame}")
        return ax

    # Create the animation
    ani = FuncAnimation(fig, update, frames=frames, interval=200)

    # If output_file is specified, save the animation
    if output_file:
        #FFwriter = FFMpegWriter(fps=30, extra_args=['-vcodec', 'libx264'])
        ani.save(output_file, writer="pillow")

    # Display the animation
    plt.show()

In [None]:
delta_t = "6 hours"
dt = int(pd.Timedelta("1 day")/pd.Timedelta(delta_t))

num_days_to_plot = 30
iyear = 0
sample0 = iyear*num_days_to_plot*dt
sampleEnd = (iyear+1)*dt*num_days_to_plot  

channel = 12  # salinity:41 for inputs
kwargs = {"vmin":-0.005, "vmax":0.005, "output_file":"animation.gif"}

# data
dataset = targets.targets.isel(channels=channel, sample=slice(sample0, sampleEnd)).load()
# animate
animate(dataset, frames=range(1, sampleEnd - sample0), **kwargs)