# Spatial UNSEEN analysis


In [None]:
%load_ext autoreload
%autoreload 2

import cartopy
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LatitudeFormatter, LongitudeFormatter
import geopandas as gp
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import numpy as np
import os
import xarray as xr

from unseen import fileio, independence, similarity, time_utils

import cfg, spatial_plots

In [None]:
# Optional parameters
# (This cell is tagged "parameters")
dpi = 300
shapefile = None
shape_overlap = 0.1
alpha = 0.05
# todo: add to Makefile
init_dim = "init_date"
lead_dim = "lead_time"
ensemble_dim = "ensemble"
time_dim = "time"
lat_dim = "lat"
lon_dim = "lon"
similarity_test = "ks"
gev_relative_fit_test = "bic"
time_agg = "maximum"
covariate_year = 2024
base_period = [1961, 2020]  # for covariate
# reference_time_period = ["1960-01-01", "2020-12-31"]

plot_additive_bc = True
plot_multiplicative_bc = False

In [None]:
# Required parameters
kwargs = locals()
assert "metric" in kwargs, "Must provide a metric name"
assert "var" in kwargs, "Must provide a variable name"
assert "model_name" in kwargs, "Must provide a model name"

assert os.path.isfile(
    obs_file
), f"Must provide an observations data file (papermill option -p obs_file [filepath])"
assert os.path.isfile(
    file_list
), "Must provide the input model files list (papermill option -p file_list [filepath])"
assert os.path.isfile(
    model_file
), "Must provide an model data file (papermill option -p model_file [filepath])"
assert os.path.isfile(
    independence_file
), "Must provide an independence min lead file (papermill option -p independence_file [filepath])"
if plot_additive_bc:
    assert os.path.isfile(
        model_add_bc_file
    ), "Must provide a model additive bias corrected data file (papermill option -p model_add_bc_file [filepath])"
if plot_multiplicative_bc:
    assert os.path.isfile(
        model_mulc_bc_file
    ), "Must provide a model multiplicative bias corrected data file (papermill option -p model_mulc_bc_file [filepath])"
assert os.path.isfile(
    similarity_raw_file
), "Must provide an raw data similarity test file (papermill option -p similarity_raw_file [filepath])"
assert os.path.isfile(
    similarity_add_bc_file
), "Must provide an additive bias corrected similarity test file (papermill option -p similarity_add_bias_file [filepath])"
assert os.path.isfile(
    similarity_mulc_bc_file
), "Must provide an multiplicative bias corrected similarity test file (papermill option -p similarity_mulc_bias_file [filepath])"

## Model ensemble


### Region selection


In [None]:
# # Plot region shapefile outline using the first file in the list
# if shape_file and file_list:
#     with open(file_list) as f:
#         all_files = f.read()
#         first_file = all_files.split("\n", 1)[0]

#     shapes = gp.read_file(shape_file)
#     isel_dict = {}
#     if model_name == "CAFE":
#         isel_dict["ensemble"] = 0

#     region_ds = fileio.open_dataset(
#         first_file,
#         metadata_file=metadata_file,
#         variables=[var],
#         lat_bnds=[-44, -10],
#         lon_bnds=[113, 155],
#         shapefile=shapefile,
#         shape_overlap=shape_overlap,
#         units={var: units_dict[var]},
#         isel=isel_dict,
#     )

#     fig = plt.figure(figsize=[12, 8])
#     ax = fig.add_subplot(111, projection=ccrs.PlateCarree())
#     region_ds[var].mean("time", keep_attrs=True).plot(
#         ax=ax,
#         transform=ccrs.PlateCarree(),
#         cmap="viridis_r",
#     )
#     ax.coastlines()
#     ax.add_feature(cartopy.feature.STATES)
#     ax.add_geometries(
#         shapes.geometry, ccrs.PlateCarree(), facecolor="none", edgecolor="orange"
#     )
#     plt.show()

### Model data


In [None]:
model_ds = fileio.open_dataset(model_file)
# # !!! This doesn't work for non-standard calendars (e.g., HadGEM3-GC31-MM)
# model_ds["event_time"] = model_ds.event_time.astype(dtype="datetime64[ns]")
event_times = np.vectorize(time_utils.str_to_cftime)(
    model_ds.event_time, model_ds.time.dt.calendar
)
model_ds["event_time"] = (model_ds.event_time.dims, event_times)
model_ds

