In [None]:
import xarray as xr
import numpy as np
import scipy.stats
import src.evt
import metpy.calc
import metpy.units
import matplotlib.pyplot as plt
import seaborn as sns
import os.path
import cartopy.crs as ccrs
import matplotlib.patches as mpatches
import cmocean
import os

## set plotting style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

#### Functions

In [None]:
def load_whoi_data():
    """Load data originally obtained from WHOI's data server"""

    ## open pre-computed PNW data
    # data = xr.open_mfdataset("../data/*ure.nc").compute()
    data = xr.open_dataset("../data/whoi_data_US.nc")

    return data


def plot_setup_simple(fig, projection, lon_range, lat_range):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## Create subplot with given projection
    ax = fig.add_subplot(projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    return ax


def plot_setup_US(fig):
    """Plot Pacific region"""

    ## Make projection
    proj = ccrs.Orthographic(central_longitude=255, central_latitude=35)
    # proj = ccrs.PlateCarree(central_longitude=240)
    proj._threshold /= 1000
    ax = plot_setup_simple(fig, proj, lon_range=[230, 290], lat_range=[15, 60])

    ## Plot bartusek's box
    ax.add_patch(
        mpatches.Rectangle(
            xy=[-130, 40],
            width=20,
            height=20,
            facecolor="none",
            edgecolor="magenta",
            transform=ccrs.PlateCarree(),
            zorder=10,
        )
    )

    return ax


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance"""
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )


def get_mse(data):
    """compute moist static energy"""

    ## get height
    geopot = data["geopotential"] * metpy.units.units("m^2/s^2")
    height = metpy.calc.geopotential_to_height(geopot)

    ## add units to temp, humidity
    temp = data["temperature"] * metpy.units.units.kelvin
    q = data["specific_humidity"] * metpy.units.units("kg/kg")

    # compute MSE
    mse = metpy.calc.moist_static_energy(
        height=height, temperature=temp, specific_humidity=q
    )

    return mse

# Load and prep data

In [None]:
## load data
data = load_whoi_data()

## drop un-needed vars
data = data.drop_vars(["d2m", "sp"]).metpy.dequantify()

## Compute annual max
data_annual_max = data.groupby("time.year").max()
year = data_annual_max.year

Function to compute return period

In [None]:
## Get test sample
X = data_annual_max["t2m"].isel(latitude=30, longitude=20)

def get_tr_max(X):
    """estimate return period for maximum value w/ and w/o LOO training.
    Function takes in an xr.Dataarray with at single lon/lat point"""

    ## Fit model and get return levels
    bounds = dict(c=[-1, 1], loc=[200, 400], scale=[-1e5, 1e5])
    
    ## Get indices for "Leave-one-out" version of data
    LOO_idx = np.array([i for i in range(len(X.year)) if i!=X.argmax("year").item()])
    X_LOO = X.isel(year=LOO_idx)
    
    ## fit models
    kwargs = dict(model_class = scipy.stats.genextreme, bounds=bounds)
    model = src.evt.fit_model(X, **kwargs)
    model_LOO = src.evt.fit_model(X.isel(year=LOO_idx), **kwargs)
    
    ## Get return periods for each model
    Xr, tr = src.evt.get_return_levels(model, return_periods=np.logspace(0.01, 5, 100))
    Xr_LOO, _ = src.evt.get_return_levels(model_LOO, return_periods=np.logspace(0.01, 5, 100))
    
    # ## Empirical return period
    tr_empirical, Xr_empirical = src.evt.get_empirical_return_period(X)
    
    # ## compute estimated return time for max event
    tr_max = tr[np.argmin(np.abs(Xr_empirical[-1] - Xr))]
    tr_max_LOO = tr[np.argmin(np.abs(Xr_empirical[-1] - Xr_LOO))]

    return np.array([tr_max, tr_max_LOO])

In [None]:
get_tr_max(data_annual_max["mse"].isel(latitude=35, longitude=20))

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

## plot modeled return period
ax.plot(tr, Xr, c="k")

## plot LOO return period
ax.plot(tr, Xr_LOO, c="k", ls=":")

## plot empirical return period
ax.scatter(tr_empirical, Xr_empirical, c="r", s=10)
ax.plot(
    [tr_empirical[np.argmax(Xr_empirical)], tr_max],
    [Xr_empirical.max(), Xr_empirical.max()],
    ls="--",
    lw=1,
    c="k",
)
ax.scatter(tr_max, Xr_empirical.max(), c="k", marker="x", s=50)

## limit
ax.set_ylim([None, Xr_empirical.max() + 2])
ax.set_xlim([1e0, 1e4])

## label axes
ax.set_xlabel("Return period (years)")
ax.set_ylabel(r"$T_{2m}$ ($K$)")
ax.set_xscale("log")
ax.set_title(f"Est. return time: {tr_max:.0f} years")

plt.show()