# Observational spatial analysis


In [None]:
%load_ext autoreload
%autoreload 2

import calendar
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import xarray as xr

from unseen import fileio, time_utils, eva
from acs_plotting_maps import cmap_dict, tick_dict  # NOQA

import spatial_plots

In [None]:
# Optional parameters
# (This cell is tagged "parameters")
dpi = 300
shapefile = None
shape_overlap = 0.1
alpha = 0.05
time_dim = "time"
lat_dim = "lat"
lon_dim = "lon"
similarity_test = "ks"
gev_relative_fit_test = "lrt"
time_agg = "maximum"

In [None]:
# Required parameters
kwargs = locals()
assert "metric" in kwargs, "Must provide a metric name"
assert "var" in kwargs, "Must provide a variable name"
assert "obs_name" in kwargs, "Must provide a name"

if isinstance(reference_time_period, str):
    reference_time_period = list(reference_time_period.split(" "))

assert os.path.isfile(
    obs_file
), f"Must provide an observations data file (papermill option -p obs_file [filepath])"
assert os.path.isfile(
    gev_params_nonstationary_file
), f"Must provide a nonstationary GEV parameters file (papermill option -p gev_params_nonstationary_file [filepath])"

assert os.path.isfile(
    gev_params_nonstationary_drop_max_file
), f"Must provide a nonstationary GEV parameters file (papermill option -p gev_params_nonstationary_drop_max_file [filepath])"
assert (
    "covariate_base" in kwargs
), "Must provide a nonstationary GEV covariate base year"
assert (
    time_agg in spatial_plots.func_dict
), f"Invalid time aggregation method: {time_agg} (options: {list(spatial_plots.func_dict.keys())})"

# Format parameters passed as strings
assert (
    "gev_trend_period" in kwargs
), "Must provide a GEV trend period (e.g., '[1981, 2010]')"
if isinstance(gev_trend_period, str):
    # Convert string to list
    gev_trend_period = eval(gev_trend_period)
    print(f"gev_trend_period: {gev_trend_period}")

assert (
    "plot_dict" in kwargs
), "Must provide spatial plot dictionary of labels, ticks and colormaps"
if isinstance(plot_dict, str):
    # Convert string to dictionary and check for required keys
    plot_dict = eval(plot_dict)
    for key in [
        "metric",
        "var",
        "var_name",
        "units",
        "units_label",
        "freq",
        "cmap",
        "cmap_anom",
        "ticks",
        "ticks_anom",
        "ticks_param_trend",
        "cbar_extend",
        "agcd_mask",
    ]:
        assert key in plot_dict, f"Missing key: {key} in plot_dict"

### Open dataset of metric in observational data

In [None]:
obs_ds = fileio.open_dataset(obs_file, shapefile=shapefile, shape_overlap=shape_overlap)
obs_ds

In [None]:
# Select reference time period (defined in metric config file)
if reference_time_period is not None:
    obs_ds = time_utils.select_time_period(obs_ds, reference_time_period)
obs_ds = obs_ds.dropna("time", how="all")
obs_ds

In [None]:
# Convert event time strings to cftime objects
event_times = np.vectorize(time_utils.str_to_cftime)(
    obs_ds.event_time, obs_ds.time.dt.calendar
)
obs_ds["event_time"] = (obs_ds.event_time.dims, event_times)
obs_ds

In [None]:
obs_max_event = obs_ds[var].max().load().item()
obs_max_event_loc = (
    obs_ds[var].where(obs_ds[var].load() == obs_max_event, drop=True).squeeze()
)
obs_max_event_loc.load()

## Spatial Maps

In [None]:
# Store plot related variables
info = spatial_plots.InfoSet(
    name=obs_name,
    obs_name=obs_name,
    fig_dir=fig_dir,
    file=obs_file,
    obs_ds=obs_ds,
    **plot_dict,
)

# Plot year when record event occurred

In [None]:
spatial_plots.plot_event_year(info, obs_ds, time_agg)

## Most common month of event

In [None]:
spatial_plots.plot_event_month_mode(info, obs_ds)