In [None]:
if plot_additive_bc:
    model_add_bc_ds = fileio.open_dataset(model_add_bc_file)

    # Convert event_time date strings to cftime
    model_add_bc_ds["event_time"] = model_add_bc_ds.event_time.astype(
        dtype="datetime64[ns]"
    )
    # event_times = np.vectorize(time_utils.str_to_cftime)(
    #     model_ds.event_time, model_ds.time.dt.calendar
    # )
    # model_ds["event_time"] = (model_ds.event_time.dims, event_times)
    # model_ds

In [None]:
if plot_multiplicative_bc:
    model_mulc_bc_ds = fileio.open_dataset(model_mulc_bc_file)
    model_mulc_bc_ds["event_time"] = model_mulc_bc_ds.event_time.astype(
        dtype="datetime64[ns]"
    )
    # model_mulc_bc_ds["event_time"] = (
    #     model_mulc_bc_ds.time.dims,
    #     np.vectorize(time_utils.str_to_cftime)(
    #         model_mulc_bc_ds.event_time, model_mulc_bc_ds.time.dt.calendar
    #     ),
    # )

### Independence testing


In [None]:
ds_independence = xr.open_dataset(independence_file, use_cftime=True)
ds_independence

Plot the correlation coefficients for each initialisation month and lead time

In [None]:
cm = ds_independence.r.plot(
    col=lead_dim,
    row="month",
    subplot_kws=dict(projection=ccrs.PlateCarree()),
    transform=ccrs.PlateCarree(),
    add_colorbar=False,
    cmap=plt.cm.seismic,
)

# Fix hidden axis ticks and labels
for i, ax in enumerate(cm.axs.flat):
    ax.coastlines()
    ax.xaxis.set_major_formatter(LongitudeFormatter())
    ax.yaxis.set_major_formatter(LatitudeFormatter())
    ax.xaxis.set_minor_locator(AutoMinorLocator())
    ax.yaxis.set_minor_locator(AutoMinorLocator())
    ax.set_xlabel(None)
    ax.set_ylabel(None)

    subplotspec = ax.get_subplotspec()
    ax.xaxis.set_visible(True)
    if subplotspec.is_first_col():
        ax.yaxis.set_visible(True)

cm.fig.set_constrained_layout(True)
cm.fig.get_layout_engine().set(h_pad=0.2)
cm.add_colorbar(pad=0.02)

Plot maps of the null correlation bounds 

In [None]:
cm = ds_independence.ci.plot(
    col="quantile",
    row="month",
    subplot_kws=dict(projection=ccrs.PlateCarree()),
    transform=ccrs.PlateCarree(),
    add_colorbar=False,
    cmap=plt.cm.seismic,
)
# Fix hidden axis ticks
for i, ax in enumerate(cm.axs.flat):
    ax.coastlines()
    ax.xaxis.set_major_formatter(LongitudeFormatter())
    ax.yaxis.set_major_formatter(LatitudeFormatter())
    ax.xaxis.set_minor_locator(AutoMinorLocator())
    ax.yaxis.set_minor_locator(AutoMinorLocator())
    ax.set_xlabel(None)
    ax.set_ylabel(None)
    subplotspec = ax.get_subplotspec()
    ax.xaxis.set_visible(True)
    if subplotspec.is_first_col():
        ax.yaxis.set_visible(True)

cm.fig.set_constrained_layout(True)
cm.fig.get_layout_engine().set(h_pad=0.2)
cm.add_colorbar(pad=0.02)

Plot map of first independent lead time (first lead time where the correlation coefficient is within the null correlation bounds)

In [None]:
independence.spatial_plot(
    ds_independence,
    dataset_name=model_name,
    outfile=independence_plot,
)

In [None]:
# Drop dependent lead times based on the minimum independent lead time
min_lead_ds = fileio.open_dataset(
    independence_file,
    variables="min_lead",
    shapefile=shapefile,
    shape_overlap=shape_overlap,
    spatial_agg=min_lead_spatial_agg,
)
min_lead = min_lead_ds["min_lead"].load()

print(min_lead)

In [None]:
model_ds = model_ds.groupby(f"{init_dim}.month").where(model_ds[lead_dim] >= min_lead)

In [None]:
if plot_additive_bc:
    model_add_bc_ds = model_add_bc_ds.groupby(f"{init_dim}.month").where(
        model_add_bc_ds[lead_dim] >= min_lead
    )

In [None]:
if plot_multiplicative_bc:
    model_mulc_bc_ds = model_mulc_bc_ds.groupby(f"{init_dim}.month").where(
        model_mulc_bc_ds[lead_dim] >= min_lead
    )

## Similarity analysis
We can look at p-values for the KS-test and Anderson-Darling test for each lead time.

