## Regional Ocean: Animations of Surface Fields

In [None]:
%load_ext autoreload
%autoreload 2

import xarray as xr
import os
import cartopy.crs as ccrs
import matplotlib.pyplot as plt

import regional_utils as utils

In [None]:
CESM_output_dir = ""  # "CROCODILE_tutorial_nwa12_MARBL"
case_name = ""  # "/glade/campaign/cgd/oce/projects/CROCODILE/workshops/2025/Diagnostics/CESM_Output/"

## No timeseries or base case used in this notebook
ts_dir = None
base_case_output_dir = None
base_case_name = None

## As regional domains vary so much in purpose, simulation length, and extent, we don't want to assume a minimum duration
## Thus, we ignore start and end dates and simply reduce/output over the whole time frame for all of the examples given.
start_date = None  # "0001-01-01"
end_date = None  # "0101-01-01"
base_start_date = None  # "0001-01-01"
base_end_date = None  # "0101-01-01"

obs_data_dir = None

savefigs = False
fig_output_dir = None

serial = False  # use dask LocalCluster

lc_kwargs = {}

In [None]:
# Parameters
case_name = "CROCODILE_tutorial_nwa12_MARBL"
base_case_name = "CROCODILE_tutorial_nwa12_MARBL"
CESM_output_dir = "/glade/campaign/cesm/development/cross-wg/diagnostic_framework/CESM_output_for_testing/"
start_date = "2000-01-01"
end_date = "2000-11-01"
base_start_date = ""
base_end_date = ""
ts_dir = "/glade/derecho/scratch/ajanney/archive/"
lc_kwargs = {"threads_per_worker": 1}
serial = True
savefigs = True
fig_output_dir = None
subset_kwargs = {}
product = "/glade/work/ajanney/CUPiD/examples/regional_ocean/computed_notebooks//ocn/Regional_Ocean_Report_Card.ipynb"

In [None]:
OUTDIR = f"{CESM_output_dir}/{case_name}/ocn/hist/"
print("Output directory is:", OUTDIR)

In [None]:
case_output_dir = os.path.join(CESM_output_dir, case_name, "ocn", "hist")

# Xarray time decoding things
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)

## Static data includes hgrid, vgrid, bathymetry, land/sea mask
static_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*static.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
    engine="netcdf4",
)

# ## Surface Data
# sfc_data = xr.open_mfdataset(
#     os.path.join(case_output_dir, f"*sfc*.nc"),
#     decode_timedelta=True,
#     decode_times=time_coder,
#     engine="netcdf4",
# )

## Native Monthly Domain Data (and atm forcing)
monthly_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*h.z*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
    engine="netcdf4",
)

## Native Monthly Domain Data (and atm forcing)
native_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*native*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
    engine="netcdf4",
)

## Image/Gif Output Directory
if fig_output_dir is None:
    image_output_dir = os.path.join(
        "/glade/derecho/scratch/",
        os.environ["USER"],
        "archive",
        case_name,
        "ocn",
        "cupid_images",
    )
else:
    image_output_dir = os.path.join(fig_output_dir, case_name, "ocn", "cupid_images")
if not os.path.exists(image_output_dir):
    os.makedirs(image_output_dir)
print("Image output directory is:", image_output_dir)

In [None]:
## Apply time boundaries
if len(start_date) > 0 and len(end_date) > 0:
    import cftime

    calendar = monthly_data.time.encoding.get("calendar", "standard")

    calendar_map = {
        "gregorian": cftime.DatetimeProlepticGregorian,
        "noleap": cftime.DatetimeNoLeap,
    }

    CFTime = calendar_map.get(calendar, cftime.DatetimeGregorian)
    y, m, d = [int(i) for i in start_date.split("-")]
    start_date_time = CFTime(y, m, d)
    y, m, d = [int(i) for i in end_date.split("-")]
    end_date_time = CFTime(y, m, d)

    monthly_data = monthly_data.sel(time=slice(start_date_time, end_date_time))
    native_data = native_data.sel(time=slice(start_date_time, end_date_time))

In [None]:
atm_variables = ["hfds", "tauuo", "tauvo"]
atm_to_ocn_variables = {"hfds": "thetao", "tauuo": "uo", "tauvo": "vo"}

for coord in list(static_data.variables):
    if "geolon" in coord or "geolat" in coord:
        native_data = native_data.assign_coords({coord: static_data[coord]})
        monthly_data = monthly_data.assign_coords({coord: static_data[coord]})

for var in atm_variables:
    field = native_data[var]
    sfc_var = atm_to_ocn_variables[var]
    sfc_field = monthly_data[sfc_var].isel(z_l=0)

    geocoords = utils.chooseGeoCoords(field.dims)
    lon = geocoords["longitude"]
    lat = geocoords["latitude"]

    cmap = utils.chooseColorMap(var)

    # Plot atmospheric variable for each month
    g1 = field.plot(
        x=lon, y=lat, col="time", col_wrap=4, robust=True, figsize=(16, 10), cmap=cmap
    )
    plt.suptitle(f"Atmosphere: {var} monthly fields", fontsize=16, y=1.02)
    # plt.tight_layout()
    if savefigs:
        plt.savefig(os.path.join(image_output_dir, f"{var}_monthly_grid.png"))
    plt.show()

    # Plot ocean surface variable for each month
    g2 = sfc_field.plot(
        x=lon, y=lat, col="time", col_wrap=4, robust=True, figsize=(16, 10), cmap=cmap
    )
    plt.suptitle(f"Ocean: {sfc_var} monthly fields", fontsize=16, y=1.02)
    # plt.tight_layout()
    if savefigs:
        plt.savefig(os.path.join(image_output_dir, f"{sfc_var}_monthly_grid.png"))
    plt.show()