In [None]:
import xarray as xr
import numpy as np
import scipy.stats
import src.evt
import metpy.calc
import metpy.units
import matplotlib.pyplot as plt
import seaborn as sns
import os.path

## set plotting style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

# Load 6-hourly data

In [None]:
def load_whoi_data():
    """Load data originally obtained from WHOI's data server"""

    ## open pre-computed PNW data
    data = xr.open_mfdataset("../data/*.nc").compute()

    return data

In [None]:
data = load_whoi_data()

#### Compute wetbulb temperature

In [None]:
## compute wetbulb
Tw = metpy.calc.wet_bulb_temperature(
    pressure=data["sp"] * metpy.units.units.Pa,
    temperature=data["t2m"] * metpy.units.units.K,
    dewpoint=data["d2m"] * metpy.units.units.K,
)

## add to dataset
data["tw"] = Tw

## drop un-needed vars
data = data.drop_vars(["d2m", "sp"])

# Prep data
Following Bartusek et al., compute daily mean and subset to JJA

In [None]:
## subset for JJA
def get_jja(x):
    """function to subset for JJA"""
    is_jja = (x.time.dt.month <= 8) & (x.time.dt.month >= 6)
    return x.sel(time=is_jja)


data_daily_mean = data.resample({"time": "1D"}).mean()
data_daily_mean = get_jja(data_daily_mean)

Scatter $T_{2m}$ vs. $T_{w,2m}$

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 2.5))
ax.scatter(data_daily_mean.t2m, data_daily_mean.tw, s=1)
plt.show()

#### Compute annual max

In [None]:
data_annual_max = data_daily_mean.groupby("time.year").max()
year = data_annual_max.year

#### Plot timeseries comparison

In [None]:
normalize = lambda x: (x - x.mean()) / x.std()
remove_mean = lambda x: x - x.mean()

prep_fn = remove_mean

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(year, prep_fn(data_annual_max["tw"]).values, label=r"wet bulb")
ax.plot(year, prep_fn(data_annual_max["t2m"]).values, label=r"dry bulb")
ax.legend()
ax.set_ylabel(r"$K$")
ax.set_title(r"PNW $T_{2m}$ anomaly (daily max)")


plt.show()

## Compute GEV

In [None]:
## specify model type, one of {"gev", "gp"}
model_type = "gev"
var_name = "tw"

## select variable
X = data_annual_max[var_name].values

## specify whether to hold out 2021/2022 data
hold_out_2021 = False

## training data
if hold_out_2021:
    X_train = X[:-2]

else:
    X_train = X

## select EVT distribution
if model_type == "gev":
    model_class = scipy.stats.genextreme
    thresh = None

elif model_type == "gp":
    model_class = scipy.stats.genpareto
    thresh = 318


## Load data
## TO-DO
# X = load_prepped_data(model_type=model_type, thresh=thresh)

## Empirical PDF (normalized histogram)
pdf_empirical, bin_edges = src.evt.get_empirical_pdf(X)

## Empirical return period
tr_empirical, Xr_empirical = src.evt.get_empirical_return_period(X)

## Fit model and get return levels
model = src.evt.fit_model(X_train, model_class)
Xr, tr = src.evt.get_return_levels(model, return_periods=np.logspace(0.01, 5, 100))

# ## Compute confidence interval
bounds_fp = f"../data/PNW_ERA5/{model_type}_{var_name}_bounds.pkl"
Xr_lb, Xr_ub = src.evt.load_return_period_bnds(
    X, model_class=model_class, n_samples=1000, save_fp=bounds_fp
)

## compute estimated return time for max event
tr_max = tr[np.argmin(np.abs(Xr_empirical[-1] - Xr))]

## Plot PDF

In [None]:
## test points to plot curve
X_test = np.linspace(bin_edges.min(), bin_edges.max() + 1, 100)

## setup plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot empirical pdf
ax.stairs(pdf_empirical, edges=bin_edges, color="gray", fill=True, alpha=0.3)

## plot distribution fit
ax.plot(X_test, model.pdf(X_test), c="k")

## plot max value
ax.scatter(X.max(), 0, marker="x", c="r", s=50)

## label
ax.set_xlabel(r"Annual max ($K$)")
ax.set_ylabel("Prob.")

plt.show()

## Plot return level

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

## plot modeled return period
ax.plot(tr, Xr, c="k")
# ax.fill_between(tr, Xr_ub, Xr_lb, color="k", alpha=0.1)

## plot empirical return period
ax.scatter(tr_empirical, Xr_empirical, c="r", s=10)
ax.plot(
    [tr_empirical[np.argmax(Xr_empirical)], tr_max],
    [Xr_empirical.max(), Xr_empirical.max()],
    ls="--",
    lw=1,
    c="k",
)
ax.scatter(tr_max, Xr_empirical.max(), c="k", marker="x", s=50)

## limit
ax.set_ylim([None, Xr_empirical.max() + 0.5])
ax.set_xlim([1e0, 1e3])

## label axes
ax.set_xlabel("Return period (years)")
ax.set_ylabel(r"$T_{2m}$ ($K$)")
ax.set_xscale("log")
ax.set_title(f"Est. return time: {tr_max:.0f} years")

plt.show()