## Imports

In [1]:
import xarray as xr
import numpy as np
import imageio.v2 as imageio
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
import datetime
from pathlib import Path
import shutil

## Setup

In [2]:
nc_path = "/N/slate/jmelms/projects/FCN_dynamical_testing/data/output/ideal_default_60t.nc" # 
lat_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/latitude.npy"
lon_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/longitude.npy"
lsm_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/land_sea_edges_mask.npy"
mean_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/global_means.npy"
std_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/global_stds.npy"
img_out_path = Path("/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/code/analysis/imgs")
img_out_path.mkdir(exist_ok=True)

## Loading Data

In [5]:
# load lat, lon, and land-sea mask
lat = np.load(lat_path)
lon = np.load(lon_path)
lsm = np.load(lsm_path)
mean = np.load(mean_path)
std = np.load(std_path)
nlat = lat.size
nlon = lon.size
dt = 6 # time step in hours

# load netcdf data
ds = xr.open_dataset(nc_path)

# get the data array
da = ds["__xarray_dataarray_variable__"]

# remove the "history" dimension because it's not used in these runs
da = da.squeeze(drop=True)

# check to make sure that the lat/lon dimensions are the same as the input data
assert da.lat.size == nlat, "Latitude dimensions do not match"
assert da.lon.size == nlon, "Longitude dimensions do not match"

# change time dimension to datetime.datetime if using real times
# ds["time"] = [datetime.datetime.fromtimestamp(t/10e8, tz=datetime.timezone.utc) for t in ds.time.values]

# change time dimension to hours from initialization if using relative times for idealized sim
da["time"] = [t/10e8/3600 for t in da.time.values]
da.loc[dict(channel="msl")] = da.sel(channel="msl") / 100 # converting to hPa
print(da)

<xarray.DataArray '__xarray_dataarray_variable__' (time: 61, channel: 73,
                                                   lat: 721, lon: 1440)> Size: 18GB
[4623282720 values with dtype=float32]
Coordinates:
  * lat      (lat) float64 6kB 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * lon      (lon) float64 12kB 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
  * channel  (channel) <U5 1kB 'u10m' 'v10m' 'u100m' ... 'r850' 'r925' 'r1000'
  * time     (time) float64 488B 0.0 6.0 12.0 18.0 ... 342.0 348.0 354.0 360.0


### Util Functions

In [4]:
def dir2gif(img_dir, output_path, fps):
    img_paths = sorted(img_dir.glob("*.png"))
    images = [imageio.imread(img_path) for img_path in img_paths]
    imageio.mimsave(output_path, images, fps=fps)

def plotting_loop(dim_name, iterable, data, fig, obj, update_func, output_dir, fname_prefix, keep_images=False, fps=1):
    """
    Inputs:
        dim_name (str): name of the dimension to loop over
        iterable (normally time or variable)
        data (xarray dataset)
        fig (matplotlib figure)
        obj (prepared matplotlib image object on axis)
        update_func (function that updates the plot)
        output_dir (directory to save images)
        fname_prefix (prefix for image filenames)
        keep_images (if True, images will not be deleted after the loop)
        fps (frames per second for gif)

    Outputs
        a gif of the plot loop at the specified output_dir

    Return:
        fig
    """
    (output_dir/"imgs").mkdir(exist_ok=True, parents=True)
    for i, idx in enumerate(iterable):
        data_slice = data.sel({dim_name: idx}).values
        update_func(fig, obj, data_slice, i)
        fig.savefig(output_dir / "imgs" / f"{fname_prefix}_{i:03}.png")

    dir2gif(output_dir/"imgs", output_dir / f"{fname_prefix}.gif", fps=fps)

    if not keep_images:
        shutil.rmtree(output_dir/"imgs")

    return fig

## Example Usage of Utils

In [5]:
data = []
for i in range(100):
    field = np.random.default_rng().normal(0, 2*np.sin(i)+3, (100, 100))
    data.append(field)

