## Regional Ocean: Basic Region and Surface Field Visualization

Note: This notebook is meant to be run with the cupid-analysis kernel (see [CUPiD Installation](https://ncar.github.io/CUPiD/index.html#installing)). This notebook is often run by default as part of [CESM post-processing steps](https://ncar.github.io/CUPiD/run_cesm.html), but you can also run it manually.

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from cartopy import crs as ccrs, feature as cfeature
import cartopy
import warnings
import os
from glob import glob

import regional_utils

In [None]:
CESM_output_dir = ""  # "/glade/u/home/manishrv/scratch/archive/"  # /glade/derecho/scratch/ajanney/archive/"  "
case_name = ""  # "gustavo-glorys-2"  # "CrocCaribGio"

ts_dir = None  # "/glade/campaign/cesm/development/cross-wg/diagnostic_framework/CESM_output_for_testing"
base_case_output_dir = None  # None => use CESM_output_dir
base_case_name = None  # "b.e23_alpha17f.BLT1850.ne30_t232.092"

## As regional domains will vary so much in application, we don't want to assume a minimum duration
## Thus, we can ignore start and end dates and simply reduce over the whole time frame.
start_date = None  # "0001-01-01"
end_date = None  # "0101-01-01"
base_start_date = None  # "0001-01-01"
base_end_date = None  # "0101-01-01"

obs_data_dir = None  # "/glade/campaign/cesm/development/cross-wg/diagnostic_framework/CUPiD_obs_data"

serial = False  # use dask LocalCluster

savefigs = False

lc_kwargs = {}

In [None]:
CESM_output_dir = "/glade/u/home/manishrv/scratch/archive/"  # /glade/derecho/scratch/ajanney/archive/"  "
case_name = "gustavo-glorys-2"  # "CrocCaribGio"

ts_dir = None  # "/glade/campaign/cesm/development/cross-wg/diagnostic_framework/CESM_output_for_testing"
base_case_output_dir = None  # None => use CESM_output_dir
base_case_name = None  # "b.e23_alpha17f.BLT1850.ne30_t232.092"

## As regional domains will vary so much in application, we don't want to assume a minimum duration
## Thus, we can ignore start and end dates and simply reduce over the whole time frame.
start_date = None  # "0001-01-01"
end_date = None  # "0101-01-01"
base_start_date = None  # "0001-01-01"
base_end_date = None  # "0101-01-01"

obs_data_dir = None  # "/glade/campaign/cesm/development/cross-wg/diagnostic_framework/CUPiD_obs_data"

serial = False  # use dask LocalCluster

savefigs = True
image_output_dir = "/glade/derecho/scratch/ajanney/archive"  # if None, will be created in case directory

lc_kwargs = {}

In [None]:
OUTDIR = f"{CESM_output_dir}/{case_name}/ocn/hist/"
print("Output directory is:", OUTDIR)

In [None]:
# When running interactively, cupid_run should be set to 0
# this will spin-up a DASK cluster in the notebook, it may need a few mintues

cupid_run = 1

if cupid_run == 1:

    from dask.distributed import Client, LocalCluster

    # Spin up cluster (if running in parallel)
    client = None
    if not serial:
        cluster = LocalCluster(**lc_kwargs)
        client = Client(cluster)

else:

    from dask.distributed import Client
    from dask_jobqueue import PBSCluster

    # Make sure to run on Casper or another system that is able
    # to allocate cores and not just nodes
    cluster = PBSCluster(
        cores=12,
        processes=12,
        memory="120GB",
        account="P93300012",
        queue="casper",
        walltime="02:00:00",
    )

    client = Client(cluster)

    cluster.scale(1)

    print(cluster)

client

## Load in Model Output and Peek at Variables

#### Default File Structure in MOM6
This file structure will be different if you modify the diag_table.

- **static data**: contains horizontal grid, vertical grid, land/sea mask, bathymetry, lat/lon information
- **sfc data**: daily output of 2D surface fields (salinity, temp, SSH, velocities)
- **monthly data**: averaged monthly output of the full 3D domain, regridded to predefined grid (MOM6 default WOA)
- **native data**: averaged monthly output of ocean state and atmospheric fluxes

In [None]:
case_output_dir = os.path.join(CESM_output_dir, case_name, "ocn", "hist")

# Xarray time decoding things
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)

## Static data includes hgrid, vgrid, bathymetry, land/sea mask
static_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*static.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Surface Data
sfc_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*sfc*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Monthly Domain Data
monthly_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*z*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Image/Gif Output Directory
if image_output_dir is None:
    image_output_dir = os.path.join(CESM_output_dir, case_name, "ocn", "cupid_images")
else:
    image_output_dir = os.path.join(image_output_dir, case_name, "ocn", "cupid_images")
if not os.path.exists(image_output_dir):
    os.makedirs(image_output_dir)
print("Image output directory is:", image_output_dir)

In [None]:
static_data

#### Static Information About the Domain

In [None]:
static_data

#### Daily Surface Fields

In [None]:
sfc_data

#### Full Domain Fields, Averaged Monthly

In [None]:
monthly_data

## Look at Regional Domain

In [None]:
%matplotlib inline

In [None]:
regional_utils.visualize_regional_domain(static_data)

## Plotting Surface Fields

In [None]:
## Ploting basice state variables
variables = ["SSH", "tos", "sos", "speed"]  # "SSU", "SSV"]
for var in variables:
    if var not in list(sfc_data.variables):
        print(f"Variable '{var}' not in given dataset. It will not be plotted.")
        variables.remove(var)

In [None]:
for var in variables:
    long_name = sfc_data[var].long_name
    short_name = var
    time_bound = [
        sfc_data["time_bounds"].values[0][0].strftime("%Y-%m-%d"),
        sfc_data["time_bounds"].values[-1][-1].strftime("%Y-%m-%d"),
    ]
    units = sfc_data[var].units

    central_longitude = float(
        (static_data["geolon"].max().values + static_data["geolon"].min().values) / 2
    )

    fig = plt.figure(dpi=200, figsize=(8, 8))
    ax1 = fig.add_subplot(
        2, 2, 1, projection=ccrs.Mercator(central_longitude=central_longitude)
    )
    ax2 = fig.add_subplot(
        2, 2, 2, projection=ccrs.Mercator(central_longitude=central_longitude)
    )
    ax3 = fig.add_subplot(
        2, 2, 3, projection=ccrs.Mercator(central_longitude=central_longitude)
    )
    ax4 = fig.add_subplot(
        2, 2, 4, projection=ccrs.Mercator(central_longitude=central_longitude)
    )

    annotate = True

    mesh = regional_utils.plot_2D_latlon_field_plot(
        sfc_data[var].mean(dim="time"),
        static_data,
        fontsize=8,
        title=f"Avg {short_name}",
        cbar_label=f"{units}",
        axis=ax3,
        annotate=annotate,
    )

    colormap = mesh.get_cmap()
    norm = mesh.norm
    clim = mesh.get_clim()

    regional_utils.plot_2D_latlon_field_plot(
        sfc_data[var].isel(time=0),
        static_data,
        fontsize=8,
        title=f"Initial {short_name}",
        cbar_label=f"{units}",
        axis=ax1,
        colormap=colormap,
        norm=norm,
        clim=clim,
        annotate=annotate,
    )
    regional_utils.plot_2D_latlon_field_plot(
        sfc_data[var].isel(time=-1),
        static_data,
        fontsize=8,
        title=f"Final {short_name}",
        cbar_label=f"{units}",
        axis=ax2,
        colormap=colormap,
        norm=norm,
        clim=clim,
        annotate=annotate,
    )

    regional_utils.plot_2D_latlon_field_plot(
        sfc_data[var].std(dim="time"),
        static_data,
        fontsize=8,
        title=f"StdDev {short_name}",
        cbar_label=f"{units}",
        axis=ax4,
        annotate=annotate,
    )

    fig.suptitle(f"{long_name} from {time_bound[0]} to {time_bound[-1]}", fontsize=18)

    # fig.subplots_adjust(
    # hspace=0.1,  # Increase vertical space between plots
    # wspace=0.4,  # Increase horizontal space between plots
    # )

## 🚧 Add Slicing Function Here? 🚧

## Area Weighted Averages and Timeseries

In [None]:
regional_utils.plot_area_averaged_timeseries(
    sfc_data, static_data, variables=["SSH", "tos", "sos", "speed"]
)

In [None]:
if not serial and cupid_run == 1:
    client.shutdown()