p > 0.05 means the null hypothesis (that the two samples are from the same population) can't be rejected.

In [None]:
similarity_ds = fileio.open_dataset(str(similarity_raw_file))

In [None]:
similarity_add_bc_ds = fileio.open_dataset(str(similarity_add_bc_file))

In [None]:
similarity_mulc_bc_ds = fileio.open_dataset(str(similarity_mulc_bc_file))

In [None]:
similarity.similarity_spatial_plot(
    similarity_ds,
    dataset_name=model_name,
    outfile=similarity_raw_plot,
)

In [None]:
similarity.similarity_spatial_plot(
    similarity_add_bc_ds,
    dataset_name=f"{model_name} (additive bias corrected)",
    outfile=similarity_add_bc_plot,
)

In [None]:
similarity.similarity_spatial_plot(
    similarity_mulc_bc_ds,
    dataset_name=f"{model_name} (multiplicative bias corrected)",
    outfile=similarity_mulc_bc_plot,
)

In [None]:
# Add the similarity test results to the dataset
model_ds["pval_mask"] = similarity_ds[f"{similarity_test}_pval"] <= alpha

In [None]:
if plot_additive_bc:
    model_add_bc_ds["pval_mask"] = (
        similarity_add_bc_ds[f"{similarity_test}_pval"] <= alpha
    )

In [None]:
if plot_multiplicative_bc:
    model_mulc_bc_ds["pval_mask"] = (
        similarity_mulc_bc_ds[f"{similarity_test}_pval"] <= alpha
    )

## Observational data

In [None]:
obs_ds = fileio.open_dataset(obs_file)
obs_ds

In [None]:
# Select observations within the model initialisation times # todo: check this
model_init_time_bnds = model_ds.time.isel(
    {lead_dim: 0, init_dim: [0, -1]}
).dt.year.values
print("model init start/finish", model_init_time_bnds)
obs_ds = obs_ds.where(
    (obs_ds.time.dt.year >= model_init_time_bnds[0])
    & (obs_ds.time.dt.year <= model_init_time_bnds[1]),
    drop=True,
)
obs_ds = obs_ds.dropna("time", how="all")
obs_ds

In [None]:
obs_max_event = obs_ds[var].max().load().item()
obs_max_event_loc = (
    obs_ds[var].where(obs_ds[var].load() == obs_max_event, drop=True).squeeze()
)
obs_max_event_loc.load()

## ACS Spatial Maps

In [None]:
# Store plot related variables using the InfoSet class

info = cfg.InfoSet(
    name=model_name,
    metric=metric,
    obs_file=obs_file,
    project_dir=project_dir,
    file=model_file,
    ds=model_ds,
    bias_correction=None,
    masked=False,
)
info.date_range_obs = cfg.date_range_str(obs_ds.time, info.freq)

if plot_additive_bc:
    info_add_bc = cfg.InfoSet(
        name=model_name,
        metric=metric,
        obs_file=obs_file,
        project_dir=project_dir,
        file=model_add_bc_file,
        ds=model_add_bc_ds,
        bias_correction="additive",
        masked=False,
    )
    info_add_bc.date_range_obs = cfg.date_range_str(obs_ds.time, info.freq)

if plot_multiplicative_bc:
    info_mulc_bc = cfg.InfoSet(
        name=model_name,
        metric=metric,
        obs_file=obs_file,
        project_dir=project_dir,
        file=model_mulc_bc_file,
        ds=model_mulc_bc_ds,
        bias_correction="multiplicative",
        masked=False,
    )
    info_mulc_bc.date_range_obs = cfg.date_range_str(obs_ds.time, info.freq)

In [None]:
# Stack datasets
model_ds_stacked = (
    model_ds.dropna(lead_dim)
    .stack({"sample": [ensemble_dim, init_dim, lead_dim]}, create_index=False)
    .transpose("sample", ...)
)

if plot_additive_bc:
    model_add_bc_ds_stacked = (
        model_add_bc_ds.dropna(lead_dim)
        .stack({"sample": [ensemble_dim, init_dim, lead_dim]}, create_index=False)
        .transpose("sample", ...)
    )
    assert model_ds_stacked[var].shape == model_add_bc_ds_stacked[var].shape

# Model multiplicative bias corrected data
if plot_multiplicative_bc:
    model_mulc_bc_ds_stacked = (
        model_mulc_bc_ds.dropna(lead_dim)
        .stack({"sample": [ensemble_dim, init_dim, lead_dim]}, create_index=False)
        .transpose("sample", ...)
    )
    assert model_ds_stacked[var].shape == model_mulc_bc_ds_stacked[var].shape