data = xr.DataArray(np.array(data), dims=["time", "lat", "lon"])
fig, ax = plt.subplots()
im = ax.imshow(field, vmin=-10, vmax=10)
def example_plot_updater(fig, plot_obj, data, i): 
    plot_obj.set_data(data)
    fig.suptitle(f"Field at t={i}h")

plotting_loop("time", range(100), data, fig, im, example_plot_updater, img_out_path / "random", "ex", keep_images=True, fps=1)
plt.clf()

<Figure size 640x480 with 0 Axes>

## Plotting gif of each var

In [13]:
def context(var, units):
    # params
    data = da.sel(channel=var)
    output_dir = img_out_path / "era5_60t"
    title_f = lambda t: f"{var} at t={t*6:03}h"
    cbar_label = f"{var} ({units})"
    iter_dim = "time"
    iterable_vals = da.time.values
    vmin, vmax = data.values.min(), data.values.max()
    cmap = cm.viridis # type: ignore
    figsize = (9, 5)
    xax_label = "longitude [degrees_east]"
    yax_label = "latitude [degrees_north]"
    xticks = np.linspace(0, nlon, 9, dtype=int)[1:-1]
    yticks = np.linspace(0, nlat-1, 5, dtype=int)
    xticklabs = lon[xticks].astype(int)
    yticklabs = lat[yticks].astype(int)
    adjust = {
        "top": 0.83,
        "bottom": 0.13,
        "left": 0.02, #"left": 0.025,
        "right": 0.93,
    }

    # plot setup
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(np.zeros((nlat, nlon)), vmin=vmin, vmax=vmax, cmap=cmap)
    fig.subplots_adjust(**adjust)
    cax = fig.add_axes([ax.get_position().x1+0.02,ax.get_position().y0,0.02,ax.get_position().y1-ax.get_position().y0]) # type: ignore
    cbar = fig.colorbar(im, cax=cax, extend="both")
    cbar.set_label(cbar_label)
    ax.set_xlabel(xax_label)
    ax.set_ylabel(yax_label)
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_xticklabels(xticklabs)
    ax.set_yticklabels(yticklabs)
    fig.patch.set_facecolor('xkcd:pale grey')

    # make plotting function
    def plot_updater(fig, plot_obj, data, i):
        plot_obj.set_data(data)
        fig.suptitle(title_f(i), fontsize=20, x=0.205, y=0.94)

    # run plotting loop
    plotting_loop(iter_dim, iterable_vals, data, fig, im, plot_updater, output_dir, var, keep_images=False, fps=4)

    # close the figure
    plt.close(fig)

units_table = {
    "u": "m s-1",
    "v": "m s-1",
    "z": "m",
    "q": "kg/kg",
    "r": "%",
    "t": "K",
    "t2m": "K",
    "sp": "Pa",
    "msl": "Pa",
    "tcwv": "kg m-2",
    "2d": "K",
}
sfc = ("t2m", "sp", "msl", "tcwv", "2d")
for ch in da.channel.values:
    if ch in sfc:
        units = units_table[ch]

    else:
        char = [c for c in ch if c.isalpha()] # get first alphabetical character of channel name
        assert len(char) in (1,2), f"Received invalid channel name: {ch}"
        units = units_table[char[0]]
            
    context(ch, units)
    print(f"Finished plotting {ch}")

Finished plotting u10m
Finished plotting v10m
Finished plotting u100m
Finished plotting v100m
Finished plotting t2m
Finished plotting sp
Finished plotting msl
Finished plotting tcwv
Finished plotting u50
Finished plotting u100
Finished plotting u150
Finished plotting u200
Finished plotting u250
Finished plotting u300
Finished plotting u400
Finished plotting u500
Finished plotting u600
Finished plotting u700
Finished plotting u850
Finished plotting u925
Finished plotting u1000
Finished plotting v50
Finished plotting v100
Finished plotting v150
Finished plotting v200
Finished plotting v250
Finished plotting v300
Finished plotting v400
Finished plotting v500
Finished plotting v600
Finished plotting v700
Finished plotting v850
Finished plotting v925
Finished plotting v1000
Finished plotting z50
Finished plotting z100
Finished plotting z150
Finished plotting z200
Finished plotting z250
Finished plotting z300
Finished plotting z400
Finished plotting z500
Finished plotting z600
Finished plott

