## Visualize WoFCast Predictions

In [1]:
import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

from wofscast.model import WoFSCastModel
from wofscast.border_mask import BORDER_MASK_NUMPY

from wofscast.data_generator import (load_chunk, 
                                     WRFZarrFileProcessor,
                                     WoFSDataProcessor, 
                                     dataset_to_input,
                                     add_local_solar_time
                                    )
from wofscast import checkpoint
from wofscast import rollout 
from wofscast.wofscast_task_config import (DBZ_TASK_CONFIG, 
                                           WOFS_TASK_CONFIG, 
                                           DBZ_TASK_CONFIG_1HR,
                                           DBZ_TASK_CONFIG_FULL
                                          )


from wofscast.diffusion import EDMPrecond
from diffusers import UNet2DModel

In [2]:
# For plotting. 
import os
import numpy as np
import xarray 
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.axes_grid1 import make_axes_locatable

from wofscast.plot import WoFSColors, WoFSLevels
from datetime import datetime, timedelta
import pandas as pd
import xarray as xr

def get_case_date(path):
    name = os.path.basename(path)
    comps = name.split('_')
    
    start_date = comps[1]+'_'+comps[2]
    start_date_dt = datetime.strptime(start_date, '%Y-%m-%d_%H%M%S.zarr')
    
    if start_date_dt.hour < 14:
        case_date = start_date_dt.date() - timedelta(days=1)
    else:
        case_date = start_date_dt.date() 
        
    return case_date.strftime('%Y%m%d')
    


def to_datetimes(path, n_times = 13):  
    name, freq, ens_mem = os.path.basename(path).split('__')
    start_time_dt = datetime.strptime(name.split('_to')[0], 'wrfwof_%Y-%m-%d_%H%M%S.zarr')
    start_time = pd.Timestamp(start_time_dt)
    
    dt_list = pd.date_range(start=start_time, periods=n_times, freq=freq)
    return dt_list[2:]

# Assuming 'BORDER_MASK' is available and correctly sized for your 'preds' and 'tars'
def border_difference_check(preds, tars, border_mask):
    """Calculate the difference at the border and return a mask of differences."""
    border_diff = np.abs(preds - tars)
    # Apply the border mask to get differences only at the border
    border_diff_masked = np.where(border_mask, border_diff, np.nan)  # NaN where not border
    return np.nanmax(border_diff_masked)  # Get the maximum difference at the border

display_name_mapper = {'U' : 'U-wind Comp.', 
          'V' : 'V-wind Comp.', 
          'W' : 'Vert. Velocity',
          'T' : 'Pot. Temp.', 
          'GEOPOT' : 'Geopot. Height',
          'QVAPOR' : 'QVAPOR', 
          'T2' : '2-m Temp.', 
          'COMPOSITE_REFL_10CM' : 'Comp. Refl.',
          'UP_HELI_MAX' : '2-5 km UH', 
          'RAIN_AMOUNT' : 'Rain Rate'
         }

units_mapper = {'T': 'K', 
                'QVAPOR': 'kg/kg', 
                'T2': 'F', 
                'U': 'm/s', 
                'V': 'm/s', 
                'W': 'm/s', 
                'GEOPOT': 'm', 
                'RAIN_AMOUNT': 'in', 
                'COMPOSITE_REFL_10CM': 'dBZ'
               }

## Predict with WoFS-Cast

In [12]:
use_raw_zarr = False 

if use_raw_zarr: 

    base_path = '/work2/wofs_zarr/2021/20210503/2300/ENS_MEM_07'
    zarr_files = [os.path.join(base_path, f) for f in os.listdir(base_path)]
    zarr_files.sort()

    zarr_files = [zarr_files[::2][:14]]
    
    process_fn = preprocessor = WoFSDataProcessor()
    dataset = load_chunk(zarr_files, batch_over_time=True, gpu_batch_size=32, preprocess_fn=process_fn)
    
    dataset = dataset.compute()