In [None]:
# Load GEV parameters
covariate = model_ds_stacked[time_dim].dt.year
# base_period = [model_ds.time.dt.year.min().load().item(), obs_ds.time.dt.year.max().load().item()]
times = xr.DataArray(base_period, dims=info.time_dim)

In [None]:
dparams_s = xr.open_dataset(gev_params_stationary_file)[var]
dparams_ns = xr.open_dataset(gev_params_nonstationary_file)[var]
# Ensure model ds and GEV parameters have the same lat/lon grid (hard to fit a
# GEV at ocean points - usually apply a mask that drops some lat/lons)
model_ds_stacked = model_add_bc_ds_stacked.sel(lat=dparams_ns.lat, lon=dparams_ns.lon)

if plot_additive_bc:
    dparams_add_bc_s = xr.open_dataset(gev_params_stationary_add_bc_file)[var]
    dparams_add_bc_ns = xr.open_dataset(gev_params_nonstationary_add_bc_file)[var]
    model_add_bc_ds_stacked = model_add_bc_ds_stacked.sel(
        lat=dparams_add_bc_ns.lat, lon=dparams_add_bc_ns.lon
    )

if plot_multiplicative_bc:
    dparams_mulc_bc_s = xr.open_dataset(gev_params_stationary_mulc_bc_file)[var]
    dparams_mulc_bc_ns = xr.open_dataset(gev_params_nonstationary_mulc_bc_file)[var]
    model_mulc_bc_ds_stacked = model_mulc_bc_ds_stacked.sel(
        lat=dparams_mulc_bc_ns.lat, lon=dparams_mulc_bc_ns.lon
    )
print(dparams_s)
print(dparams_ns)

# Year when max/min event occured

In [None]:
for mask in [None, True]:
    spatial_plots.plot_map_event_year(info, model_ds_stacked, time_agg, mask=mask)

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_event_year(
            info_add_bc, model_add_bc_ds_stacked, time_agg, mask=mask
        )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_event_year(
            info_mulc_bc, model_mulc_bc_ds_stacked, time_agg, mask=mask
        )

## Most common month for max/min event

In [None]:
for mask in [None, True]:
    spatial_plots.plot_map_event_month_mode(info, model_ds_stacked, mask=mask)

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_event_month_mode(
            info_add_bc, model_add_bc_ds_stacked, mask=mask
        )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_event_month_mode(
            info_mulc_bc, model_mulc_bc_ds_stacked, mask=mask
        )


## Map of metric median

In [None]:
for mask in [None, True]:
    spatial_plots.plot_map_time_agg(info, model_ds_stacked, "median", mask=mask)

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_time_agg(
            info_add_bc, model_add_bc_ds_stacked, "median", mask=mask
        )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_time_agg(
            info_mulc_bc, model_mulc_bc_ds_stacked, "median", mask=mask
        )

# Map of metric maximum or minimum

In [None]:
for mask in [None, True]:
    spatial_plots.plot_map_time_agg(info, model_ds_stacked, time_agg, mask=mask)

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_time_agg(
            info_add_bc, model_add_bc_ds_stacked, time_agg, mask=mask
        )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_time_agg(
            info_mulc_bc, model_mulc_bc_ds_stacked, time_agg, mask=mask
        )

# Model-observation comparison (soft records/grey swans)

### Model minus observations (median anomaly)

In [None]:
# Model minus observations (median anomaly, no bias correction)
for mask in [None, True]:
    spatial_plots.plot_map_obs_anom(
        info, model_ds_stacked, obs_ds, "median", "anom", mask=mask
    )

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_obs_anom(
            info_add_bc, model_add_bc_ds_stacked, obs_ds, "median", "anom", mask=mask
        )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_obs_anom(
            info_mulc_bc, model_mulc_bc_ds_stacked, obs_ds, "median", "anom", mask=mask
        )

### Model minus observations (maximum anomalies)

In [None]:
anom_metrics = ["anom", "anom_std", "anom_pct", "anom_2000yr"]

In [None]:
for anom_metric in anom_metrics:
    for mask in [None, True]:
        spatial_plots.plot_map_obs_anom(
            info,
            model_ds_stacked,
            obs_ds,
            time_agg,
            anom_metric,
            dparams_ns,
            covariate_year,
            mask=mask,
        )

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for anom_metric in anom_metrics:
        for mask in [None, True]:
            spatial_plots.plot_map_obs_anom(
                info_add_bc,
                model_add_bc_ds_stacked,
                obs_ds,
                time_agg,
                anom_metric,
                dparams_add_bc_ns,
                covariate_year,
                mask=mask,
            )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for anom_metric in anom_metrics:
        for mask in [None, True]:
            spatial_plots.plot_map_obs_anom(
                info_mulc_bc,
                model_mulc_bc_ds_stacked,
                obs_ds,
                time_agg,
                anom_metric,
                dparams_mulc_bc_ns,
                covariate_year,
                mask=mask,
            )

