## Regional Ocean: Basic Region and Surface Field Visualization

Note: This notebook is meant to be run with the cupid-analysis kernel (see [CUPiD Installation](https://ncar.github.io/CUPiD/index.html#installing)). This notebook is often run by default as part of [CESM post-processing steps](https://ncar.github.io/CUPiD/run_cesm.html), but you can also run it manually.

In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

# from mom6_tools.MOM6grid import MOM6grid
from mom6_tools.m6plot import (
    chooseColorLevels,
    chooseColorMap,
    boundaryStats,
    myStats,
    label,
)
from cartopy import crs as ccrs, feature as cfeature
import cartopy
import warnings
import os
from glob import glob

In [None]:
CESM_output_dir = ""  # "/glade/u/home/manishrv/scratch/archive/"  # /glade/derecho/scratch/ajanney/archive/"  "
case_name = ""  # "gustavo-glorys-2"  # "CrocCaribGio"

ts_dir = None  # "/glade/campaign/cesm/development/cross-wg/diagnostic_framework/CESM_output_for_testing"
base_case_output_dir = None  # None => use CESM_output_dir
base_case_name = None  # "b.e23_alpha17f.BLT1850.ne30_t232.092"

## As regional domains will vary so much in application, we don't want to assume a minimum duration
## Thus, we can ignore start and end dates and simply reduce over the whole time frame.
start_date = None  # "0001-01-01"
end_date = None  # "0101-01-01"
base_start_date = None  # "0001-01-01"
base_end_date = None  # "0101-01-01"

obs_data_dir = None  # "/glade/campaign/cesm/development/cross-wg/diagnostic_framework/CUPiD_obs_data"

serial = False  # use dask LocalCluster

savefigs = False

lc_kwargs = {}

In [None]:
OUTDIR = f"{CESM_output_dir}/{case_name}/ocn/hist/"
print("Output directory is:", OUTDIR)

In [None]:
# When running interactively, cupid_run should be set to 0
# this will spin-up a DASK cluster in the notebook, it may need a few mintues

cupid_run = 1

if cupid_run == 1:

    from dask.distributed import Client, LocalCluster

    # Spin up cluster (if running in parallel)
    client = None
    if not serial:
        cluster = LocalCluster(**lc_kwargs)
        client = Client(cluster)

else:

    from dask.distributed import Client
    from dask_jobqueue import PBSCluster

    # Make sure to run on Casper or another system that is able
    # to allocate cores and not just nodes
    cluster = PBSCluster(
        cores=12,
        processes=12,
        memory="120GB",
        account="P93300012",
        queue="casper",
        walltime="02:00:00",
    )

    client = Client(cluster)

    cluster.scale(1)

    print(cluster)

client

In [None]:
def latlon_proj_plot(
    field,
    grid,
    lon_var="geolon",
    lat_var="geolat",
    area_var=None,
    xlabel=None,
    xunits=None,
    ylabel=None,
    yunits=None,
    title="",
    suptitle="",
    clim=None,
    colormap=None,
    norm=None,
    extend=None,
    centerlabels=False,
    nbins=None,
    axis=None,
    add_cbar=True,
    cbar_label=None,
    figsize=[16, 9],
    dpi=150,
    sigma=2.0,
    annotate=True,
    ignore=None,
    save=None,
    debug=False,
    show=False,
    logscale=False,
    projection=None,
    coastlines=True,
    res=None,
    coastcolor=[0, 0, 0],
    landcolor=[0.75, 0.75, 0.75],
    coast_linewidth=0.5,
    fontsize=22,
    gridlines=False,
):
    ## Preplotting
    plt.rc("font", size=fontsize)

    # Mask ignored values
    if ignore is not None:
        maskedField = np.ma.masked_array(field, mask=[field == ignore])
    else:
        maskedField = np.ma.masked_array(
            field, mask=np.isnan(field)
        )  # maskedField = field.copy()

    # Diagnose statistics
    if area_var is None:
        area_cell = grid["areacello"].to_numpy()
    else:
        area_cell = grid[area_var].to_numpy()
    sMin, sMax, sMean, sStd, sRMS = myStats(maskedField, area_cell, debug=debug)

    # Choose colormap
    if nbins is None and (clim is None or len(clim) == 2):
        nbins = 35
    if colormap is None:
        colormap = chooseColorMap(sMin, sMax)
        if clim is None and sStd is not None:
            lower = sMean - sigma * sStd
            upper = sMean + sigma * sStd
            if lower < sMin:
                lower = sMin
            if upper > sMax:
                upper = sMax
            cmap, norm, extend = chooseColorLevels(
                lower,
                upper,
                colormap,
                clim=clim,
                nbins=nbins,
                extend=extend,
                logscale=logscale,
            )
        else:
            cmap, norm, extend = chooseColorLevels(
                sMin,
                sMax,
                colormap,
                clim=clim,
                nbins=nbins,
                extend=extend,
                logscale=logscale,
            )
    else:
        cmap = colormap

    ## Set up figure and axis
    if projection is None:
        central_longitude = float(
            (grid[lon_var].max().values + grid[lon_var].min().values) / 2
        )
        projection = ccrs.Robinson(central_longitude=central_longitude)
    created_own_axis = False
    if axis is None:
        created_own_axis = True
        fig = plt.figure(dpi=dpi, figsize=figsize)
        axis = fig.add_subplot(1, 1, 1, projection=projection)

    ## Plot Color Mesh
    pm = axis.pcolormesh(
        grid[lon_var],
        grid[lat_var],
        field,
        cmap=cmap,
        norm=norm,
        transform=ccrs.PlateCarree(),
    )

    ## Add Land and Coastlines
    if res is None:
        res = "50m"  # can be adjusted to estimate a best res between 10m, 50m, and 110m, also use other methods
    if coastlines:
        coasts = axis.coastlines(
            resolution=res, color=coastcolor, linewidth=coast_linewidth
        )
    axis.set_facecolor(landcolor)

    ## Add the fancy bits
    if add_cbar:
        # Get position of the axis
        bbox = axis.get_position()
        fig = axis.figure  # Get the figure the axis belongs to

        # Create new axes for the colorbar, scaled to axis size
        cbar_width = 0.01  # width as fraction of figure
        cbar_padding = 0.01
        cbar_ax = fig.add_axes(
            [
                bbox.x1 + cbar_padding,  # left
                bbox.y0,  # bottom
                cbar_width,  # width
                bbox.height,  # height
            ]
        )
        cb = plt.colorbar(pm, cax=cbar_ax, extend=extend)
        if cbar_label is not None:
            cb.set_label(cbar_label)
    if centerlabels and len(clim) > 2:
        if not add_cbar:
            raise ValueError(
                "Argument Mismatch: add_cbar must be true if you also specify centerlabels to be true."
            )
        cb.set_ticks(0.5 * (clim[:-1] + clim[1:]))

    ## Finish Up
    axis.set_facecolor(landcolor)
    # axis.set_xlim( xLims )
    # axis.set_ylim( yLims )

    if annotate:
        axis.annotate(
            "max=%.5g\nmin=%.5g" % (sMax, sMin),
            xy=(0.0, 1.01),
            xycoords="axes fraction",
            verticalalignment="bottom",
        )
        if area_cell is not None:
            axis.annotate(
                "mean=%.5g\nrms=%.5g" % (sMean, sRMS),
                xy=(1.0, 1.01),
                xycoords="axes fraction",
                verticalalignment="bottom",
                horizontalalignment="right",
            )
            axis.annotate(
                " sd=%.5g\n" % (sStd),
                xy=(1.0, 1.01),
                xycoords="axes fraction",
                verticalalignment="bottom",
                horizontalalignment="left",
            )

    if xlabel and xunits:
        if len(xlabel + xunits) > 0:
            axis.set_xlabel(label(xlabel, xunits))
    if ylabel and yunits:
        if len(ylabel + yunits) > 0:
            axis.set_ylabel(label(ylabel, yunits))
    if len(title) > 0:
        axis.set_title(title, fontsize=fontsize * 1.5)
    if len(suptitle) > 0:
        if annotate:
            plt.suptitle(suptitle, y=1.01)
        else:
            plt.suptitle(suptitle)

    if gridlines:
        gl = axis.gridlines(
            draw_labels=True,
            lw=0.5,
            color="gray",
            alpha=0.5,
        )
        gl.top_labels = False
        gl.right_labels = False

    # Only show if we created our own axis
    if created_own_axis and show:
        plt.show(block=False)
    elif created_own_axis:
        plt.show(block=True)
    if save is not None:
        plt.savefig(save)
        plt.close()

    return pm

## Load in Model Output and Peek at Variables

In [None]:
case_output_dir = os.path.join(CESM_output_dir, case_name, "ocn", "hist")

# Xarray time decoding things
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)

## Static data includes hgrid, vgrid, bathymetry, land/sea mask
static_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*static.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Surface Data
sfc_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*sfc*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

## Monthly Domain Data
monthly_data = xr.open_mfdataset(
    os.path.join(case_output_dir, f"*z*.nc"),
    decode_timedelta=True,
    decode_times=time_coder,
)

#### Static Information About the Domain

In [None]:
static_data

#### Daily Surface Fields

In [None]:
sfc_data

#### Full Domain Fields, Averaged Monthly

In [None]:
monthly_data

## Look at Regional Domain

In [None]:
%matplotlib inline

In [None]:
def visualize_regional_domain(grd_xr, save=None):
    ## Grab useful variables
    central_longitude = float(
        (grd_xr["geolon"].max().values + grd_xr["geolon"].min().values) / 2
    )
    lon = grd_xr["geolon_c"].values
    lat = grd_xr["geolat_c"].values

    ## Set up figure, need to add subplots individually because of projections
    fig = plt.figure(dpi=200, figsize=(14, 8))

    ### -------------------------- ###
    ### Global Plot showing region ###
    ### -------------------------- ###
    ax = fig.add_subplot(
        1, 3, 1, projection=ccrs.Mercator(central_longitude=central_longitude)
    )

    # Add coastlines and country borders
    ax.coastlines(linewidth=0.5)
    ax.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax.add_feature(cfeature.LAND, facecolor="lightgray")
    ax.add_feature(cfeature.OCEAN, facecolor="lightblue")

    # Get the boundary of the grid
    boundary_lon = np.concatenate(
        [
            lon[0, :],  # top
            lon[1:, -1],  # right
            lon[-1, -2::-1],  # bottom (reversed)
            lon[-2::-1, 0],  # left (reversed)
        ]
    )

    boundary_lat = np.concatenate(
        [lat[0, :], lat[1:, -1], lat[-1, -2::-1], lat[-2::-1, 0]]
    )

    ax.plot(boundary_lon, boundary_lat, color="red", transform=ccrs.PlateCarree())
    ax.set_global()
    ax.set_title("Regional Domain", fontsize=10)

    ### --------------------------------------- ###
    ### Zoomed in showing relative grid density ###
    ### --------------------------------------- ###
    ax1 = fig.add_subplot(
        1, 3, 2, projection=ccrs.PlateCarree(central_longitude=central_longitude)
    )

    ax1.coastlines(linewidth=0.5, resolution="50m")
    ax1.add_feature(cfeature.LAND, facecolor="lightgray")
    ax1.add_feature(cfeature.OCEAN, facecolor="lightblue")

    skip = 10

    for i in range(0, lon.shape[1], skip):
        ax1.plot(
            lon[:, i],
            lat[:, i],
            color="black",
            linewidth=0.5,
            transform=ccrs.PlateCarree(),
        )

    ax1.plot(
        lon[:, -1],
        lat[:, -1],
        color="black",
        linewidth=0.5,
        transform=ccrs.PlateCarree(),
    )

    for j in range(0, lat.shape[0], skip):
        ax1.plot(
            lon[j, :],
            lat[j, :],
            color="black",
            linewidth=0.5,
            transform=ccrs.PlateCarree(),
        )

    ax1.plot(
        lon[-1, :], lat[-1, :], color="red", linewidth=0.8, transform=ccrs.PlateCarree()
    )
    ax1.plot(
        lon[0, :], lat[0, :], color="red", linewidth=0.8, transform=ccrs.PlateCarree()
    )
    ax1.plot(
        lon[:, -1], lat[:, -1], color="red", linewidth=0.8, transform=ccrs.PlateCarree()
    )
    ax1.plot(
        lon[:, 0], lat[:, 0], color="red", linewidth=0.8, transform=ccrs.PlateCarree()
    )

    ax1.set_title("Approx. Grid (~1/10th Density)", fontsize=10)
    gl = ax1.gridlines(
        draw_labels=True,
        lw=0.5,
        color="gray",
        alpha=0.5,
    )
    gl.top_labels = False
    gl.right_labels = False
    gl.xlabel_style = {"size": 8}
    gl.ylabel_style = {"size": 8}

    ### ----------------------------- ###
    ### Final plot with land-sea mask ###
    ### ----------------------------- ###
    from matplotlib.colors import ListedColormap

    ax2 = fig.add_subplot(
        1, 3, 3, projection=ccrs.PlateCarree(central_longitude=central_longitude)
    )
    latlon_proj_plot(
        field=grd_xr["wet"],
        grid=grd_xr,
        axis=ax2,
        annotate=False,
        colormap=ListedColormap(["tan", "cornflowerblue"]),
        coast_linewidth=0.2,
        add_cbar=False,
    )
    pm = ax2.set_title("Land/Ocean Mask", fontsize=10)
    pm.sticky_edges.x[:] = []
    pm.sticky_edges.y[:] = []
    ax2.margins(x=0.05, y=0.05)
    # ax2.legend([color = '')

    fig.subplots_adjust(wspace=0.2)

    if save is not None:
        plt.savefig(save)
        plt.close()


visualize_regional_domain(static_data)

In [None]:
var = "speed"

for var in ["SSH"]:
    time_bound = [
        sfc_data["time_bounds"].values[0][0].strftime("%Y-%m-%d"),
        sfc_data["time_bounds"].values[-1][-1].strftime("%Y-%m-%d"),
    ]
    long_name = sfc_data[var].long_name
    units = sfc_data[var].units
    cmap = "RdBu"
    latlon_proj_plot(
        sfc_data[var].mean(dim="time"),
        static_data,
        projection=ccrs.Mercator(),
        title=f"{long_name} averaged {time_bound[0]} to {time_bound[-1]}",
        fontsize=12,
        annotate=False,
        cbar_label=f"{units}",
        colormap=cmap,
    )
    latlon_proj_plot(
        sfc_data[var].std(dim="time"),
        static_data,
        projection=ccrs.Mercator(),
        title=f"StDev {long_name} {time_bound[0]} to {time_bound[-1]}",
        fontsize=12,
        annotate=False,
        cbar_label=f"{units}",
        colormap=cmap,
    )

## Plotting Surface Fields

In [None]:
## Plotting basice state variables
variables = ["SSH", "tos", "sos", "speed"]
for var in variables:
    if var not in list(sfc_data.variables):
        print(f"Variable '{var}' not in given dataset. It will not be plotted.")
        variables.remove(var)

In [None]:
for var in variables:
    long_name = sfc_data[var].long_name
    time_bound = [
        sfc_data["time_bounds"].values[0][0].strftime("%Y-%m-%d"),
        sfc_data["time_bounds"].values[-1][-1].strftime("%Y-%m-%d"),
    ]
    units = sfc_data[var].units

    central_longitude = float(
        (static_data["geolon"].max().values + static_data["geolon"].min().values) / 2
    )

    fig = plt.figure(dpi=200, figsize=(14, 8))
    ax1 = fig.add_subplot(
        1, 3, 1, projection=ccrs.Mercator(central_longitude=central_longitude)
    )
    ax2 = fig.add_subplot(
        1, 3, 2, projection=ccrs.Mercator(central_longitude=central_longitude)
    )
    ax3 = fig.add_subplot(
        1, 3, 3, projection=ccrs.Mercator(central_longitude=central_longitude)
    )

    mesh = latlon_proj_plot(
        sfc_data[var].mean(dim="time"),
        static_data,
        fontsize=8,
        suptitle=f"{long_name} averaged {time_bound[0]} to {time_bound[-1]}",
        cbar_label=f"{units}",
        axis=ax3,
    )

    colormap = mesh.get_cmap()
    norm = mesh.norm
    clim = mesh.get_clim()

    latlon_proj_plot(
        sfc_data[var].isel(time=0),
        static_data,
        fontsize=8,
        suptitle=f"Initial {long_name} at {time_bound[0]}",
        cbar_label=f"{units}",
        axis=ax1,
        colormap=colormap,
        norm=norm,
        clim=clim,
    )
    latlon_proj_plot(
        sfc_data[var].isel(time=-1),
        static_data,
        fontsize=8,
        suptitle=f"{long_name} at {time_bound[-1]}",
        cbar_label=f"{units}",
        axis=ax2,
        colormap=colormap,
        norm=norm,
        clim=clim,
    )

In [None]:
if not serial and cupid_run == 1:
    client.shutdown()