else:
    base_path = '/work/mflora/wofs-cast-data/full_domain_datasets/2021'
    name = 'wrfwof_2021-05-15_020000.zarr_to_2021-05-15_041000.zarr__10min__ens_mem_09.zarr'

    path = os.path.join(base_path, name)
    
    def preprocess_fn(dataset):
        #_path = '/work/mflora/wofs-cast-data/datasets_zarr/2021/'
        #latlon_path = os.path.join(_path, 'wrfwof_2021-05-15_040000_to_2021-05-15_043000__10min__ens_mem_09.zarr')
        #preprocess_fn = WoFSDataProcessor(latlon_path=latlon_path)
        
        #dataset = preprocess_fn(dataset)
        
        dataset = add_local_solar_time(dataset) 
        
        return dataset 
    
    dataset = load_chunk([path], 1, preprocess_fn)
    dataset = dataset.compute() 
    
    case_date = get_case_date(path)

In [13]:
from wofscast.normalization import normalize, unnormalize 

class PyTorchScaler:
    def __init__(self, mean, std):
        self.mean = mean 
        self.std = std 

    def scale(self, x):
        return normalize(x, self.std, self.mean)

    def unscale(self, x): 
        return unnormalize(x, self.std, self.mean)

In [14]:
norm_stats_path = '/work/mflora/wofs-cast-data/full_normalization_stats'
mean_by_level = xarray.load_dataset(os.path.join(norm_stats_path, 'mean_by_level.nc'))
stddev_by_level = xarray.load_dataset(os.path.join(norm_stats_path, 'stddev_by_level.nc'))

scaler = PyTorchScaler(mean_by_level, stddev_by_level)

In [6]:
%%time 

# Corey's biggest model 
MODEL_PATH = '/work/cpotvin/WOFSCAST/model/wofscast_test_v178.npz'

model = WoFSCastModel()
model.load_model(MODEL_PATH, **{'tiling' : (2,2)})
 
model.load_model(MODEL_PATH)
inputs, targets, forcings = dataset_to_input(dataset, model.task_config, 
                                             target_lead_times=slice('10min', '120min'), 
                                             batch_over_time=False, n_target_steps=12)


predictions = model.predict(inputs, targets, forcings)