## Plotting Topo-Hallucination Figure for Ullrich et al. 2024 (Unused attempt #1)

In [37]:
def context(var, units, t1, fname, vlims=None, geopotential_conversion=False, clevs=None):
    with plt.rc_context({
        "axes.labelsize": 7,
        "axes.titlesize": 10,
        "xtick.labelsize": 7,
        "ytick.labelsize": 7
    }):

        # params
        data = da.sel(channel=var).isel(time=[0,t1]) / (9.81 if geopotential_conversion else 1)
        if var == "msl":
            data = data / 100 # convert Pa to hPa

        output_dir = img_out_path / "paper_figures" / "gif_input"
        title_str = f"{var} [{units}]"
        cbar_label = f"[{units}]"
        if vlims:
            # set vlims
            vmin, vmax = vlims

        else:
            vmin, vmax = data.values.min(), data.values.max()

        cmap = cm.viridis # type: ignore
        figsize = (4.2, 4.4)
        xax_label = "longitude [degrees_east]"
        yax_label = "latitude [degrees_north]"
        xticks = np.linspace(0, nlon, 7, dtype=int)[1:-1]
        yticks = np.linspace(0, nlat-1, 5, dtype=int)
        xticklabs = lon[xticks].astype(int)
        yticklabs = lat[yticks].astype(int)
        adjust = {
            "top": 0.92,
            "bottom": 0.04,
            "left": 0.15,
            "right": 0.80,
            "hspace": 0.10,
            "wspace": 0.06,
        }

        # plot setup
        fig, axs = plt.subplot_mosaic("a\nc", figsize=figsize)
        ax1, ax2 = axs["a"], axs["c"]
        ax1.set_title("t = 0h")
        ax2.set_title(f"t = {t1 * dt}h")
        fig.subplots_adjust(**adjust)
        im1 = ax1.imshow(data.isel(time=0), vmin=vmin, vmax=vmax, cmap=cmap)
        im2 = ax2.imshow(data.isel(time=1), vmin=vmin, vmax=vmax, cmap=cmap)
        C2 = ax2.contour(data.isel(time=1), colors="k", levels=[1000, 1026], linewidths=0.3)
        ax2.clabel(C2, C2.levels, inline=True, fmt=lambda x: f"{x:n}", fontsize=4)

        cax = fig.add_axes([ax2.get_position().x1+0.05,ax2.get_position().y0,0.02,ax1.get_position().y1-ax2.get_position().y0]) # type: ignore
        cbar = fig.colorbar(im2, cax=cax, orientation="vertical", fraction=.1)
        # cbar.set_ticks(np.linspace(vmin, vmax, 7).round(0))
        # cbar.set_label(cbar_label)
        cbar.ax.set_ylabel(cbar_label, rotation="horizontal", loc="center", labelpad=-45, fontsize=9)
        for ax in (ax1, ax2):
            ax.set_xlabel(xax_label)
            ax.set_ylabel(yax_label)
            ax.set_xticks(xticks)
            ax.set_yticks(yticks)
            ax.set_xticklabels(xticklabs)
            ax.set_yticklabels(yticklabs)

        ax1.set_xlabel("")
        fig.patch.set_facecolor('xkcd:pale grey')
        fig.suptitle(title_str, x=(ax1.get_position().x0+ax1.get_position().x1)/2)

        plt.savefig(output_dir/fname, dpi=300)

        # close the figure
        plt.close(fig)

units_table = {
    "u": "m s-1",
    "v": "m s-1",
    "z": "m",
    "q": "kg/kg",
    "r": "%",
    "t": "K",
    "t2m": "K",
    "sp": "Pa",
    "msl": "Pa",
    "tcwv": "kg m-2",
    "2d": "K",
}

