In [None]:
import xarray as xr
import numpy as np
import scipy.stats
import src.evt
import metpy.calc
import metpy.units
import matplotlib.pyplot as plt
import seaborn as sns
import os.path
import cartopy.crs as ccrs
import matplotlib.patches as mpatches
import cmocean

## set plotting style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

#### Functions

In [None]:
def load_whoi_data():
    """Load data originally obtained from WHOI's data server"""

    ## open pre-computed PNW data
    data = xr.open_mfdataset("../data/*.nc").compute()

    return data


def load_google_data():
    ## open pre-computed PNW data
    data = xr.open_dataset("../data/PNW_ERA5/era5_google.nc").compute()

    ## rename for consistency with other data
    data = data.rename(
        {
            "2m_temperature": "t2m",
            "2m_dewpoint_temperature": "tw",
            "surface_pressure": "sp",
        }
    )

    return data


def get_jja(x):
    """function to subset for JJA months"""
    is_jja = (x.time.dt.month <= 8) & (x.time.dt.month >= 6)
    return x.sel(time=is_jja)


def plot_setup_simple(fig, projection, lon_range, lat_range):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## Create subplot with given projection
    ax = fig.add_subplot(projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    return ax


def plot_setup_pnw(fig):
    """Plot Pacific region"""

    ## Make projection
    proj = ccrs.Orthographic(central_longitude=240, central_latitude=50)
    # proj = ccrs.PlateCarree(central_longitude=240)
    proj._threshold /= 1000
    ax = plot_setup_simple(fig, proj, lon_range=[200, 280], lat_range=[30, 70])

    ## Plot bartusek's box
    ax.add_patch(
        mpatches.Rectangle(
            xy=[-130, 40],
            width=20,
            height=20,
            facecolor="none",
            edgecolor="magenta",
            transform=ccrs.PlateCarree(),
            zorder=10,
        )
    )

    return ax


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance"""
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )


def get_mse(data):
    """compute moist static energy"""

    ## get height
    geopot = data["geopotential"] * metpy.units.units("m^2/s^2")
    height = metpy.calc.geopotential_to_height(geopot)

    ## add units to temp, humidity
    temp = data["temperature"] * metpy.units.units.kelvin
    q = data["specific_humidity"] * metpy.units.units("kg/kg")

    # compute MSE
    mse = metpy.calc.moist_static_energy(
        height=height, temperature=temp, specific_humidity=q
    )

    return mse

# Load and prep data

In [None]:
## Should we use ERA5 back-extension?
use_back_ext = False

## Load 2m-temperature, 2m-dewpoint, and surface pressure
if use_back_ext:
    data = load_google_data()

else:
    data = load_whoi_data()

## Compute wetbulb temperature
Tw = metpy.calc.wet_bulb_temperature(
    pressure=data["sp"] * metpy.units.units.Pa,
    temperature=data["t2m"] * metpy.units.units.K,
    dewpoint=data["d2m"] * metpy.units.units.K,
)

## add wetbulb temp. to dataset
data["tw"] = Tw

## drop un-needed vars
data = data.drop_vars(["d2m", "sp"])

## Following Bartusek et al. (2021), compute daily mean and subset to JJA
data_daily_mean = data.resample({"time": "1D"}).mean()
data_daily_mean = get_jja(data_daily_mean)

## Compute annual max
data_annual_max = data_daily_mean.groupby("time.year").max()
year = data_annual_max.year

## Make some plots

Scatter $T_{2m}$ vs. $T_{w,2m}$

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 2.5))
ax.set_aspect("equal")
ax.scatter(data_daily_mean.t2m, data_daily_mean.tw, s=1)
ax.set_xlabel(r"dry bulb")
ax.set_ylabel(r"wet bulb")
ax.set_xticks([280, 295])
ax.set_yticks([278, 288])
plt.show()

Timeseries of $T_{2m}$ vs. $T_{w,2m}$

In [None]:
## should we normalize timeseries by std. dev.?
normalize = False

## functions to get max index
get_topk_idx = lambda k, var_name: np.argsort(data_annual_max[var_name]).values[-k:][
    ::-1
]
get_topk_year = lambda k, var_name: data_annual_max.year[get_topk_idx(k, var_name)]

## define preprocess function for plotting
if normalize:
    prep_fn = lambda x: (x - x.mean()) / x.std()

else:
    prep_fn = lambda x: x - x.mean()


## make the plot
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(year, prep_fn(data_annual_max["tw"]).values, label=r"wet bulb")
ax.plot(year, prep_fn(data_annual_max["t2m"]).values, label=r"dry bulb")

## mark top-k years
topk_years = get_topk_year(k=3, var_name="t2m")
ax.scatter(topk_years, prep_fn(data_annual_max["t2m"]).sel(year=topk_years), c="k")

## label
ax.legend()
ax.set_ylabel(r"$K$")
ax.set_title(r"PNW $T_{2m}$ anomaly (daily max)")
ax.set_xticks([year[0], *topk_years.values])

plt.show()

## When did the annual max happen?

In [None]:
## find the maximum for each year
argmax = lambda x: x.argmax("time")
argmax_for_year = data_daily_mean.groupby("time.year").map(argmax)


def get_topk_time_idx(k, var_name):
    topk_idx = argmax_for_year[var_name].sel(year=topk_years).values
    return data_daily_mean[var_name].isel(time=topk_idx).time


## print out top time index
for var_name in ["t2m", "tw"]:
    topk_times = get_topk_time_idx(k=3, var_name=var_name)
    for t in topk_times:
        print(f"{t.dt.month.values.item()}/{t.dt.day.values.item()}")
    print()

# ## scatter times for max of each var
# fig,ax = plt.subplots(figsize=(3,3))
# ax.set_aspect("equal")
# ax.scatter(argmax_for_year["t2m"], argmax_for_year["tw"])
# ax.plot(np.linspace(30,85),np.linspace(30,85), c="k")
# ax.set_xticks([29,60, 91], labels=["Jun 30", "Jul 31", "Aug 31"])
# ax.set_yticks([29,60, 91], labels=["Jun 30", "Jul 31", "Aug 31"])
# ax.set_xlabel("dry bulb")
# ax.set_ylabel("wet bulb")
# plt.show()

## Load in data

In [None]:
def load_data_and_clim(start_date, end_date):
    """Load data and climatology for given range"""

    ## get filepaths
    data_fname = f"data_{start_date}-{end_date}.nc"
    clim_fname = f"clim_{start_date}-{end_date}.nc"

    ## open data
    data = xr.open_dataset(os.path.join("../data/PNW_ERA5", data_fname))
    clim = xr.open_dataset(os.path.join("../data/PNW_ERA5", clim_fname))

    ## resample to daily
    data = data.resample({"time": "D"}).mean()
    clim = clim.mean("hour").rename({"dayofyear": "time"})

    ## match climatology time dimension to data
    clim["time"] = data["time"]

    return data, clim

In [None]:
## open data
# data, clim = load_data_and_clim(start_date="2021-06-20", end_date="2021-07-10")
data, clim = load_data_and_clim(start_date="2013-06-22", end_date="2013-07-12")

## compute MSE
data["mse"] = get_mse(data)
clim["mse"] = get_mse(clim)

# ## compute anomalies
data_anom = data - clim

## Make plots

In [None]:
plot_time = "2013-06-26"

#### $T_{2m}$

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = plot_setup_pnw(fig)

## plot 2m-temperature
t2m_plot = ax.contourf(
    data_anom.longitude,
    data_anom.latitude,
    data_anom["2m_temperature"].sel(time=plot_time),
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(20, 2),
    extend="both",
)

## plot Z500
z500_plot = ax.contour(
    data_anom.longitude,
    data_anom.latitude,
    data_anom["geopotential"].sel(time=plot_time, level=500) / 9.8,
    colors="k",
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(300, 50),
    extend="both",
    linewidths=0.75,
)

cb = fig.colorbar(
    t2m_plot, ticks=[-20, -10, 0, 10, 20], fraction=0.02, label=r"$^{\circ}C$"
)

plt.show()

## MSE plot

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = plot_setup_pnw(fig)

## plot mse
sm_plot = ax.contourf(
    data_anom.longitude,
    data_anom.latitude,
    data_anom["mse"].sel(level=500, time=plot_time),
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(20, 2),
    extend="both",
)

## plot Z500
z500_plot = ax.contour(
    data_anom.longitude,
    data_anom.latitude,
    data_anom["geopotential"].sel(time=plot_time, level=500) / 9.8,
    colors="k",
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(300, 50),
    extend="both",
    linewidths=0.75,
)

cb = fig.colorbar(
    sm_plot, fraction=0.02, label=r"$kJ~kg^{-1}$", ticks=[-20, -10, 0, 10, 20]
)

plt.show()

#### Plot specific humidity

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = plot_setup_pnw(fig)

## plot q
sm_plot = ax.contourf(
    data_anom.longitude,
    data_anom.latitude,
    1000 * data_anom["specific_humidity"].sel(level=850, time=plot_time),
    cmap="cmo.balance_r",
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(8, 0.8),
    extend="both",
)

## plot Z500
z500_plot = ax.contour(
    data_anom.longitude,
    data_anom.latitude,
    data_anom["geopotential"].sel(time=plot_time, level=500) / 9.8,
    colors="k",
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(300, 50),
    extend="both",
    linewidths=0.75,
)

cb = fig.colorbar(sm_plot, fraction=0.02, label=r"$g~kg^{-1}$", ticks=[-8, -4, 0, 4, 8])

plt.show()

#### Plot soil moisture

In [None]:
fig = plt.figure(figsize=(8, 4))
ax = plot_setup_pnw(fig)

## plot q
sm_plot = ax.contourf(
    data_anom.longitude,
    data_anom.latitude,
    data_anom["volumetric_soil_water_layer_1"].sel(time=plot_time),
    cmap="cmo.balance_r",
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(0.2, 0.02),
    extend="both",
)

## plot Z500
z500_plot = ax.contour(
    data_anom.longitude,
    data_anom.latitude,
    data_anom["geopotential"].sel(time=plot_time, level=500) / 9.8,
    colors="k",
    transform=ccrs.PlateCarree(),
    levels=make_cb_range(300, 50),
    extend="both",
    linewidths=0.75,
)

cb = fig.colorbar(
    sm_plot, fraction=0.02, label=r"$m^3~m^{-3}$", ticks=[-0.2, -0.1, 0, 0.1, 0.2]
)

plt.show()

# Compute GEV

In [None]:
## specify model type, one of {"gev", "gp"}
model_type = "gev"
var_name = "tw"

## select variable
X = data_annual_max[var_name].values

## specify whether to hold out 2021/2022 data
hold_out_2021 = False

## training data
if hold_out_2021:
    X_train = X[:-2]

else:
    X_train = X

## select EVT distribution
if model_type == "gev":
    model_class = scipy.stats.genextreme
    thresh = None

elif model_type == "gp":
    model_class = scipy.stats.genpareto
    thresh = 318

## Empirical PDF (normalized histogram)
pdf_empirical, bin_edges = src.evt.get_empirical_pdf(X)

## Empirical return period
tr_empirical, Xr_empirical = src.evt.get_empirical_return_period(X)

## Fit model and get return levels
model = src.evt.fit_model(X_train, model_class)
Xr, tr = src.evt.get_return_levels(model, return_periods=np.logspace(0.01, 5, 100))

# ## Compute confidence interval
# bounds_fp = f"../data/PNW_ERA5/{model_type}_{var_name}_bounds.pkl"
# Xr_lb, Xr_ub = src.evt.load_return_period_bnds(
#     X, model_class=model_class, n_samples=1000, save_fp=bounds_fp
# )

## compute estimated return time for max event
tr_max = tr[np.argmin(np.abs(Xr_empirical[-1] - Xr))]

# Plot results

## PDF

In [None]:
## test points to plot curve
X_test = np.linspace(bin_edges.min(), bin_edges.max() + 1, 100)

## setup plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot empirical pdf
ax.stairs(pdf_empirical, edges=bin_edges, color="gray", fill=True, alpha=0.3)

## plot distribution fit
ax.plot(X_test, model.pdf(X_test), c="k")

## plot max value
ax.scatter(X.max(), 0, marker="x", c="r", s=50)

## label
ax.set_xlabel(r"Annual max ($K$)")
ax.set_ylabel("Prob.")

plt.show()

## Return times

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

## plot modeled return period
ax.plot(tr, Xr, c="k")

## shade confidence bounds
# ax.fill_between(tr, Xr_ub, Xr_lb, color="k", alpha=0.1)

## plot empirical return period
ax.scatter(tr_empirical, Xr_empirical, c="r", s=10)
ax.plot(
    [tr_empirical[np.argmax(Xr_empirical)], tr_max],
    [Xr_empirical.max(), Xr_empirical.max()],
    ls="--",
    lw=1,
    c="k",
)
ax.scatter(tr_max, Xr_empirical.max(), c="k", marker="x", s=50)

## limit
ax.set_ylim([None, Xr_empirical.max() + 0.5])
ax.set_xlim([1e0, 1e3])

## label axes
ax.set_xlabel("Return period (years)")
ax.set_ylabel(r"$T_{2m}$ ($K$)")
ax.set_xscale("log")
ax.set_title(f"Est. return time: {tr_max:.0f} years")

plt.show()