## Imports

In [86]:
import xarray as xr
import numpy as np
import imageio
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
import datetime
from pathlib import Path

## Setup

In [85]:
nc_path = "/N/slate/jmelms/projects/FCN_dynamical_testing/data/output/full_output_corrected.nc"
lat_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/latitude.npy"
lon_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/longitude.npy"
lsm_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/land_sea_edges_mask.npy"
mean_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/global_means.npy"
std_path = "/N/u/jmelms/BigRed200/projects/dynamical-tests-FCN/metadata/global_stds.npy"
img_out_path = Path("imgs").resolve()
img_out_path.mkdir(exist_ok=True)

## Loading Data

In [81]:
# load lat, lon, and land-sea mask
lat = np.load(lat_path)
lon = np.load(lon_path)
lsm = np.load(lsm_path)
mean = np.load(mean_path)
std = np.load(std_path)
nlat = lat.size
nlon = lon.size
dt = 6 # time step in hours

# load netcdf data
ds = xr.open_dataset(nc_path)

# get the data array
da = ds["__xarray_dataarray_variable__"]

# remove the "history" dimension because it's not used in these runs
da = da.squeeze(drop=True)

# check to make sure that the lat/lon dimensions are the same as the input data
assert da.lat.size == nlat, "Latitude dimensions do not match"
assert da.lon.size == nlon, "Longitude dimensions do not match"

# change time dimension to datetime.datetime if using real times
# ds["time"] = [datetime.datetime.fromtimestamp(t/10e8, tz=datetime.timezone.utc) for t in ds.time.values]

# change time dimension to hours from initialization if using relative times for idealized sim
da["time"] = [t/10e8/3600 for t in da.time.values]

# unstandardize data
da = da * std + mean

print(da)

<xarray.DataArray '__xarray_dataarray_variable__' (time: 53, channel: 73,
                                                   lat: 721, lon: 1440)>
array([[[[ 7.45058060e-09,  7.45058060e-09,  7.45058060e-09, ...,
           7.45058060e-09,  7.45058060e-09,  7.45058060e-09],
         [ 7.45058060e-09,  7.45058060e-09,  7.45058060e-09, ...,
           7.45058060e-09,  7.45058060e-09,  7.45058060e-09],
         [ 7.45058060e-09,  7.45058060e-09,  7.45058060e-09, ...,
           7.45058060e-09,  7.45058060e-09,  7.45058060e-09],
         ...,
         [ 7.45058060e-09,  7.45058060e-09,  7.45058060e-09, ...,
           7.45058060e-09,  7.45058060e-09,  7.45058060e-09],
         [ 7.45058060e-09,  7.45058060e-09,  7.45058060e-09, ...,
           7.45058060e-09,  7.45058060e-09,  7.45058060e-09],
         [ 7.45058060e-09,  7.45058060e-09,  7.45058060e-09, ...,
           7.45058060e-09,  7.45058060e-09,  7.45058060e-09]],

        [[-4.47034836e-08, -4.47034836e-08, -4.47034836e-08, ...,
   

## Check a few variables

In [84]:
def context():
    vis_name = "mslp_time_series"
    output_dir = img_out_path / vis_name
    output_dir.mkdir(exist_ok=True, parents=True)

    channel = "mslp"
    units = "hPa"
    xticks = np.arange(0, nlon, 200)
    yticks = np.arange(0, nlat, 200)
    xticklabs = lon[xticks]
    yticklabs = lat[yticks]
    contour_levels = np.arange(900, 1100, 4)

    # hPa = 100 Pa
    data = da.sel(channel="msl") / 100

    # vmin = data.min()
    # vmax = data.max()

    vmin = 800
    vmax = 1200

    tstep = data.isel(time=1)

    print(tstep)
    
    # # plot all timesteps
    # for i, t in enumerate(da.time):
    #     if i % 10:
    #         continue
    #     t_str = str(t.values)
    #     tstep = data.sel(time=t)
    #     fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5), gridspec_kw={'width_ratios': [1, 0.05]})
    #     ax1.imshow(tstep, cmap='viridis')
    #     CS = ax1.contour(tstep, levels=3, colors="black")
    #     ax1.clabel(CS, CS.levels, inline=True, fmt=lambda x: round(x), fontsize=10)
    #     norm = Normalize(vmin=vmin, vmax=vmax)
    #     cmap = cm.viridis
    #     scalar_mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
    #     fig.colorbar(scalar_mappable, label=f"{channel} ({units})", cax=ax2, fraction=0.05, pad=0.07, )
    #     fig.suptitle(f'{channel} ({units}) at {t_str}h')
    #     ax1.set_xlabel('Longitude')
    #     ax1.set_ylabel('Latitude')
    #     ax1.set_xticks(xticks, xticklabs)
    #     ax1.set_yticks(yticks, yticklabs)

    #     fig.savefig(output_dir / f"{vis_name}_{i:04d}.png")
    #     print(f"Saved {output_dir / f'{vis_name}_{i:04d}.png'}")
    #     plt.close()

context()

<xarray.DataArray '__xarray_dataarray_variable__' (lat: 721, lon: 1440)>
array([[96591696., 96591696., 96591696., ..., 96591696., 96591696.,
        96591696.],
       [96624480., 96624640., 96624776., ..., 96624048., 96624192.,
        96624344.],
       [96683904., 96684224., 96684560., ..., 96682912., 96683240.,
        96683584.],
       ...,
       [98706080., 98706064., 98706056., ..., 98706112., 98706104.,
        98706104.],
       [98588984., 98588960., 98588936., ..., 98589048., 98589032.,
        98589008.],
       [98483912., 98483912., 98483912., ..., 98483912., 98483912.,
        98483912.]], dtype=float32)
Coordinates:
  * lat      (lat) float64 90.0 89.75 89.5 89.25 ... -89.25 -89.5 -89.75 -90.0
  * lon      (lon) float64 0.0 0.25 0.5 0.75 1.0 ... 359.0 359.2 359.5 359.8
    channel  <U5 'msl'
    time     float64 6.0