times = np.arange(0, 41, 1).tolist()
data_at_times = da.sel(channel="z500").isel(time=times) / 9.81
vlims = (data_at_times.values.min(), data_at_times.values.max())
# vlims = (4600, 6000) #custom z500
vlims = (970,1050)
for t in times:
    # context("z500", r"m", t, fname=f"z500_{t:02}.png", vlims=vlims, geopotential_conversion=True)
    context("msl", r"hPa", t, fname=f"msl_{t:02}.png", vlims=vlims)

In [38]:
### Turn above into a gif

gif_dir = img_out_path / "paper_figures" / "gif_input"
output_loc = img_out_path / "msl.gif"
dir2gif(gif_dir, output_loc, 2)

In [9]:
### explore z500 field... weird so far

print(da.channel)
da.sel(channel="z1000").sel(lat=0, lon=70).values / 9.81

<xarray.DataArray 'channel' (channel: 73)> Size: 1kB
array(['u10m', 'v10m', 'u100m', 'v100m', 't2m', 'sp', 'msl', 'tcwv', 'u50',
       'u100', 'u150', 'u200', 'u250', 'u300', 'u400', 'u500', 'u600', 'u700',
       'u850', 'u925', 'u1000', 'v50', 'v100', 'v150', 'v200', 'v250', 'v300',
       'v400', 'v500', 'v600', 'v700', 'v850', 'v925', 'v1000', 'z50', 'z100',
       'z150', 'z200', 'z250', 'z300', 'z400', 'z500', 'z600', 'z700', 'z850',
       'z925', 'z1000', 't50', 't100', 't150', 't200', 't250', 't300', 't400',
       't500', 't600', 't700', 't850', 't925', 't1000', 'r50', 'r100', 'r150',
       'r200', 'r250', 'r300', 'r400', 'r500', 'r600', 'r700', 'r850', 'r925',
       'r1000'], dtype='<U5')
Coordinates:
  * channel  (channel) <U5 1kB 'u10m' 'v10m' 'u100m' ... 'r850' 'r925' 'r1000'


array([114.99505 , 102.44791 , 105.383224,  96.64572 , 109.236946,
        62.01703 ,  96.85756 ,  55.672173,  92.633064,  38.948067,
        97.40587 ,  61.332233,  97.26207 ,  47.8136  ,  98.4535  ,
        62.16858 ,  91.624146,  46.210552,  94.88063 ,  62.80506 ,
        92.74042 ,  52.233124,  97.79895 ,  68.605194,  95.016266,
        58.956474,  94.228584,  67.6941  ,  91.05917 ,  57.555794,
        92.29206 ,  68.408226,  89.651924,  53.55202 ,  82.93445 ,
        57.0054  ,  80.54857 ,  46.827316,  79.50152 ,  56.666687,
        82.823425,  51.620926,  82.49542 ,  61.317875,  84.96036 ,
        57.23197 ,  86.208145,  66.66867 ,  91.403015,  66.98376 ,
        92.04274 ,  72.36972 ,  93.9425  ,  67.87189 ,  92.82409 ,
        76.58098 , 101.40775 ,  75.9887  ,  97.00587 ,  78.122536,
        98.39008 ], dtype=float32)

## 4 Panel Plot for Ullrich et al. 2024 (Hopefully-useful attempt #2)