TypeError: Scanned function carry input and carry output must have the same pytree structure, but they differ:
["the input tree structure is:\nPyTreeDef((CustomNode(Dataset[_HashableCoords({'lat': <xarray.IndexVariable 'lat' (lat: 150)>\narray([32.1181  , 32.14526 , 32.17244 , 32.199604, 32.22678 , 32.253963,\n       32.281155, 32.30834 , 32.33552 , 32.362713, 32.38991 , 32.417103,\n       32.4443  , 32.47151 , 32.49872 , 32.52592 , 32.553135, 32.580345,\n       32.60757 , 32.63479 , 32.662   , 32.689224, 32.716454, 32.74368 ,\n       32.770912, 32.79815 , 32.825386, 32.852627, 32.879864, 32.907116,\n       32.93436 , 32.96161 , 32.988876, 33.016125, 33.04338 , 33.070644,\n       33.0979  , 33.125164, 33.15244 , 33.17971 , 33.20698 , 33.234257,\n       33.26154 , 33.288815, 33.316105, 33.343388, 33.370674, 33.397964,\n       33.42526 , 33.452553, 33.47985 , 33.50716 , 33.534466, 33.561768,\n       33.589077, 33.61639 , 33.6437  , 33.671024, 33.69835 , 33.72566 ,\n       33.752975, 33.780304, 33.80763 , 33.83496 , 33.862297, 33.889633,\n       33.91697 , 33.944317, 33.971657, 33.999   , 34.026352, 34.053703,\n       34.08105 , 34.10841 , 34.13577 , 34.163128, 34.190506, 34.217857,\n       34.24522 , 34.27261 , 34.299976, 34.32735 , 34.354725, 34.382095,\n       34.409477, 34.436863, 34.464252, 34.491634, 34.519024, 34.54642 ,\n       34.573822, 34.601215, 34.628616, 34.656017, 34.683426, 34.71083 ,\n       34.738243, 34.76565 , 34.793068, 34.820488, 34.847908, 34.875328,\n       34.902752, 34.930176, 34.957607, 34.985035, 35.01247 , 35.03989 ,\n       35.067337, 35.09477 , 35.122204, 35.149654, 35.17709 , 35.20454 ,\n       35.231995, 35.25944 , 35.286896, 35.314358, 35.341816, 35.369274,\n       35.39674 , 35.42421 , 35.451675, 35.479145, 35.506615, 35.534092,\n       35.561554, 35.58904 , 35.61652 , 35.644   , 35.671486, 35.69898 ,\n       35.72646 , 35.753952, 35.781445, 35.80893 , 35.836433, 35.86393 ,\n       35.891434, 35.91893 , 35.946438, 35.973946, 36.00145 , 36.028973,\n       36.056473, 36.08398 , 36.111507, 36.13903 , 36.166542, 36.194077],\n      dtype=float32), 'lon': <xarray.IndexVariable 'lon' (lon: 150)>\narray([79.53119 , 79.56244 , 79.59366 , 79.62488 , 79.6561  , 79.68732 ,\n       79.71857 , 79.74979 , 79.781006, 79.812256, 79.843475, 79.874695,\n       79.905945, 79.937195, 79.968414, 79.99963 , 80.03088 , 80.06213 ,\n       80.09335 , 80.12457 , 80.15582 , 80.18707 , 80.21829 , 80.24954 ,\n       80.28076 , 80.31201 , 80.34326 , 80.37451 , 80.40573 , 80.43698 ,\n       80.4682  , 80.49945 , 80.5307  , 80.56195 , 80.5932  , 80.62445 ,\n       80.65567 , 80.68692 , 80.71817 , 80.74942 , 80.78064 , 80.81189 ,\n       80.84314 , 80.87439 , 80.90564 , 80.93689 , 80.96814 , 80.99939 ,\n       81.03064 , 81.06189 , 81.09314 , 81.12439 , 81.15564 , 81.18689 ,\n       81.21814 , 81.24936 , 81.28061 , 81.31186 , 81.34311 , 81.37436 ,\n       81.40561 , 81.43686 , 81.46811 , 81.49939 , 81.53064 , 81.56189 ,\n       81.59314 , 81.62439 , 81.65564 , 81.68689 , 81.71814 , 81.74939 ,\n       81.78064 , 81.81189 , 81.84314 , 81.87439 , 81.90564 , 81.93689 ,\n       81.96814 , 81.99939 , 82.03064 , 82.06189 , 82.09314 , 82.12439 ,\n       82.15564 , 82.18689 , 82.21814 , 82.24939 , 82.28064 , 82.31189 ,\n       82.34314 , 82.37439 , 82.40564 , 82.43689 , 82.46814 , 82.49939 ,\n       82.53064 , 82.56189 , 82.59314 , 82.62439 , 82.65564 , 82.68689 ,\n       82.71814 , 82.74939 , 82.78064 , 82.81189 , 82.84311 , 82.87436 ,\n       82.90561 , 82.93686 , 82.96811 , 82.99936 , 83.03058 , 83.06183 ,\n       83.09308 , 83.12433 , 83.15558 , 83.18683 , 83.21805 , 83.2493  ,\n       83.28055 , 83.31177 , 83.34302 , 83.37427 , 83.40552 , 83.43674 ,\n       83.46799 , 83.49921 , 83.53046 , 83.56171 , 83.592926, 83.624146,\n       83.655396, 83.686646, 83.717865, 83.749084, 83.780334, 83.811584,\n       83.842804, 83.87402 , 83.90527 , 83.93649 , 83.96771 , 83.99896 ,\n       84.03018 , 84.0614  , 84.09265 , 84.12387 , 84.15509 , 84.18631 ],\n      dtype=float32), 'time': <xarray.IndexVariable 'time' (time: 2)>\narray([-600000000000,             0], dtype='timedelta64[ns]')})], [{'COMPOSITE_REFL_10CM': CustomNode(Variable[('batch', 'time', 'lat', 'lon')], [*]), 'GEOPOT': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'QVAPOR': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'RAIN_AMOUNT': CustomNode(Variable[('batch', 'time', 'lat', 'lon')], [*]), 'T': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'T2': CustomNode(Variable[('batch', 'time', 'lat', 'lon')], [*]), 'U': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'V': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'W': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*])}, {}]), CustomNode(namedtuple[InternalState], [None, CustomNode(defaultdict[(<class 'dict'>, ())], []), (*, (*,))])))\n", "the output tree structure is:\nPyTreeDef((CustomNode(Dataset[_HashableCoords({'lat': <xarray.IndexVariable 'lat' (lat: 150)>\narray([32.1181  , 32.14526 , 32.17244 , 32.199604, 32.22678 , 32.253963,\n       32.281155, 32.30834 , 32.33552 , 32.362713, 32.38991 , 32.417103,\n       32.4443  , 32.47151 , 32.49872 , 32.52592 , 32.553135, 32.580345,\n       32.60757 , 32.63479 , 32.662   , 32.689224, 32.716454, 32.74368 ,\n       32.770912, 32.79815 , 32.825386, 32.852627, 32.879864, 32.907116,\n       32.93436 , 32.96161 , 32.988876, 33.016125, 33.04338 , 33.070644,\n       33.0979  , 33.125164, 33.15244 , 33.17971 , 33.20698 , 33.234257,\n       33.26154 , 33.288815, 33.316105, 33.343388, 33.370674, 33.397964,\n       33.42526 , 33.452553, 33.47985 , 33.50716 , 33.534466, 33.561768,\n       33.589077, 33.61639 , 33.6437  , 33.671024, 33.69835 , 33.72566 ,\n       33.752975, 33.780304, 33.80763 , 33.83496 , 33.862297, 33.889633,\n       33.91697 , 33.944317, 33.971657, 33.999   , 34.026352, 34.053703,\n       34.08105 , 34.10841 , 34.13577 , 34.163128, 34.190506, 34.217857,\n       34.24522 , 34.27261 , 34.299976, 34.32735 , 34.354725, 34.382095,\n       34.409477, 34.436863, 34.464252, 34.491634, 34.519024, 34.54642 ,\n       34.573822, 34.601215, 34.628616, 34.656017, 34.683426, 34.71083 ,\n       34.738243, 34.76565 , 34.793068, 34.820488, 34.847908, 34.875328,\n       34.902752, 34.930176, 34.957607, 34.985035, 35.01247 , 35.03989 ,\n       35.067337, 35.09477 , 35.122204, 35.149654, 35.17709 , 35.20454 ,\n       35.231995, 35.25944 , 35.286896, 35.314358, 35.341816, 35.369274,\n       35.39674 , 35.42421 , 35.451675, 35.479145, 35.506615, 35.534092,\n       35.561554, 35.58904 , 35.61652 , 35.644   , 35.671486, 35.69898 ,\n       35.72646 , 35.753952, 35.781445, 35.80893 , 35.836433, 35.86393 ,\n       35.891434, 35.91893 , 35.946438, 35.973946, 36.00145 , 36.028973,\n       36.056473, 36.08398 , 36.111507, 36.13903 , 36.166542, 36.194077],\n      dtype=float32), 'lon': <xarray.IndexVariable 'lon' (lon: 150)>\narray([79.53119 , 79.56244 , 79.59366 , 79.62488 , 79.6561  , 79.68732 ,\n       79.71857 , 79.74979 , 79.781006, 79.812256, 79.843475, 79.874695,\n       79.905945, 79.937195, 79.968414, 79.99963 , 80.03088 , 80.06213 ,\n       80.09335 , 80.12457 , 80.15582 , 80.18707 , 80.21829 , 80.24954 ,\n       80.28076 , 80.31201 , 80.34326 , 80.37451 , 80.40573 , 80.43698 ,\n       80.4682  , 80.49945 , 80.5307  , 80.56195 , 80.5932  , 80.62445 ,\n       80.65567 , 80.68692 , 80.71817 , 80.74942 , 80.78064 , 80.81189 ,\n       80.84314 , 80.87439 , 80.90564 , 80.93689 , 80.96814 , 80.99939 ,\n       81.03064 , 81.06189 , 81.09314 , 81.12439 , 81.15564 , 81.18689 ,\n       81.21814 , 81.24936 , 81.28061 , 81.31186 , 81.34311 , 81.37436 ,\n       81.40561 , 81.43686 , 81.46811 , 81.49939 , 81.53064 , 81.56189 ,\n       81.59314 , 81.62439 , 81.65564 , 81.68689 , 81.71814 , 81.74939 ,\n       81.78064 , 81.81189 , 81.84314 , 81.87439 , 81.90564 , 81.93689 ,\n       81.96814 , 81.99939 , 82.03064 , 82.06189 , 82.09314 , 82.12439 ,\n       82.15564 , 82.18689 , 82.21814 , 82.24939 , 82.28064 , 82.31189 ,\n       82.34314 , 82.37439 , 82.40564 , 82.43689 , 82.46814 , 82.49939 ,\n       82.53064 , 82.56189 , 82.59314 , 82.62439 , 82.65564 , 82.68689 ,\n       82.71814 , 82.74939 , 82.78064 , 82.81189 , 82.84311 , 82.87436 ,\n       82.90561 , 82.93686 , 82.96811 , 82.99936 , 83.03058 , 83.06183 ,\n       83.09308 , 83.12433 , 83.15558 , 83.18683 , 83.21805 , 83.2493  ,\n       83.28055 , 83.31177 , 83.34302 , 83.37427 , 83.40552 , 83.43674 ,\n       83.46799 , 83.49921 , 83.53046 , 83.56171 , 83.592926, 83.624146,\n       83.655396, 83.686646, 83.717865, 83.749084, 83.780334, 83.811584,\n       83.842804, 83.87402 , 83.90527 , 83.93649 , 83.96771 , 83.99896 ,\n       84.03018 , 84.0614  , 84.09265 , 84.12387 , 84.15509 , 84.18631 ],\n      dtype=float32), 'time': <xarray.IndexVariable 'time' (time: 2)>\narray([-600000000000,             0], dtype='timedelta64[ns]'), 'level': <xarray.IndexVariable 'level' (level: 17)>\narray([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16])})], [{'COMPOSITE_REFL_10CM': CustomNode(Variable[('batch', 'time', 'lat', 'lon')], [*]), 'GEOPOT': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'QVAPOR': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'RAIN_AMOUNT': CustomNode(Variable[('batch', 'time', 'lat', 'lon')], [*]), 'T': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'T2': CustomNode(Variable[('batch', 'time', 'lat', 'lon')], [*]), 'U': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'V': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*]), 'W': CustomNode(Variable[('batch', 'time', 'lat', 'lon', 'level')], [*])}, {}]), CustomNode(namedtuple[InternalState], [None, CustomNode(defaultdict[(<class 'dict'>, ())], []), (*, (*,))])))\n"]
Revise the scanned function so that its output is a pair where the first element has the same pytree structure as the first argument.

