## Regional Ocean: Animations of Surface Fields

In [None]:
%load_ext autoreload
%autoreload 2

import xarray as xr
import os
import regional_utils
import cartopy.crs as ccrs
import matplotlib.pyplot as plt

In [None]:
CESM_output_dir = ""  # "CROCODILE_tutorial_nwa12_MARBL"
case_name = ""  # "/glade/campaign/cgd/oce/projects/CROCODILE/workshops/2025/Diagnostics/CESM_Output/"

## No timeseries or base case used in this notebook
ts_dir = None
base_case_output_dir = None
base_case_name = None

## As regional domains vary so much in purpose, simulation length, and extent, we don't want to assume a minimum duration
## Thus, we ignore start and end dates and simply reduce/output over the whole time frame for all of the examples given.
start_date = None  # "0001-01-01"
end_date = None  # "0101-01-01"
base_start_date = None  # "0001-01-01"
base_end_date = None  # "0101-01-01"

obs_data_dir = None

savefigs = False
fig_output_dir = None

serial = False  # use dask LocalCluster

lc_kwargs = {}

In [None]:
# Parameters
case_name = "CROCODILE_tutorial_nwa12_MARBL"
base_case_name = "CROCODILE_tutorial_nwa12_MARBL"
CESM_output_dir = "/glade/campaign/cesm/development/cross-wg/diagnostic_framework/CESM_output_for_testing/"
start_date = ""
end_date = ""
base_start_date = ""
base_end_date = ""
ts_dir = "/glade/derecho/scratch/ajanney/archive/"
lc_kwargs = {"threads_per_worker": 1}
serial = True
savefigs = True
img_output_dir = None
subset_kwargs = {}
product = "/glade/work/ajanney/CUPiD/examples/regional_ocean/computed_notebooks//ocn/Regional_Ocean_Report_Card.ipynb"

In [None]:
OUTDIR = f"{CESM_output_dir}/{case_name}/ocn/hist/"
print("Output directory is:", OUTDIR)

In [None]:
case_output_dir = os.path.join(CESM_output_dir, case_name, "ocn", "hist")

# Xarray time decoding things
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)

## Static data includes hgrid, vgrid, bathymetry, land/sea mask
static_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*static.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Surface Data
sfc_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*sfc*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Monthly Domain Data
monthly_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*z*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Native Monthly Domain Data (and atm forcing)
native_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*native*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Image/Gif Output Directory
if img_output_dir is None:
    image_output_dir = os.path.join(
        "/glade/derecho/scratch/",
        os.environ["USER"],
        "archive",
        case_name,
        "ocn",
        "cupid_images",
    )
else:
    image_output_dir = os.path.join(img_output_dir, case_name, "ocn", "cupid_images")
if not os.path.exists(image_output_dir):
    os.makedirs(image_output_dir)
print("Image output directory is:", image_output_dir)

In [None]:
native_data["tauuo"]

In [None]:
native_data

In [None]:
atm_variables = ["hfds", "tauuo", "tauvo"]  # "friver"]


def choose_geocoords(dimensions):
    if "xh" in dimensions and "yh" in dimensions:
        lon_var = "geolon"
        lat_var = "geolat"
    elif "xq" in dimensions and "yq" in dimensions:
        lon_var = "geolon_c"
        lat_var = "geolat_c"
    elif "xq" in dimensions and "yh" in dimensions:
        lon_var = "geolon_u"
        lat_var = "geolat_u"
    elif "xh" in dimensions and "yq" in dimensions:
        lon_var = "geolon_v"
        lat_var = "geolat_v"
    else:
        raise ValueError(f"Could not determine geocoords for dims: {dimensions}")
    return lon_var, lat_var


def choose_areacello(dimensions):
    if "xh" in dimensions and "yh" in dimensions:
        area_var = "areacello"
    elif "xq" in dimensions and "yq" in dimensions:
        area_var = "areacello_bu"
    elif "xq" in dimensions and "yh" in dimensions:
        area_var = "areacello_cu"
    elif "xh" in dimensions and "yq" in dimensions:
        area_var = "areacello_cv"
    else:
        raise ValueError(f"Could not determine areacello for dims: {dimensions}")
    return area_var


for var in atm_variables:
    long_name = native_data[var].long_name
    short_name = var
    time_bound = [
        sfc_data["time_bounds"].values[0][0].strftime("%Y-%m-%d"),
        sfc_data["time_bounds"].values[-1][-1].strftime("%Y-%m-%d"),
    ]
    units = native_data[var].units

    central_longitude = float(
        (static_data["geolon"].max().values + static_data["geolon"].min().values) / 2
    )

    fig = plt.figure(dpi=200, figsize=(8, 8))
    ax1 = fig.add_subplot(
        2, 2, 1, projection=ccrs.Mercator(central_longitude=central_longitude)
    )
    ax2 = fig.add_subplot(
        2, 2, 2, projection=ccrs.Mercator(central_longitude=central_longitude)
    )
    ax3 = fig.add_subplot(
        2, 2, 3, projection=ccrs.Mercator(central_longitude=central_longitude)
    )
    ax4 = fig.add_subplot(
        2, 2, 4, projection=ccrs.Mercator(central_longitude=central_longitude)
    )

    annotate = True

    lon_var, lat_var = choose_geocoords(native_data[var].dims)
    area_var = choose_areacello(native_data[var].dims)

    mesh = regional_utils.plot_2D_latlon_field_plot(
        native_data[var].mean(dim="time"),
        static_data,
        lon_var=lon_var,
        lat_var=lat_var,
        area_var=area_var,
        fontsize=8,
        title=f"Avg {short_name}",
        cbar_label=f"{units}",
        axis=ax3,
        annotate=annotate,
    )

    colormap = mesh.get_cmap()
    norm = mesh.norm
    clim = mesh.get_clim()

    regional_utils.plot_2D_latlon_field_plot(
        native_data[var].isel(time=0),
        static_data,
        lon_var=lon_var,
        lat_var=lat_var,
        area_var=area_var,
        fontsize=8,
        title=f"Initial {short_name}",
        cbar_label=f"{units}",
        axis=ax1,
        colormap=colormap,
        norm=norm,
        clim=clim,
        annotate=annotate,
    )
    regional_utils.plot_2D_latlon_field_plot(
        native_data[var].isel(time=-1),
        static_data,
        lon_var=lon_var,
        lat_var=lat_var,
        area_var=area_var,
        fontsize=8,
        title=f"Final {short_name}",
        cbar_label=f"{units}",
        axis=ax2,
        colormap=colormap,
        norm=norm,
        clim=clim,
        annotate=annotate,
    )

    regional_utils.plot_2D_latlon_field_plot(
        native_data[var].std(dim="time"),
        static_data,
        lon_var=lon_var,
        lat_var=lat_var,
        area_var=area_var,
        fontsize=8,
        title=f"StdDev {short_name}",
        cbar_label=f"{units}",
        axis=ax4,
        annotate=annotate,
    )

    fig.suptitle(f"{long_name} from {time_bound[0]} to {time_bound[-1]}", fontsize=18)

    # fig.subplots_adjust(
    # hspace=0.1,  # Increase vertical space between plots
    # wspace=0.4,  # Increase horizontal space between plots
    # )