In [None]:
# Event month distribution (based on all grid points)
months = obs_ds.event_time.dt.month
months = xr.where(~np.isnan(obs_ds[var]), months, np.nan)
months.plot.hist(bins=np.arange(0.5, 13))
plt.xticks(np.arange(1, 13), [calendar.month_abbr[i] for i in range(1, 13)])


## Map of metric median

In [None]:
spatial_plots.plot_time_agg(info, obs_ds, "median")

# Map of metric maximum/minimum

In [None]:
spatial_plots.plot_time_agg(info, obs_ds, time_agg)

# GEV analysis

In [None]:
# Load GEV parameters
covariate = obs_ds[time_dim].dt.year
times = xr.DataArray(gev_trend_period, dims="time")

In [None]:
dparams_ns = fileio.open_dataset(
    gev_params_nonstationary_file, shapefile=shapefile, shape_overlap=shape_overlap
)[var]

dparams_ns

In [None]:
dparams_ns_drop_max = fileio.open_dataset(
    gev_params_nonstationary_drop_max_file,
    shapefile=shapefile,
    shape_overlap=shape_overlap,
)[var]

dparams_ns_drop_max

# GEV parameters
### Non-stationary GEV parameters

In [None]:
eva.spatial_plot_gev_parameters(
    dparams_ns,
    dataset_name=obs_name,
    outfile=f"{info.fig_dir}/gev_parameters_{info.filestem()}.png",
)

### Non-stationary GEV parameters (excluding maximum event)

In [None]:
eva.spatial_plot_gev_parameters(
    dparams_ns_drop_max,
    dataset_name=obs_name + " (max event removed)",
    outfile=f"{info.fig_dir}/gev_parameters_drop_max_{info.filestem()}.png",
)

### Stationary GEV parameters

In [None]:
if Path(gev_params_stationary_file).exists():

    dparams_stationary = fileio.open_dataset(
        gev_params_stationary_file, shapefile=shapefile, shape_overlap=shape_overlap
    )[var]
    eva.spatial_plot_gev_parameters(
        dparams_stationary,
        dataset_name=obs_name,
        outfile=f"{info.fig_dir}/gev_parameters_stationary_{info.filestem()}.png",
    )

### Best of stationary and non-stationary GEV parameters (anomaly with respect to the non-stationary GEV parameters)

In [None]:
if Path(gev_params_best_file).exists():
    dparams_best = xr.open_dataset(gev_params_best_file)[var]
    dparams_diff = dparams_ns - dparams_best
    eva.spatial_plot_gev_parameters(
        dparams_diff,
        dataset_name=f"{obs_name} (non-stationary - best)",
        outfile=f"{info.fig_dir}/gev_parameters_best_diff_{info.filestem()}.png",
    )

### Plot GEV trend parameters

In [None]:
spatial_plots.plot_gev_param_trend(info, dparams_ns, "location")

In [None]:
spatial_plots.plot_gev_param_trend(info, dparams_ns, "scale")

## Annual reccurence of observed max/min event

In [None]:
spatial_plots.plot_obs_ari(
    info,
    obs_ds,
    None,
    dparams_ns,
    covariate_base,
    time_agg=time_agg,
)

## Annual exceedance probability 
### GEV-based exceedance probability

In [None]:
# ARI: 10, 100, 1000 years (i.e., 10% AEP is equiv to a 1-in-10-year event)
aep = 1

### Plot of 1% AEP (1-in-100-year event) using the non-stationary GEV (past year, current year and the change per decade)

In [None]:
spatial_plots.plot_aep(
    info,
    dparams_ns,
    times,
    aep=aep,
)

## Probability of breaking the observed record

In [None]:
spatial_plots.plot_new_record_probability(
    info,
    obs_ds,
    None,
    dparams_ns,
    covariate_base,
    time_agg,
    n_years=10,
)

In [None]:
# Repeat for max event removed
info_copy = info.__copy__()  # Copy InfoSet object & update names
info_copy.long_name = f"{info.long_name} (max event removed)"
info_copy.file = info_copy.file.with_name(f"{info.filestem()}_drop_max.nc")

spatial_plots.plot_new_record_probability(
    info_copy,
    obs_ds,
    None,
    dparams_ns,
    covariate_base,
    time_agg,
    n_years=10,
)