## Annual reccurence of observed max/min event

In [None]:
for mask in [None, True]:
    spatial_plots.plot_map_obs_ari(
        info,
        model_ds_stacked,
        obs_ds,
        dparams_ns,
        covariate=covariate_year,
        time_agg=time_agg,
        mask=mask,
    )

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_obs_ari(
            info_add_bc,
            model_add_bc_ds_stacked,
            obs_ds,
            dparams_add_bc_ns,
            covariate=covariate_year,
            time_agg=time_agg,
            mask=mask,
        )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_obs_ari(
            info_mulc_bc,
            model_mulc_bc_ds_stacked,
            obs_ds,
            dparams_mulc_bc_ns,
            covariate=covariate_year,
            time_agg=time_agg,
            mask=mask,
        )

## GEV parameter trends

In [None]:
for param in ["scale", "location"]:
    for mask in [None, True]:
        spatial_plots.plot_map_gev_param_trend(
            info, model_ds_stacked, dparams_ns, param=param, mask=mask
        )

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for param in ["scale", "location"]:
        for mask in [None, True]:
            spatial_plots.plot_map_gev_param_trend(
                info_add_bc,
                model_add_bc_ds_stacked,
                dparams_add_bc_ns,
                param=param,
                mask=mask,
            )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for param in ["scale", "location"]:
        for mask in [None, True]:
            spatial_plots.plot_map_gev_param_trend(
                info_mulc_bc,
                model_mulc_bc_ds_stacked,
                dparams_mulc_bc_ns,
                param=param,
                mask=mask,
            )

## Annual exceedance probability

In [None]:
# ARI: 5, 10, 50, 100, 1000 years (i.e., 20% AEP is equiv to a 1 in 5 year event)
aep_list = [20, 10, 2, 1, 0.1]

In [None]:
for aep in aep_list:
    for mask in [None, True]:
        spatial_plots.plot_map_aep(
            info,
            model_ds_stacked,
            dparams_ns,
            times,
            aep=aep,
            mask=mask,
        )

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for aep in aep_list:
        for mask in [None, True]:
            spatial_plots.plot_map_aep(
                info_add_bc,
                model_add_bc_ds_stacked,
                dparams_add_bc_ns,
                times,
                aep=aep,
                mask=mask,
            )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for aep in aep_list:
        for mask in [None, True]:
            spatial_plots.plot_map_aep(
                info_mulc_bc,
                model_mulc_bc_ds_stacked,
                dparams_mulc_bc_ns,
                times,
                aep=aep,
                mask=mask,
            )

# Probability of breaking the observed record

In [None]:
for mask in [None, True]:
    spatial_plots.plot_map_new_record_probability(
        info,
        model_ds,
        obs_ds,
        dparams_ns,
        covariate_year,
        time_agg,
        ari=10,
        mask=mask,
    )

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_new_record_probability(
            info_add_bc,
            model_add_bc_ds,
            obs_ds,
            dparams_add_bc_ns,
            covariate_year,
            time_agg,
            ari=10,
            mask=mask,
        )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_new_record_probability(
            info_mulc_bc,
            model_mulc_bc_ds,
            obs_ds,
            dparams_mulc_bc_ns,
            covariate_year,
            time_agg,
            ari=10,
            mask=mask,
        )

# Subsampling analysis

In [None]:
for mask in [None, True]:
    spatial_plots.plot_map_time_agg_subsampled(
        info, model_ds_stacked, obs_ds, time_agg, n_samples=1000, mask=mask
    )

In [None]:
# Additive bias corrected
if plot_additive_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_time_agg_subsampled(
            info_add_bc,
            model_add_bc_ds_stacked,
            obs_ds,
            time_agg,
            n_samples=1000,
            mask=mask,
        )

In [None]:
# Multiplicative bias corrected
if plot_multiplicative_bc:
    for mask in [None, True]:
        spatial_plots.plot_map_time_agg_subsampled(
            info_mulc_bc,
            model_mulc_bc_ds_stacked,
            obs_ds,
            time_agg,
            n_samples=1000,
            mask=mask,
        )