In [56]:
def context(data, vars, units, fcst_hour, fname):
    with plt.rc_context({
        "axes.labelsize": 9,
        "axes.titlesize": 12,
        "xtick.labelsize": 9,
        "ytick.labelsize": 9,
        "figure.titlesize": 14,
        "font.family": "sans-serif",
        # "font.sans-serif": "Geneva"
    }):

        output_path = img_out_path / fname
        # if vlims:
        #     # set vlims
        #     vmin, vmax = vlims

        # else:
        #     vmin, vmax = data.values.min(), data.values.max()

        cmap = cm.viridis # type: ignore
        figsize = (8, 4.3)
        xax_label = "longitude [degrees_east]"
        yax_label = "latitude [degrees_north]"
        xticks = np.linspace(0, nlon, 7, dtype=int)[1:-1]
        yticks = np.linspace(0, nlat-1, 5, dtype=int)
        xticklabs = lon[xticks].astype(int)
        yticklabs = lat[yticks].astype(int)
        adjust = {
            "top": 0.88,
            "bottom": 0.1,
            "left": 0.08,
            "right": 0.9,
            "hspace": 0.5,
            "wspace": 0.55,
        }

        # plot setup
        fig, axs_caxs = plt.subplots(nrows=2, ncols=2, figsize=figsize)
        fig.subplots_adjust(**adjust)
        axs = axs_caxs[0, 0], axs_caxs[0, 1], axs_caxs[1, 0], axs_caxs[1, 1]
        # caxs = axs_caxs[0, 1], axs_caxs[0, 2], axs_caxs[1, 1], axs_caxs[1, 3]
        plotting_items = list(zip(vars, units, axs, vlims))

        for i, (var, unit, ax, (vmin, vmax)) in enumerate(plotting_items):
            ax.set_title(f"{var} [{unit}]")
            im = ax.imshow(data.sel(channel=var), vmin=vmin, vmax=vmax, cmap=cmap)

            cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.01,ax.get_position().y1-ax.get_position().y0]) # type: ignore
            cbar = fig.colorbar(im, cax=cax, orientation="vertical", fraction=.1)
        # cbar.set_ticks(np.linspace(vmin, vmax, 7).round(0))
        # cbar.set_label(cbar_label)
            cbar.ax.set_ylabel(unit, rotation="vertical", loc="center", fontsize=9)
            ax.set_xlabel(xax_label)
            ax.set_ylabel(yax_label)
            ax.set_xticks(xticks)
            ax.set_yticks(yticks)
            ax.set_xticklabels(xticklabs)
            ax.set_yticklabels(yticklabs)

        # fig.patch.set_facecolor('xkcd:pale grey')
        fig.suptitle(f"Projection Time: {fcst_hour} hours", x=((axs[0].get_position().x0 + axs[1].get_position().x1) / 2))

        plt.savefig(output_path, dpi=300)

        # close the figure
        plt.close(fig)

units_table = {
    "u": "m s-1",
    "v": "m s-1",
    "z": "m",
    "q": "kg/kg",
    "r": "%",
    "t": "K",
    "t2m": "K",
    "sp": "hPa",
    "msl": "hPa",
    "tcwv": "kg m$^{-2}$",
    "2d": "K",
    "u10m": "m s$^{-1}$"
}

vars = ["t2m", "tcwv", "msl", "u10m"]
units = [units_table[var] for var in vars]
vlims = [
    (225, 325),
    (0, 70),
    (980, 1040),
    (-20, 20)
]
fcst_hour = 120
data = da.sel(channel=vars, time=fcst_hour)
context(data, vars, units, fcst_hour, fname=f"four_panel_0.png")

In [50]:
plt.rcParams.keys()

KeysView(RcParams({'_internal.classic_mode': False,
          'agg.path.chunksize': 0,
          'animation.bitrate': -1,
          'animation.codec': 'h264',
          'animation.convert_args': ['-layers', 'OptimizePlus'],
          'animation.convert_path': 'convert',
          'animation.embed_limit': 20.0,
          'animation.ffmpeg_args': [],
          'animation.ffmpeg_path': 'ffmpeg',
          'animation.frame_format': 'png',
          'animation.html': 'none',
          'animation.writer': 'ffmpeg',
          'axes.autolimit_mode': 'data',
          'axes.axisbelow': 'line',
          'axes.edgecolor': 'black',
          'axes.facecolor': 'white',
          'axes.formatter.limits': [-5, 6],
          'axes.formatter.min_exponent': 0,
          'axes.formatter.offset_threshold': 4,
          'axes.formatter.use_locale': False,
          'axes.formatter.use_mathtext': False,
          'axes.formatter.useoffset': True,
          'axes.grid': False,
          'axes.grid.axis': 'b