In [7]:

class WoFSCastAnimator:
    def __init__(self, domain_size, plot_border=False, dts=None):
        self.dts = dts  # Placeholder, replace with your datetime conversion function
        self.plot_border = plot_border
        self.domain_size = domain_size
    
    def __call__(self, var, level, inputs, predictions, targets, mrms_dz=None):
        self.var = var
        self.level = level
        self.inputs = inputs
        self.predictions = predictions
        self.targets = targets
        self.mrms_dz = mrms_dz
    
        init_ds, pred, tars = self.drop_batch_dim(inputs, predictions, targets)
    
        level_txt = ''
        if level != 'none': 
            level_txt = f', level={level}'
    
        self.titles = [f'WoFS {display_name_mapper.get(var, var)}{level_txt}', 
                       f'WoFS-Cast {display_name_mapper.get(var, var)}{level_txt}']
    
        fig, self.axes = plt.subplots(dpi=200, figsize=(12, 6), ncols=2, 
                                      gridspec_kw={'height_ratios': [1], 'bottom': 0.15})
        
        plt.tight_layout()
    
        zs, levels = self.get_target_and_pred_pair(init_ds, init_ds, t=0, level=level, return_rng=True)
    
        self.cmap, self.levels = self.get_colormap_and_levels(var, levels)
    
        self.cbar_ax = fig.add_axes([0.15, 0.075, 0.7, 0.02])
        self.cbar = None
        
        self.fig = fig
        self.N = len(predictions.time)
    
        return FuncAnimation(fig, self.update, frames=self.N, interval=200)
    
    def drop_batch_dim(self, inputs, predictions, targets):
        dims = ('time', 'level', 'lat', 'lon')
        init_ds = inputs.squeeze(dim='batch', drop=True).isel(time=[-1]).transpose(*dims, missing_dims='ignore')
        preds = predictions.squeeze(dim='batch', drop=True).transpose(*dims, missing_dims='ignore')
        tars = targets.squeeze(dim='batch', drop=True).transpose(*dims, missing_dims='ignore')
    
        return init_ds, preds, tars
    
    def get_target_and_pred_pair(self, preds, targets, t, level=0, return_rng=False):
        if level == 'max':
            zs = [targets[self.var].isel(time=t).max(dim='level').values, 
                  preds[self.var].isel(time=t).max(dim='level').values]
        elif level == 'min': 
            zs = [targets[self.var].isel(time=t).min(dim='level').values, 
                  preds[self.var].isel(time=t).min(dim='level').values]
        elif level == 'none':
            zs = [targets[self.var].isel(time=t).values, 
                  preds[self.var].isel(time=t).values]
        else:
            zs = [targets[self.var].isel(time=t, level=level).values, 
                  preds[self.var].isel(time=t, level=level).values]
    
        if self.var == 'RAIN_AMOUNT':
            zs = [z / 25.4 for z in zs]
    
        if self.var == 'T2':
            zs = [(9.0 / 5.0 * (z - 273.15)) + 32.0 for z in zs]
    
        if return_rng:
            global_min = np.percentile(zs, 1)
            global_max = np.percentile(zs, 99)
            rng = np.linspace(global_min, global_max, 10)
            return zs, rng
    
        return zs 
    
    def get_colormap_and_levels(self, var, levels):
        if var == 'COMPOSITE_REFL_10CM':
            cmap = WoFSColors.nws_dz_cmap
            levels = WoFSLevels.dz_levels_nws
        elif var == 'RAIN_AMOUNT':
            cmap = WoFSColors.rain_cmap
            levels = WoFSLevels.rain_rate_levels
        elif var == 'UP_HELI_MAX':
            cmap = WoFSColors.wz_cmap_extend
            levels = WoFSLevels.uh_2to5_levels_3000m
        elif var == 'T2':
            cmap = WoFSColors.temp_cmap
            levels = np.arange(40., 90., 2.5)
        elif var == 'QVAPOR': 
            cmap = WoFSColors.temp_cmap
        elif var == 'W': 
            cmap = WoFSColors.wz_cmap_extend
            levels = [2.5, 5, 10, 15, 20, 25, 30, 35, 40]
        else:
            cmap = WoFSColors.wz_cmap_extend
        
        return cmap, levels
    
    def update(self, t):
        for ax in self.axes:
            ax.clear()

        if t == 0:
            zs = self.get_target_and_pred_pair(self.inputs, self.inputs, t=0, level=self.level)
        else:    
            zs = self.get_target_and_pred_pair(self.predictions, self.targets, t=t, level=self.level)
    
        rmse = np.sqrt(np.mean((zs[0] - zs[1])**2))

        try:
            u_pred, v_pred = self.predictions['U'][t].isel(level=0).values, self.predictions['V'][t].isel(level=0).values
            u_tar, v_tar = self.targets['U'][t].isel(level=0).values, self.targets['V'][t].isel(level=0).values 
            u_pred = u_pred[::5, ::5]
            v_pred = v_pred[::5, ::5]
            u_tar = u_tar[::5, ::5]
            v_tar = v_tar[::5, ::5]
            wind_pred = (u_pred, v_pred)
            wind_tar = (u_tar, v_tar)
            winds = [wind_tar, wind_pred]
            x, y = np.meshgrid(np.arange(self.domain_size), np.arange(self.domain_size))
            x = x[::5, ::5]
            y = y[::5, ::5]
        except:
            winds = [None, None]
    
        for i, (ax, z, wind) in enumerate(zip(self.axes, zs, winds)):
            
            z = z.squeeze() 
            
            if self.var in ['REFL_10CM', 'UP_HELI_MAX', 'COMPOSITE_REFL_10CM']:
                z = np.ma.masked_where(z < 1, z)
            
            
            im = ax.contourf(z, origin='lower', aspect='equal', cmap=self.cmap, levels=self.levels)
 
            try:
                u, v = wind
                ax.quiver(x, y, u, v, alpha=0.5)
            except:
                pass

            ax.set_title(self.titles[i], fontweight='bold')
            if i == 1:
                dis_name = display_name_mapper.get(self.var, self.var)
                ax.annotate(f'RMSE of {dis_name} ({units_mapper.get(self.var, self.var)}): {rmse:.4f}', 
                            xy=(0.01, 0.95), xycoords='axes fraction', 
                            weight='bold', color='red', 
                            bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

            ax.annotate(f'Time: {self.dts[t]}', xy=(0.01, 0.01), xycoords='axes fraction', 
                        weight='bold', color='red', fontsize=10, 
                        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

            
            if self.cbar is None:
                self.cbar = self.fig.colorbar(im, cax=self.cbar_ax, orientation='horizontal')
                self.cbar.set_label(f'{display_name_mapper.get(self.var, self.var)} ({units_mapper.get(self.var, self.var)})')

            # Plot the MRMS overlays
            if self.mrms_dz is not None:
                this_rmse = np.sqrt(np.mean((z - self.mrms_dz[t])**2))
            
                ax.contour(self.mrms_dz[t], 
                         origin='lower', aspect='equal', 
                        colors=['black', 'blue'], 
                        levels=[35.0, 50.0], linewidths=[1.0, 1.5])
            
                dis_name = display_name_mapper.get(self.var, self.var)
                ax.annotate(f'RMSE with MRMS: {this_rmse:.4f}', 
                            xy=(0.01, 0.90), xycoords='axes fraction', 
                            weight='bold', color='k', 
                            bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
            

In [8]:
# Load MRMS 

from pathlib import Path 
class MRMSDataLoader: 
    
    MRMS_PATH = '/work/rt_obs/MRMS/RAD_AZS_MSH/'
    
    def __init__(self, case_date, datetime_rng):
        self.case_date = case_date
        self.datetime_rng = datetime_rng 

    def find_mrms_files(self):
        """
        When given a start and end date, this function will find any MRMS RAD 
        files between those time periods. It will check if the path exists. 
        """
        year = str(self.datetime_rng[0].year) 
        mrms_filenames = [date.strftime('wofs_MRMS_RAD_%Y%m%d_%H%M.nc') for date in self.datetime_rng]

        mrms_filepaths = [Path(self.MRMS_PATH).joinpath(year, self.case_date, f) 
                          if Path(self.MRMS_PATH).joinpath(year, self.case_date, f).is_file() else None
                          for f in mrms_filenames 
                 ]
       
        return mrms_filepaths 
    
    def resize(self, ds, n_lat=300, n_lon=300, domain_size=150):
        """Resize the domain"""
        start_lat, start_lon = (n_lat - domain_size) // 2, (n_lon - domain_size) // 2
        end_lat, end_lon = start_lat + domain_size, start_lon + domain_size
        
        # Subsetting the dataset to the central size x size grid
        ds_subset = ds.isel(lat=slice(start_lat, end_lat), lon=slice(start_lon, end_lon))
        
        return ds_subset
    
    def load(self):

        files = self.find_mrms_files()
        
        # Initialize an empty list to store the datasets with 'mesh_consv' variable
        data = np.zeros((len(files), 150, 150))

        # Load 'mesh_consv' variable from each file and append to the datasets list
        for t, file in enumerate(files):
            if file is not None: 
                ds = xr.open_dataset(file, drop_variables=['lat', 'lon'])
                
                # Resize the output to 150 x 150
                ds = self.resize(ds)
                
                data[t,:,:] = ds['dz_consv'].values
    
                ds.close()
        
        return data


In [9]:
dts = to_datetimes(path, n_times = 14)
loader = MRMSDataLoader(case_date, dts)
mrms_dz = loader.load() 

In [10]:
animator = WoFSCastAnimator(domain_size=150, dts=dts)
anim = animator(
                var='T2',
                #var='COMPOSITE_REFL_10CM', 
                #var = 'U',
                #var = 'QVAPOR',
                level='none',
                inputs=inputs, 
                predictions=predictions, 
                targets=targets, 
                mrms_dz = mrms_dz
               )

# To display the animation in a Jupyter notebook
from IPython.display import HTML
HTML(anim.to_jshtml())

# Optionally, to save the animation
#anim.save("wofscast_with_diffusion_best_model.gif", writer="pillow", fps=3)   

NameError: name 'predictions' is not defined