# Calculate GMSL from FAIR draws using SESL for SSP-RCPs and then map to RFF-SPs

This notebook applies the SESL model to estimate the GMSL impact of a pulse, using the 10,000-draw RFF emissions ensemble. The "baseline" GMSL value for each draw is calculated from a weighted average of the RCP-based FACTS-derived GMSL estimates used in AR6.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [10]:
import json
import pickle
import warnings
from pathlib import Path
import pathlib
from typing import Sequence, Union

import numpy as np

def fuse_to_gcsmap(path, fs=""):
    pass

settings_dict = dict(
    #############
    ## GENERAL
    #############
    # code metadata
    _PACKAGE_DIR=Path("./coastal_gmsl_inputs/other").parent,
    FS = "",
    # unit of analysis for damage regression/projection
    HIST_GEOG="state",
    PROJ_GEOG="cbsa",
    #############
    ## HAZARD
    #############
    # tracks
    GEOG="conus",
    HIST_TRACK_VERS="20210610",
    HIST_TRACK_NAME="ibtracs",
    IBTRACS_URL=(
        "https://www.ncei.noaa.gov/data/"
        "international-best-track-archive-for-climate-stewardship-ibtracs/v04r00/"
        "access/netcdf/"
    ),
    # synthetic track config
    SYNTH_TRACK_NAME="emanuel",
    SYNTH_TRACK_VERS="20220125",
    SYNTH_TRACK_VERS_HIGH="20220125_HT",
    SYNTH_REFERENCE_PERIOD=[2000, 2020],
    SYNTH_REFERENCE_SCENS=["reanal", "20th", "ssp245"],
    EMANUEL_RADIUS_RESAMPLE_SEED_INDEX=range(3),
    HAZARD_TRACKS_SYNTH_PROCESSING_VERS="v0.1",
    TRACKRADIUSMODEL_VERS="v0.1",
    # projection  ### TODO: UPDATE
    GCM_MODELS=["ccsm4", "ipsl5", "hadgem5", "mpi5", "mri5", "miroc5", "gfdl5"],
    GCM_SCENARIOS=["rcp45", "rcp85"],
    GCM_PERIODS=["2008_2025", "2030_2040", "2045_2055", "2070_2080", "2085_2095"],
    GCM_BASE="2008_2025",
    GCM_NRUNS=100,
    # reanalysis  ### TODO: UPDATE
    REANAL_MODELS=["ncep"],
    REANAL_PERIODS=["1979_1989", "2008_2018"],
    REANAL_SCENARIOS=["reanal"],
    REANAL_BASE="2008_2018",
    REANAL_NRUNS=200,
    # surge
    HAZARD_SURGE_VERS="20220718",
    SURGE_MODEL_NAME="geoclaw",
    HEIGHT_REL_MAX_WTR_VERS="v0.1",
    CL_GAUGE_VERS="v0.3",
    # wind  ### TODO: UPDATE
    HAZARD_WIND_VERS="v0.5",
    WIND_MODEL_NAME="licrice",
    # SLR
    SLR_SYNTH_DATA_VERS="20210809",
    SLR_HIST_DATA_VERS="REPLACE_WITH_ACTUAL_VALUE_WHEN_USING_HIST_DATA",
    SLR_SSP_INTERPOLATION_VERS="v0.1",
    NOAA_TIDE_HARMONICS_VERS="20220715",
    #############
    ## ELEVATION
    #############
    DEM_CATALOGUE_VERS="20210420",
    SRTM15PLUS_VERS="V2.4",
    DATUM_CONVERSION_VERS="v0.2",
    #############
    ## EXPOSURE
    #############
    EXPOSURE_BINNED_VERS="20220524",
    MAX_HEIGHT_FOR_SURGE=20,  # meters
    EXPOSURE_BIN_WIDTH_H=0.1,
    EXPOSURE_BIN_WIDTH_V=0.1,
    PROTECTED_LOCATIONS_VERS="v0.1",
    #############
    ## DAMAGE
    #############
    DAMAGE_VERS="v0.2",
    #############
    ## DOSERESPONSE
    #############
    DOSERESPONSE_VERS="experiments/br2-experiment-us",
    DOSERESPONSE_NFOLDS=5,
    #############
    # MC SAMPLING
    #############
    MC_NSAMPLES=100,
    SAMPLING_VERS="v0.2",
    #############
    ## DOSE
    #############
    DOSE_VERS="v0.1",
    DOSE_MAXS_BIN_EDGES=np.arange(0, 105, 1),
    #############
    ## GEOGRAPHY
    #############
    GLOBAL_PROTECTED_AREAS_VERS="v2.0",
)
############################################################################

##############
## SETTINGS CLASS
##############


class pyTCSettings():
    """
    Private settings class used by tests - meant to be inherrited by Settings
    """

    def __init__(self, **kwargs):

        # get all default parameters listed above
        for k, v in settings_dict.items():
            setattr(self, k, kwargs.pop(k, v))
        # also set extra kwargs
        for k, v in kwargs.items():
            setattr(self, k, v)

        self.PARAMS = {}


    @classmethod
    def from_file(cls, filepath="pytc_settings.json", **kwargs):
        with open(filepath, "r") as f:
            file_settings = json.load(f)
        file_settings.update(kwargs)

        return cls(**file_settings)

    @classmethod
    def from_pickle(cls, filepath):
        with open(filepath, "rb") as f:
            this_settings = pickle.load(f)

        return this_settings

    def write_pickle(self, outpath, overwrite=False):
        outpath = Path(outpath)
        if (not overwrite) and outpath.is_file():
            raise ValueError(
                "The file you are trying to write exists and you have not specified "
                f"overwrite=True. Please set overwrite to False or remove {outpath}"
            )
        with open(outpath, "wb") as f:
            pickle.dump(self, f, protocol=4)

    def write(
        self,
        outdir: Union[str, Path],
        kinds: Sequence = ["pickle"],
        overwrite: bool = False,
    ):
        """Write the contents of this settings object out to settings.[filetype] for
        a user-defined set of filetypes, in directory ``outdir``. This is designed to be
        called whenever intermediate data is being written out, so that a record of the
        settings used to produce that data is preserved.

        TODO: Add more precise/readable/stable ways to store settings data, rather than
        dumping the entire thing to a pickle file

        Parameters
        ----------
        outdir : str or :class:`pathlib.Path`
            Directory in which to write `settings.[filetype]` files
        kinds : Sequence of str
            The kind of files to write. Currently, only 'pickle' is allowed, in which
            the entire Settings object is dumped to pickle
        overwrite : bool
            Whether to overwrite or raise errors if the file you try to write already
            exists.

        Raises
        ------
        ValueError
            If overwrite=False and a file you are trying to write already exists.
        """
        outdir = Path(outdir)
        for k in kinds:
            outpath = outdir / f"settings.{k}"
            write_func = getattr(self, f"write_{k}")
            write_func(outpath, overwrite=overwrite)

In [11]:
import json
from pathlib import Path
import numpy as np

# place variables from the global settings file that you'd like to override in
# `pytc_settings`
pytc_settings = dict(
    GEOG="global",
    REANAL_MODELS=["era5", "ncep2"],
    REANAL_PERIODS=["1979_2019"],
    REANAL_SCENARIOS=["reanal"],
    GCM_MODELS=[
        "ccsm4",
        "ecearth",
        "gfdl5",
        "hadgem5",
        "ipsl5",
        "miroc5",
        "mpi5",
        "mri5",
    ],
    GCM_SCENARIOS=["20th", "rcp45", "rcp85"],
    GCM_PERIODS=["1999_2005", "2006_2030", "2079_2099"],
    SYNTH_REFERENCE_PERIOD=[1999, 2019],
    EMANUEL_RADIUS_RESAMPLE_SEED_INDEX=range(10),
    SAMPLING_VERS="v0.4",
    DOSE_VERS="v0.1",
    EXPOSURE_BINNED_VERS="v0.14",
    MC_NSAMPLES=1000,
)

# now define new variables that are specific to glo-co and might not exist in pyTC.
# -- Define glo-co related data and model Versions
# -- Define glo-o specific objects

gloco_settings = dict(
    ##############
    # MISC
    ##############
    GLOCO_PACKAGE_DIR=pathlib.Path().absolute() / "coastal_gmsl_inputs/other/", # "./input/
    FAIR_SCENARIOS=["ssp245", "ssp460", "ssp370"],
    CIL_COLORS_3=["#3393b0", "#ff8c00", "#ff6553"],
    CIL_COLORS_ALT_3=["#3393b0", "#ff8c00", "#880808"],
    COASTAL_COLORS_3=["#3393b0", "#a52a2a", "#696969"],
    GENERIC_COLORS_5=["#d7191c", "#fdae61", "black", "#abd9e9", "#2c7bb6"],
    INTEG_BOTTOM_CODING_GDPPC=234.235646874999,
    INTEG_ETA=2,
    ISO_DROP_NO_SOCIOECON=["ATA", "CA-", "SP-"],
    ISO_TERR_TO_CTRY_MAPPER={
        "ALA": "FIN",
        "BES": "NLD",
        "BVT": "NOR",
        "IOT": "GBR",
        "CXR": "AUS",
        "CL-": "FRA",
        "CCK": "AUS",
        "ATF": "FRA",
        "HMD": "AUS",
        "PCN": "GBR",
        "SGS": "GBR",
        "SJM": "NOR",
        "TKL": "NZL",
        "UMI": "USA",
        "KO-": "SRB",
        "VAT": "ITA",
    },
    PPP_BASELINE_YEAR=2019,
    ##############
    # SAMPLING
    ##############
    BASINS=["AL", "SI", "SP", "IO", "WP", "EP"],
    INTENSITY_THRESHOLDS=["low", "high"],
    N_BATCHES=500,  # how many "batches" to create when binning SLR draws by GMSL
    # how many batches to use in final results (subsampling from N_BATCHES)
    CLIP_BATCHES=15,
    ##############
    # PROJECTION
    ##############
    PROJ_YEAR_RANGE=[2018, 2099],
    ##############
    # DOSE RESPONSE
    ##############
    DOSERESPONSE_ALPHAS=np.concatenate([[0], np.logspace(-15, -5, 11, base=10)]),
    # DOSERESPONSE_ALPHAS=[0, 1e-5],
    DOSERESPONSE_TWEEDIE_POWER=np.arange(1, 3.1, 0.25),
    # DOSERESPONSE_TWEEDIE_POWER=[1, 1.5],
    DOSERESPONSE_PARAM_GRID_ALL={
        "only_observed": [
            False,
            True,
        ],  # whether to only include storms that were fully observed over land
        "cutoff_year": [1950, 1980, 2001],  # when to cut off regression dataset
        "covariates": [
            [],
            ["gdppc_r"],
            ["lr_wind"],
            ["gdppc_r", "lr_wind"],
        ],  # which covariates to use
        "normalization": [None, "treated", "total"],
    },
    DOSERESPONSE_PARAM_GRID_POLY={
        "power": [False, True],  # Power or polynomial
        "maxs": range(15),
        "pddi": [False, "linear"],
    },
    DOSERESPONSE_PARAM_GRID_BINNED={"maxs": [1, 2], "pddi": [False, "linear"]},
    # DOSERESPONSE_PARAM_GRID_BINNED={"maxs": [1, 2], "pddi": [False]},
    ##############
    # VERSIONS
    ##############
    SYNTH_TRACK_VERS="20201118",
    SYNTH_TRACK_VERS_HIGH="20201217",
    DIVA_VERS="20200630",
    DIVASEGSLR_VERS="v0.1",
    LITPOP_DOWNLOAD_VERS="20200714",
    LITPOP_VERS="v0.5",
    EMDAT_DOWNLOAD_VERS="20210421",
    EMDAT_VERS="v0.2",
    PWT_DOWNLOAD_VERS="20210505",
    PWT_VERS="v4.0",
    ISIMIP_DOWNLOAD_VERS="20210505",
    IIASA_INT_VERS="v4.0",
    CIA_DOWNLOAD_VERS="20201215",
    GDP_POP_GDPPC_K_VERS="v0.7",
    GEG15_VERS="v0.1",
    CIAM_VERS="v6.6.0",
    IR_VERS="v0.1",
    FAIR_RCP_VERS="v4.0_Jan212022",
    FAIR_RFF_CO2_VERS="v5.03_Feb072022",
    FAIR_RFF_CH4_VERS="v5.03_Feb072022",
    FAIR_RFF_N20_VERS="v5.03_Feb072022",
    FAIR_RFF_OUT_VERS="v5.03_Feb072022",
    FAIR_RFF_CO2_MEDIAN_VERS="v5.01_Jan72022",
    SESL_VERS="v0.1",
    DAMAGE_PROJ_VERS="v0.20",
    AR6_GMSL_VERS="v1.0",
    SLR_BINNED_VERS="v1.6.0",

)


def Settings(**kwargs):

    base_settings = pytc_settings.copy()
    base_settings.update(kwargs)
    out = pyTCSettings(**base_settings)

    # get all default parameters listed above
    for k, v in {**gloco_settings}.items():
        setattr(out, k, kwargs.pop(k, v))

    #######################################################################
    # Define glo-co specific paths/directories
    #######################################################################
    ##############
    # MISC
    ##############
    out.DIR_PARAMS_GLOCO = Path(
        kwargs.pop("DIR_PARAMS_GLOCO", out.GLOCO_PACKAGE_DIR / "params")
    )


    ################
    # HAZARD
    ################

    
    out.DIR_LOCAL = Path('./coastal_gmsl_inputs')#Path("./input/")

    out.PATH_HAZARD_SLR_GMSL_RAW_HIST = Path(
        kwargs.pop(
            "PATH_HAZARD_SLR_GMSL_RAW_HIST",
            out.DIR_LOCAL / "other/dangendorf_2019_GMSL_hist.txt",
        )
    )
    out.PATH_HAZARD_SLR_GMSL_BASELINE_INT_SYNTH = Path(
        kwargs.pop(
            "PATH_HAZARD_SLR_GMSL_BASELINE_INT_SYNTH",
            out.DIR_LOCAL / "other/v1.0.zarr/"
        )
    )
    out.DIR_HAZARD_FAIR_RFF = Path(
        kwargs.pop(
            "DIR_HAZARD_FAIR_RFF",
            out.DIR_LOCAL ,
        )
    )
    out.PATH_HAZARD_GMST_FAIR_RCP = Path(
        kwargs.pop(
            "PATH_HAZARD_GMST_FAIR_RCP",
            out.DIR_LOCAL / "other/ar6_fair162_control_pulse_2020-2030-2040-2050-2060-2070-2080_emis_conc_rf_temp_lambdaeff_emissions-driven_naturalfix_v4.0_Jan212022.nc",
        )
    )
    out.PATH_HAZARD_GMST_FAIR_RCP_MEDIAN = Path(
        kwargs.pop(
            "PATH_HAZARD_GMST_FAIR_RCP_MEDIAN",
            (
            out.DIR_LOCAL / "other/ar6_fair162_medianparams_control_pulse_2020-2080_10yrincrements_conc_rf_temp_lambdaeff_emissions-driven_2naturalfix_v4.0_Jan212022.nc"
            ),
        )
    )
    out.PATH_HAZARD_SLR_GMSL_FAIR_RFF = Path(
        kwargs.pop(
            "PATH_HAZARD_SLR_GMSL_FAIR_RFF",
            (
                f"/shares/gcp/integration/rff2/climate/ar6_rff_iter0-19_fair162_control_pulse_2020-2030-2040-2050-2060-2070-2080_gmsl_emissions-driven_naturalfix_{out.FAIR_RFF_OUT_VERS}.zarr"
            ),
        )
    )
    out.DIR_HAZARD_SLR_SESL_RAW = Path(
        kwargs.pop("DIR_HAZARD_SLR_SESL_RAW", out.DIR_LOCAL / "other")
    )

    ################
    # PARAMS
    ################
    out.DIR_GLOCO_PARAMS = kwargs.pop(
        "DIR_GLOCO_PARAMS", out.GLOCO_PACKAGE_DIR / "params"
    )
    with open(out.DIR_GLOCO_PARAMS / "sesl" / (out.SESL_VERS + ".json"), "r") as f:
        out.PARAMS["sesl"] = json.load(f)

    return out


In [12]:
# NOTE: need to `pip install pint-xarray` as it is not in dscim-epa environment

In [13]:
import dask.config
import numpy as np
import pandas as pd
import pint_xarray
import xarray as xr
import importlib  

ps = Settings()
sesl_p = ps.PARAMS["sesl"]

In [14]:
def load_dangendorf(path):
    msl_hist = pd.read_fwf(
        path,
        skiprows=1,
        usecols=[0, 1],
        names=["year", "GMSL"],
    )
    dt = pd.to_datetime(msl_hist.year.astype(int), format="%Y") + pd.to_timedelta(
        (msl_hist.year - msl_hist.year.astype(int)) * 365.25, unit="d"
    )
    msl_hist = pd.Series(msl_hist.GMSL.values, index=dt)
    msl_hist.index.name = "date"

    # center at 1995-2014 mean
    msl_hist_yr = msl_hist.resample("y").mean()
    msl_hist_yr.index = msl_hist_yr.index.year
    msl_hist_rolling = msl_hist_yr.rolling(19, center=True).mean()
    msl_hist -= msl_hist_rolling.loc[2005]

    msl_hist.name = "gmsl_rel_2005_mm"
    return msl_hist

In [None]:
from dask.distributed import Client, LocalCluster

cluster = LocalCluster()
client = Client(cluster)
client

In [15]:
FAIR_RFF_STUB = (
    "ar6_rff_fair162_control_pulse_{gas_stub}*_emis_conc_rf_temp_lambdaeff_"
    "ohc_emissions-driven_naturalfix_{version}.nc"
)
FAIR_RFF_MEDIAN_STUB = FAIR_RFF_STUB.replace("iter0-19", "medianparams")

# attrs
DESCRIPTION = "Simulations of GMSL relative to a 1991-2009 mean, from 2020 to 2500, consistent with FAIR GMST simulations "
DESCRIPTION_RCP = DESCRIPTION + "of the RCP scenarios"
DESCRIPTION_RFF = DESCRIPTION + "of the RFF emissions ensemble"

METHOD_RCP = 'A Semi-Empirical Sea Level (SESL) model (github.com/bobkopp/SESL) probabilistically converts a GMST time-series to a GMSL time series. This is applied to both control and pulse scenarios and the difference is taken. This "pulse delta" is then added to a baseline trajectory taken from GMSL simulations used in IPCC AR6 (provided via personal correspondance from Bob Kopp). Draws of SESL/FAIR model "pulse delta" and draws of AR6 "possibilistic" projections are aligned before summing by quantile-matching the 2300 GMSL projected under the control FAIR scenario with that of the AR6 projections. The parameter distribution used for the SESL model was provided via personal correspondance from Bob Kopp. All other SESL input data is taken from the github repo. The AR6 projections end in 2300. To project 2300-2500, we align the SESL control scenarios to the AR6 projections (using the previously defined quantile-matching pairs) and allow SLR to evolve based on the SESL control runs from 2300-2500. Finally, to convert from the AR6 reference period (1996-2014) to the reference period used in the rest of the coastal impacts work (1991-2009), we use a reconstruction of historical sea levels from Dangendorf et al. 2019 (https://www.nature.com/articles/s41558-019-0531-8#MOESM2). We take the means of these two periods and use that offset to adjust the GMSL projections. For the "median" runs, we use a temperature simulation from median FAIR parameters. For SESL, we take the median parameters from the two temperature reconstruction datasets (Marcott and Mann), calculate the resulting sea level values, and then take the mean of these two outputs.'

METHOD_RFF = (
    METHOD_RCP
    + """

Because the baseline trajectories are only available for the RCPs, we emulate a baseline for each RFF emissions ensemble member. We do this by taking the weighted average of the GMSL of the two bounding RCPs surrounding each RFF scenario. The ordering is determined by integrated radiative forcing from 2016 (the first year of deviation in the RCPs). This forcing is as output from FAIR. When the RFF draw falls outside the range of the RCPs included in the AR6 outputs, the GMSL from the closest RCP (in integrated forcing space) is chosen."""
)

HISTORY = """version 3.0: RCP runs. Initial model. Version starts at 3.0 to align with current version of FAIR GMST outputs.
version 3.1: RCP runs. Offsetting to a 0 GMSL in 2000 baseline (previously was 2005). This is to match the 0 point of LocalizeSL and the projections.
version 4.0: RCP runs. Correct bad AR6 baseline input due to bug bringing all scenarios to the mean in 2100 for workflow 0. Version bump occurs to keep pace with FAIR temperature version increase.
version 4.0_Jan212022: RCP runs. Model GMSL from pulses of other GHGs. Updated FACTS distributions.
version 5.0:  RFF runs. Same as v4.0 but for RFF outputs. First version to output both RFF and RCP-based GMSL datasets.
version 5.0.1:  RFF runs. Uses v5.0.1 of RFF FAIR outputs (fixed FAIR bug from v5.0 related to RFF outputs only)
version 5.02:  RFF runs. Uses v5.02 of RFF FAIR outputs. Update RFF-RCP matching algorithm to use SESL-GMSL as the index rather than radiative forcing. Also update such that if the 5 RCPs paired with the same FAIR parameters do not bound the RFF draw, search first across FAIR parameter draws and then across nearby years. See User Manual for more details.
version 5.02_Jan72022: RFF runs. Model GMSL from multiple pulse years.
version 5.02_Jan222022:  RFF runs. Model GMSL from pulses of other GHGs. Updated FACTS distributions
version 5.03_Feb072022: RFF runs with RFF-FaIR climate param pairings. Model GMSL from pulses of CO2, CH4, N2O.
"""

AUTHOR = "Ian Bolliger"
CONTACT = "ibolliger@rhg.com"

## RCP-based SESL workflow

In [16]:
def flatten_runtype(da):
    return xr.concat(
        (
            da.sel(pulse_year=2020, runtype="control", drop=True).expand_dims(
                pulse_year=[0]
            ),
            da.sel(runtype="pulse", drop=True),
        ),
        dim="pulse_year",
    )


def unflatten_runtype(da):
    control = da.sel(pulse_year=0, drop=True).expand_dims(pulse_year=da.pulse_year[1:])
    pulse = da.drop_sel(pulse_year=0)
    return xr.concat(
        [control, pulse], dim=pd.Index(["control", "pulse"], name="runtype")
    )


def quantile_map_sesl_and_baseline(baseline, sesl_sl):

    # get ordering of SESL predictions for no-pulse scenario
    order = sesl_sl.simulation.sortby(sesl_sl.squeeze(drop=True))

    # get quantiles that we want to match between ar6 baselines and SESL projections
    n_samples = len(order)
    quantile_bounds = np.linspace(0, 1, n_samples + 1)
    quantiles = (quantile_bounds[:-1] + quantile_bounds[1:]) / 2
    this_baseline = baseline.quantile(q=quantiles, dim="sample").rename(
        quantile="simulation"
    )
    this_baseline["simulation"] = order
    return this_baseline.sortby("simulation")


def get_bound_wts(trg_vals, src_vals, dim="rcp", year=None):

    # find the bounding rcp's and take a weighted average. If no bounding, just take value
    # from closest rcp
    diff = trg_vals - src_vals
    lb = diff.where(diff >= 0, np.inf)
    no_lb = np.isinf(lb).all(dim=dim)
    lb = lb.idxmin(dim).where(~no_lb, src_vals.idxmin(dim).broadcast_like(no_lb))
    ub = diff.where(diff <= 0, -np.inf)
    no_ub = np.isinf(ub).all(dim=dim)
    ub = ub.idxmax(dim).where(~no_ub, src_vals.idxmax(dim).broadcast_like(no_ub))

    assert not (no_lb & no_ub & trg_vals.notnull()).any()

    lb_val = src_vals.sel({dim: lb}, drop=True)
    ub_val = src_vals.sel({dim: ub}, drop=True)

    full_range = ub_val - lb_val
    assert (ub_val.where(full_range == 0, 0) == lb_val.where(full_range == 0, 0)).all()

    ub_wt = ((trg_vals - lb_val) / full_range).where(full_range != 0, 1)

    return xr.Dataset({"lb": lb, "ub": ub, "ub_wt": ub_wt})


def quantile_map_rff(trg_vals, src_vals, baseline_rcp, wt_ds, dim="runid"):
    lb_vals = []
    ub_vals = []
    for this_sp in trg_vals[dim]:
        this_trg_vals = trg_vals.sel({dim: this_sp})
        for lst, da in [(lb_vals, wt_ds.lb), (ub_vals, wt_ds.ub)]:
            this_src_vals = src_vals.sel(rcp=da.sel({dim: this_sp}), drop=True)
            this_baseline_rcp = baseline_rcp.sel(rcp=da.sel({dim: this_sp}), drop=True)
            this_sim = np.abs(this_trg_vals - this_src_vals).idxmin("simulation")
            lst.append(this_baseline_rcp.sel(simulation=this_sim))

    lb_vals = xr.concat(lb_vals, dim=dim)
    ub_vals = xr.concat(ub_vals, dim=dim)
    
    print("this_trg_vals", this_trg_vals.coords)
    print("this_src_vals", this_src_vals.coords)
    print("this_baseline_rcp", this_baseline_rcp.coords)
    print("this_sim", this_sim.coords)
    
    return wt_ds.ub_wt * ub_vals + (1 - wt_ds.ub_wt) * lb_vals 

In [17]:
from math import ceil

import numpy as np
import pandas as pd
import xarray as xr

from math import ceil, floor
from typing import Sequence, Union

import numpy as np
import pandas as pd
import xarray as xr
from scipy.linalg import toeplitz
from scipy.io import loadmat


def calc_temp(
    data: xr.Dataset,
    T_err: Union[str, None],
    T_num: int,
    n_samples: int,
    tau_ar1: float = 10,
    random_state: Union[int, None] = 0,
) -> xr.DataArray:
    """Simulate draws of historical temperature time series based on AR1 processes.

    Parameters
    ----------
    data : :class:`xarray.Dataset`
        Output of ``load_data_SESL`` function. Contains mean estimate and standard
        deviation for each year.
    T_err : str or None,
        Approach to simulating the time series using the standard deviation. Currently,
        only ``ar1ts`` is supported:
            - ``ar1ts``: AR(1) Parameter timescale == exp(-abs(t2-t1)/timescale)
            - ``ar1``: T as AR(1) process with sigma as "default"
            - ``default``: T + random noise as in KE11
            - ``no``: Don't add uncertainty
    T_num : int
        Number of simulated time series to create
    n_samples : int
        Number of draws of the parameter posteriors
    tau_ar1 : float, optional
        If ``T_err==ar1ts``, this is the ``timescale`` parameter. Otherwise, ignored.
    random_state : int, optional
        If set, controls the random state for the :function:`numpy.random.default_rng`
        function used to generate time series draws. If None, will result in
        non-deterministic outputs

    Returns
    -------
    :class:`xarray.DataArray`
        Contains ``T_num`` sims of historical annual mean GMST values.
    """
    rng = np.random.default_rng(random_state)
    if T_err == "ar1ts":
        T = data.T
        err_vec = T.sel(kind="err").values
        err_sq = np.expand_dims(err_vec, 1) * np.expand_dims(err_vec, 0)
        yr_vec = T.T_year.values
        yr_vec_neg_diff_norm = -np.abs(
            (np.expand_dims(yr_vec, 1) - np.expand_dims(yr_vec, 0)) / tau_ar1
        )
        cov_ar1 = err_sq * np.exp(yr_vec_neg_diff_norm)
        sims = rng.multivariate_normal(
            T.sel(kind="val"), cov_ar1, size=(T_num, n_samples)
        )
        return xr.DataArray(
            sims.T,
            coords={
                "T_sim_id": np.arange(sims.shape[0]),
                "year": T.T_year.values,
                "sample": np.arange(n_samples),
            },
            dims=["year", "sample", "T_sim_id"],
        )
    elif T_err == "no":
        return data.T.sel(kind="val").rename(T_year="year")
    raise NotImplementedError


def calc_T0(
    T_sims: xr.Dataset,
    historical_data: xr.Dataset,
    params: xr.Dataset,
    optim_T0: bool,
    model: str,
    T0_period_end: int = -1800,
) -> xr.Dataset:
    """Calculate draws of ``T0`` parameter from historical temperature draws.

    Parameters
    ----------
    T_sims : :class:`xarray.Dataset`
        Output of :func:`calc_temp`. Contains draws of historical temps.
    historical_data : :class:`xarray.Dataset`
        Output of :func:`pySESL.io.load_data_SESL`. Contains mean and SDs for temp and
        sea level reconstructions.
    params : :class:`xarray.Dataset`
        Output of :func:`pySESL.io.load_params`. Contains posterior distributions of
        trained SESL model parameters.
    optim_T0 : bool
        Whether to use the optimized T0(0) posterior distribution from ``params``
    model : "CRdecay", "ConstRate", "CRovTau", "TwoTau", or "simpel"
        Which model was used to train SESL model and generate ``params``.
    T0_period_end : int, optional
        Ending year of period used to calculate an initial T0(0). Default -1800.

    Returns
    -------
    :class:`xarray.Dataset`
        Contains the T0 parameter for each year and for each of the sims of historical
        GMST contained in ``T_sims``
    """
    if optim_T0:
        T0_rnd = params.T01
    else:
        T0_rnd = 0
    tau1 = params.tau
    # tau2 = params.tau_c

    n_yrs = len(T_sims.year)

    def toepify(arr):
        return toeplitz(arr, np.concatenate((arr[:1], np.zeros(len(arr) - 1))))

    # if use_Mar_T0 was used
    if "T0burnin" in historical_data.data_vars:
        n_burnin = historical_data.T0burnin.item()
        yrs1 = historical_data.T_year[1].item() - historical_data.T_year[0].item()
        yrs2 = historical_data.T_year[-1].item() - historical_data.T_year[-2].item()
        tau1_1 = tau1 / yrs1
        tau1_2 = tau1 / yrs2
        tau1_ = xr.concat([tau1_1, tau1_2], dim=pd.Index(["T0", "T"], name="T_type"))
        G = ((1 - 1 / tau1_) * xr.ones_like(T_sims.year)) ** xr.DataArray(
            np.arange(n_yrs), dims=["year"]
        )
        G_M = xr.apply_ufunc(
            toepify,
            G,
            input_core_dims=[["year"]],
            output_core_dims=[["year", "year2"]],
            vectorize=True,
        )

        GM_T0 = G_M.sel(T_type="T0").isel(year2=slice(None, n_burnin))
        GM_T = G_M.sel(T_type="T").isel(year2=slice(n_burnin, None))
        G_M1 = xr.concat((GM_T0, GM_T), dim="year2")

        temp_1_T0 = T_sims.isel(year=slice(None, n_burnin)) / tau1_1
        temp_1_T = T_sims.isel(year=slice(n_burnin, None)) / tau1_2
        temp_1 = xr.concat((temp_1_T0, temp_1_T), dim="year")
        temp_1[{"year": 0}] = (
            T_sims.isel(year=(historical_data.T_year <= T0_period_end).values).mean(
                "year"
            )
            + T0_rnd
        )
    else:
        raise NotImplementedError

    if model == "TwoTau":
        raise NotImplementedError

    T01 = xr.dot(G_M1, temp_1.rename(year="year2"), dims=["year2"])
    return T01


def calc_sl(
    T_sims: xr.DataArray,
    T0_sims: xr.DataArray,
    params: xr.Dataset,
    model: str,
    period: Sequence,
    interp_method: str = "nearest",
) -> tuple:
    """Calculate draws of sea level and ``c`` parameter using historical temperature
    draws.

    Parameters
    ----------
    T_sims : :class:`xarray.Dataset`
        Output of :func:`calc_temp`. Contains draws of historical temps.
    T0_sims : :class:`xarray.Dataset`
        Output of :func:`calc_T0`. Contains T0 associated with historical T draws.
    params : :class:`xarray.Dataset`
        Output of :func:`pySESL.io.load_params`. Contains posterior distributions of
        trained SESL model parameters.
    model : "CRdecay", "ConstRate", "CRovTau", "TwoTau", or "simpel"
        Which model was used to train SESL model and generate ``params``.
    period : length-2 array-like
        Period of data to include in results
    interp_method : str, optional
        Interpolation method used to annualize ``T_sims`` and ``T0_sims`` variables.

    Returns
    -------
    sea : :class:`xarray.DataArray`
        Sea Level by year for each parameter sample X temperature reconstruction sample
    dsea : :class:`xarray.DataArray`
        Annual change in sea Level by year for each parameter sample X temperature
        reconstruction sample
    c : :class:`xarray.DataArray`
        Value of ``c`` parameter by year for each parameter sample
    T_sims, T0_sims : :class:`xarray.DataArray`
        Same as the input ``T_sims`` and ``T0_sims`` but interpolated to annual values
        using ``interp_method`` and clipped to range bounded by ``period`` and
        ``calibperiod``
    """

    if model != "CRdecay":
        raise NotImplementedError

    T_sims, T0_sims = resize_T(period, T_sims, T0_sims, interp_method=interp_method)
    g = 1 - 1 / params.tau_c
    G = g ** xr.DataArray(
        np.arange(T_sims.year.size), coords={"year": T_sims.year}, dims=["year"]
    )
    c = params.c * G
    dsea = c + params.a * (T_sims - T0_sims)
    sea = dsea.cumsum("year")

    return sea, dsea, c, T_sims, T0_sims


def resample_ics(ics, sim_ids, sesl_trained_params):
    """Resample ICs such that they have the same number of samples as are in the
    temperature projections ``temps``.

    TODO: finish docstring
    """
    # get T0_2000, T_ref in index of FAIR samples
    def resample_full(ds):
        return ds.stack(simulation=["T_sim_id", "sample", "T_data"]).isel(
            simulation=slice(None, len(sim_ids))
        )

    T0_2000, T_ref = list(map(resample_full, [ics.T0_2000, ics.T_ref]))
    assert (T0_2000.simulation == T_ref.simulation).all()

    # get the appropriate parameters for each of the 3k sims
    def resample_partial(ds):
        out = ds.stack(simulation=["sample", "T_data"]).sel(
            simulation=pd.MultiIndex.from_arrays(
                (T0_2000.sample.values, T0_2000.T_data.values),
                names=["sample", "T_data"],
            )
        )
        if type(out) == xr.core.dataset.Dataset:
            arrays = []
            for i in list(out.keys()):
                t1 = out[i]
                t1["simulation"] = sim_ids
                arrays = arrays + [t1.copy()]
            out = xr.merge(arrays)
        else:
            out["simulation"] = sim_ids
        return out

    param_sims, c_2000 = list(map(resample_partial, [sesl_trained_params, ics.c_2000]))
    T0_2000["simulation"] = sim_ids
    T_ref["simulation"] = sim_ids
    out_params = xr.merge((T0_2000, c_2000, T_ref, param_sims))
    return out_params

def get_ics(n_fair_sims, sesl_trained_params, sesl_hyperparams, sesl_input_dir):
    """Get initial conditions T0_2000 and c_2000 necessary for projecting using SESL.
    Also return T_ref, or the mean temperature over the reference period as defined in
    ``sesl_hyperparams``.

    TODO: finish docstring
    """
    # figure out how many historical temp draws to use
    n_sesl_samps = len(sesl_trained_params.sample)
    n_sesl_data_samps = len(sesl_trained_params.T_data)

    T_num = ceil(n_fair_sims / n_sesl_samps / n_sesl_data_samps)

    T0_2000 = []
    T_ref = []
    c_2000 = []

    T_ref_range = np.arange(
        sesl_hyperparams["T_bias_correction_period"][0],
        sesl_hyperparams["T_bias_correction_period"][1] + 1,
    )

    for dat in sesl_hyperparams["T_data"]:
        historical_data = load_data_SESL(
            sesl_input_dir / (sesl_hyperparams["SL_data"] + ".mat"),
            sesl_input_dir / (dat + ".mat"),
            sesl_hyperparams["use_cov"],
            sesl_hyperparams["use_Mar_T0"],
            Mar_fpath=sesl_input_dir / "Marcott13_RegEM-HC3_20.mat",
            T_err_sc=sesl_hyperparams["T_err_sc"],
            cov_tau=sesl_hyperparams["cov_tau"],
            no_neg_cov=sesl_hyperparams["no_neg_cov"],
            baseperiod=sesl_hyperparams["baseperiod"],
            T0_temp_level=sesl_hyperparams["T0_temp_level"],
            T0_period_st=sesl_hyperparams["T0_period"][0],
        )

        T_sims = calc_temp(
            historical_data,
            sesl_hyperparams["T_err"],
            T_num,
            sesl_trained_params.dims["sample"],
            tau_ar1=sesl_hyperparams["tau_ar1"],
        )

        T_ref.append(T_sims.interp(year=T_ref_range).mean("year"))

        if dat[:4] == "Mann":
            dat_short = "Mn"
        elif dat[:4] == "Marc":
            dat_short = "Mar"
        else:
            raise NotImplementedError(dat)
        T0_sims = calc_T0(
            T_sims,
            historical_data,
            sesl_trained_params.sel(T_data=dat_short, drop=True),
            sesl_hyperparams["optim_T0"],
            sesl_hyperparams["model"],
            T0_period_end=sesl_hyperparams["T0_period"][1],
        )
        T0_2000.append(T0_sims.interp(year=2000).drop("year"))

        _, _, c, _, _ = calc_sl(
            T_sims,
            T0_sims,
            sesl_trained_params.sel(T_data=dat_short, drop=True),
            sesl_hyperparams["model"],
            [
                min(sesl_hyperparams["period"][0], sesl_hyperparams["calibperiod"][0]),
                max(sesl_hyperparams["period"][1], sesl_hyperparams["calibperiod"][1]),
            ],
        )

        c_2000.append(c.sel(year=2000, drop=True))

    dim = pd.Index(sesl_hyperparams["T_data"], name="T_data")
    T0_2000, c_2000, T_ref = list(
        map(
            lambda x: xr.concat(x, dim=dim),
            [T0_2000, c_2000, T_ref],
        )
    )

    new_T_data = T0_2000.T_data.str[:3]
    new_T_data = new_T_data.where(new_T_data == "Mar", "Mn")
    T0_2000["T_data"] = new_T_data
    c_2000["T_data"] = new_T_data
    T_ref["T_data"] = new_T_data

    return xr.Dataset({"T0_2000": T0_2000, "c_2000": c_2000, "T_ref": T_ref})


def bias_correct_temps(temps, bc_period, T_ref, first_year=None):
    """Bias correct a temperature dataset such that it matches with the reference period
    used to calculate the T0_2000 initial condition.

    TODO: finish docstring
    """
    return (temps - temps.sel(year=slice(*bc_period)).mean("year") + T_ref).sel(
        year=slice(first_year, None)
    )


def project_sesl(temps, params):
    """Project GMSL given input temperatures and params (including initial conditions).
    Note that ``temps`` must already be corrected to have the same reference period as
    ``params.T0_2000``.

    TODO: finish docstring
    """
    temp_arr = temps.copy()
    sl = 0 * temp_arr

    a = params.a
    tau = params.tau
    tau_c = params.tau_c

    T0 = params.T0_2000.broadcast_like(temp_arr.isel(year=0, drop=True)).copy()
    c = params.c_2000.broadcast_like(tau_c).copy()

    # iterate over years
    for yr in sl.year:
        TminusT0 = temp_arr.sel(year=yr, drop=True) - T0

        # update SL
        if yr > sl.year[0]:
            sl.loc[{"year": yr}] = sl.sel(year=yr - 1, drop=True) + a * TminusT0 + c

        # update T0
        T0 += 1 / tau * TminusT0

        # update c
        c *= 1 - 1 / tau_c

    # convert to DataArray
    return sl

def resize_T(
    period: Sequence, *das: xr.DataArray, interp_method: str = "nearest"
) -> Sequence:
    """Interpolate DataArrays of values at (potentially varying) time intervals to
    annual time series, and clip them

    Parameters
    ---------
    period : length-2 array-like
        Starting and ending values for desired period of output DataArrays
    interp_method : str, optional
        Interpolation method to use to annualize inputs. Default is "nearest".

    Returns
    -------
    tuple
        Tuple of DataArrays of same length as ``das``, interpolated and clipped
    """

    da = das[0]

    fyr = max(da.year[0].item(), period[0])
    lyr = period[1]

    out_range = da.year.isel(year=(da.year >= fyr) & (da.year <= lyr))
    diffs = out_range.diff("year")
    yrs_st = diffs[0]
    yrs_end = diffs[-1]
    yrs_out = np.arange(
        out_range[0] - floor((yrs_st - 1) / 2),
        out_range[-1] + ceil((yrs_end - 1) / 2) + 1,
    )

    return list(
        map(
            lambda x: x.interp(
                year=yrs_out,
                method=interp_method,
                kwargs={"fill_value": "extrapolate"},
            ),
            das,
        )
    )


def load_data_SESL(
    sl_fpath: Union[str, Path],
    T_fpath: Union[str, Path],
    use_cov: bool,
    use_Mar_T0: bool,
    Mar_fpath: Union[str, Path, None] = None,
    T_err_sc: float = 1,
    cov_tau: float = 100,
    no_neg_cov: bool = True,
    baseperiod: Sequence[int] = [1400, 1800],
    T0_temp_level: float = 100,
    T0_period_st: int = -2000,
) -> xr.Dataset:
    """Load historical temperature and sea level reconstructions

    Parameters
    ----------
    sl_fpath : str or :class:`pathlib.Path`
        Path to sea level reconstruction input ``.mat`` file.
    T_fpath : str or :class:`pathlib.Path`
        Path to temperature reconstruction input ``.mat`` file.
    use_cov : bool
        If True, use covariance matrix of SL reconstruction data (if existing) to
        estimate likelihood of parameter set.
    use_Mar_T0 : bool
        If True, use Marcott long-running temperature reconstruction to calculate ``T0``
        value until reconstruction at ``T_fpath`` starts.
    Mar_fpath : str or :class:`pathlib.Path` or None, optional
        Path to Marcott sea level reconstruction input ``.mat`` file. Only used if
        ``use_Mar_T0`` is True.
    T_err_sc : float, optional
        Scaling factor for temperature error uncertainty.
    cov_tau : float, optional
        Time scale for covariance. If not null, take the elementwise product of the
        covariance and a tapering function exp(-delta(t)/cov_tau). Only used if
        ``use_cov`` is True.
    no_neg_cov : bool, optional
        Bound covariance matrix to be non-negative. Default True.
    baseperiod : array-like, optional
        Reference period used for sea level data. Data are normed to have 0 mean over
        this period. Default [1400, 1800].
    T0_temp_level : int, optional
        If ``use_Mar_T0`` is True, number of years over which to harmonize the mean of
        the Marcott T time series and time series at ``T_fpath`` in order to calculate
        T0 from Marcott.
    T0_period_st : int, optional
        Starting year of period used to calculate an initial T0(0).

    Returns
    -------
    :class:`xarray.Dataset`
        Contains the processed estimated value and error for the temperature
        reconstruction at ``T_fpath``, the sea level reconstruction at ``sl_fpath``, and
        the derived T0 timeseries using ``sl_fpath`` and (optionally) the long-running
        Marcott reconstruction
    """

    # load SL proxy data
    sl_data = loadmat(sl_fpath, squeeze_me=True)
    sl = sl_data["sl"]
    proxy_sl = pd.DataFrame(
        {
            "val": (sl[:, 1] / 10).astype(np.float64),
            "err": (sl[:, 2] / 10).astype(np.float64),
        },
        index=pd.Index(sl[:, 0].astype(np.int16), name="year"),
    )
    C = (sl_data["C"] / 100).astype(np.float64)
    C += np.eye(len(C)) * np.finfo(C.dtype).eps

    if use_cov:
        if cov_tau is not None:
            Csc = np.exp(
                -np.abs(
                    np.expand_dims(proxy_sl.index.values, 0)
                    - np.expand_dims(proxy_sl.index.values, 1)
                )
                / cov_tau
            )
            C *= Csc
        else:
            raise NotImplementedError
        if no_neg_cov:
            C = np.maximum(C, 0)

    # rebase proxy SL data to base period
    proxy_sl["val"] -= proxy_sl.loc[baseperiod[0] : baseperiod[1], "val"].mean()

    # convert to long format
    proxy_sl = proxy_sl.stack()
    proxy_sl.index = proxy_sl.index.rename("kind", level=-1)
    proxy_sl.name = "sl"

    # load T reconstruction data
    T = loadmat(T_fpath, squeeze_me=True)["T"]
    T = pd.DataFrame(
        T[:, 1:3],
        columns=["val", "err"],
        index=pd.Index(T[:, 0], name="year").astype(np.int16),
    )

    # assert common timestep
    dyr = np.diff(T.index)
    assert len(np.unique(dyr)) == 1
    dyr = dyr[0]

    # scale by predefined scaling factor
    T["err"] *= T_err_sc

    # convert to long format
    T_long = T.stack()
    T_long.index = T_long.index.rename("kind", level=-1)
    T_long.name = "T"

    # aggregate into Dataset
    data = xr.merge(
        (
            proxy_sl.to_xarray().rename(year="sl_year"),
            T_long.to_xarray().rename(year="T_year"),
        )
    )

    # Use Mar data for early T values if using for initializing T0
    if use_Mar_T0:
        T_mar = loadmat(Mar_fpath)["T"]
        T_mar = pd.DataFrame(
            T_mar[:, 1:],
            columns=["val", "err"],
            index=pd.Index(T_mar[:, 0], name="year").astype(np.int16),
        )
        T_mar_overlap_mean = T_mar.loc[
            T.index.min() : T.index.min() + T0_temp_level, "val"
        ].mean()
        T_overlap_mean = T.loc[: T.index.min() + T0_temp_level, "val"].mean()
        T_mar["val"] = T_mar["val"] - T_mar_overlap_mean + T_overlap_mean
        T_mar = T_mar.loc[: T.index.min() - int((dyr - 1) / 2)]

        T0_temp = pd.concat((T_mar, T))

        # only care about part after beginning of burnin period
        T0_temp = T0_temp.loc[T0_period_st:]
        T0burnin = (T0_temp.index < T.index.min()).sum()

        # convert to long format
        T0_temp = T0_temp.stack()
        T0_temp.index = T0_temp.index.rename("kind", level=-1)
        T0_temp.name = "T"

        data = data.drop(["T", "T_year"]).assign(
            {"T": T0_temp.to_xarray().rename(year="T_year")}
        )
        data["T0burnin"] = T0burnin

    C = xr.DataArray(
        C,
        dims=["sl_year", "sl_year_cov"],
        coords={"sl_year": data.sl_year.values, "sl_year_cov": data.sl_year.values},
        name="sl_C",
    )
    data = xr.merge((data, C))
    return data


def _load_params_from_struct(struct):
    """Load SESL parameter posterior distributions from a MATLAB struct."""
    return pd.DataFrame(
        {
            "a": struct["a"].item(),
            "c": struct["c"].item(),
            "tau": struct["tau"].item(),
            "tau_c": struct["tau_c"].item(),
            "T01": struct["T01"].item(),
        },
        index=pd.Index(np.arange(len(struct["a"].item())), name="sample"),
    ).to_xarray()

def load_param_file(fpath: str) -> xr.Dataset:
    """Load posterior parameter distribution from a trained SESL model (run in the
    MATLAB version of the codebase).
    """
    data = loadmat(fpath, squeeze_me=True)["P"]
    mar = data["Mar"].item()
    mn = data["Mn"].item()
    return xr.concat(
        [_load_params_from_struct(struct) for struct in [mar, mn]],
        dim=pd.Index(["Mar", "Mn"], name="T_data"),
    )


In [18]:
# get temp projections from FAIR
with open(Path(ps.PATH_HAZARD_GMST_FAIR_RCP), "rb") as f:
    fair_temps_rcp = flatten_runtype(
        xr.open_dataset(f).temperature.drop("scalar").load()
    )
    
with open(Path(ps.PATH_HAZARD_GMST_FAIR_RCP_MEDIAN), "rb") as f:
    fair_temps_rcp_med = flatten_runtype(
        xr.open_dataset(f).temperature.load()
    ).expand_dims(simulation=1)
# Load posterior of SESL parameters (pre-trained)
with open(
    Path(ps.DIR_HAZARD_SLR_SESL_RAW / "Parameters.mat"), "rb"
) as f:
    params = load_param_file(f)
params_med = params.median("sample").expand_dims(sample=1)

# Get T0 initial condition for SESL projections
ics_rcp = get_ics(
    len(fair_temps_rcp.simulation), params, sesl_p, ps.DIR_HAZARD_SLR_SESL_RAW
)
ics_rcp_med = get_ics(1, params_med, sesl_p, ps.DIR_HAZARD_SLR_SESL_RAW)

# Resample the SESL initial condition and bias correction factors to match up with the fair simulations
param_sims_rcp = resample_ics(ics_rcp, fair_temps_rcp.simulation, params)
param_sims_rcp_med = resample_ics(
    ics_rcp_med, fair_temps_rcp_med.simulation, params_med
)
fair_temps_rcp = bias_correct_temps(
    fair_temps_rcp,
    sesl_p["T_bias_correction_period"],
    param_sims_rcp.T_ref,
    first_year=2000,
)
fair_temps_rcp_med = bias_correct_temps(
    fair_temps_rcp_med,
    sesl_p["T_bias_correction_period"],
    param_sims_rcp_med.T_ref,
    first_year=2000,
)

# run sim
sl_rcp = project_sesl(fair_temps_rcp, param_sims_rcp)
sl_rcp_med = project_sesl(fair_temps_rcp_med, param_sims_rcp_med)

## Merge in AR6 baselines

### Load

In [19]:
baselines = (
    xr.open_zarr(ps.PATH_HAZARD_SLR_GMSL_BASELINE_INT_SYNTH, chunks=None)
    .sea_level_change.pint.quantify()
    .pint.to("cm")
    .pint.dequantify()
    .sel(workflow=["wf_1f", "wf_2f"])
    .dropna("rcp", how="all")
    .stack(sample=["workflow", "samples"])
    .rename(years="year")
)

baselines["sample"] = np.arange(len(baselines.sample))
baselines = baselines.interp(year=np.arange(baselines.year[0], baselines.year[-1] + 1))
final_year = baselines.year.max().item() 

### Transform to a year 2000 baseline

AR6 projections use 1996-2014 datum, but all of our projected damages refer to slr above a 1991-2009 mean datum. So we use historical GMSL estimates from [Dangendorf et al. 2019 (Supplementary Data 1)](https://www.nature.com/articles/s41558-019-0531-8#MOESM2) to adjust the datums.

In [20]:
msl_hist = load_dangendorf(ps.PATH_HAZARD_SLR_GMSL_RAW_HIST)
msl_hist_yr = msl_hist.resample("y").mean()
msl_hist_yr.index = msl_hist_yr.index.year
msl_hist_rolling = msl_hist_yr.rolling(19, center=True).mean()
offset_05_to_00 = (msl_hist_rolling[2005] - msl_hist_rolling[2000]) / 10
baselines += offset_05_to_00 

In [21]:
# get baseline for median sample
baselines_med = baselines.median("sample")
extra_sims_med = sl_rcp_med.sel(year=slice(final_year, None)).squeeze(drop=True)
baselines_med = xr.concat(
    (
        baselines_med.isel(year=slice(None, -1)),
        extra_sims_med.sel(pulse_year=0, drop=True)
        + baselines_med.sel(year=final_year, drop=True)
        - extra_sims_med.sel(pulse_year=0, year=final_year, drop=True),
    ),
    dim="year",
) 

### Do quantile mapping to match SESL and AR6 runs

In [22]:
# quantile map within each year
sl_rcp_chunked = (
    sl_rcp.sel(pulse_year=0, year=baselines.year, rcp=baselines.rcp)
    .drop("pulse_year")
    .chunk({"rcp": 1, "year": 1})
)
baseline_rcp = (
    baselines.chunk({"rcp": 1, "year": 1})
    .map_blocks(
        quantile_map_sesl_and_baseline, (sl_rcp_chunked,), template=sl_rcp_chunked
    )
    .load()
) 

## Interpolate to missing rcp460

In [None]:
sl_rcp_chunked_to_interp = (
    sl_rcp.sel(pulse_year=0, year=baselines.year)
    .drop_sel(rcp=baselines.rcp)
    .drop("pulse_year")
    .chunk({"rcp": 1, "year": 1})
    .rename(rcp="tmp")
)
rcp_wt_ds = sl_rcp_chunked_to_interp.map_blocks(
    get_bound_wts,
    (sl_rcp_chunked.chunk({"rcp": -1}),),
    template=xr.Dataset(
        {
            "lb": sl_rcp_chunked_to_interp.astype(object),
            "ub": sl_rcp_chunked_to_interp.astype(object),
            "ub_wt": sl_rcp_chunked_to_interp,
        }
    ),
).load()

baseline_rcp_extra = quantile_map_rff(
    sl_rcp_chunked_to_interp.rename(simulation="iter").load(),
    sl_rcp_chunked.load(),
    baseline_rcp,
    rcp_wt_ds.rename(simulation="iter"),
    dim="iter",
).rename(iter="simulation", tmp="rcp")

baseline_rcp = xr.concat((baseline_rcp, baseline_rcp_extra), dim="rcp").sel(
    rcp=sl_rcp.rcp
)
interpolated = sl_rcp_chunked_to_interp.tmp.values 

In [18]:
dim = 'rcp'
src_vals = sl_rcp_chunked.chunk({"rcp": -1}).sel(year = 2050)
trg_vals = sl_rcp_chunked_to_interp.sel(year = 2050)
diff = trg_vals - src_vals
lb = diff.where(diff >= 0, np.inf)
no_lb = np.isinf(lb).all(dim=dim)
lb = lb.idxmin(dim).where(~no_lb, src_vals.idxmin(dim).broadcast_like(no_lb))
ub = diff.where(diff <= 0, -np.inf)
no_ub = np.isinf(ub).all(dim=dim)
ub = ub.idxmax(dim).where(~no_ub, src_vals.idxmax(dim).broadcast_like(no_ub))
assert not (no_lb & no_ub & trg_vals.notnull()).any()
lb_val = src_vals.sel({dim: lb}, drop=True)
ub_val = src_vals.sel({dim: ub}, drop=True)

full_range = ub_val - lb_val
assert (ub_val.where(full_range == 0, 0) == lb_val.where(full_range == 0, 0)).all()

ub_wt = ((trg_vals - lb_val) / full_range).where(full_range != 0, 1)


In [None]:
sl_rcp_chunked_med = (
    sl_rcp_med.sel(pulse_year=0, year=baselines_med.year)
    .drop("pulse_year")
    .chunk({"year": 1})
)
sl_rcp_chunked_med_to_interp = (
    sl_rcp_chunked_med.drop_sel(rcp=baselines_med.rcp)
    .rename(rcp="tmp")
    .chunk({"tmp": 1})
)
sl_rcp_chunked_med = sl_rcp_chunked_med.sel(rcp=baselines_med.rcp)

rcp_wt_ds_med = sl_rcp_chunked_med_to_interp.map_blocks(
    get_bound_wts,
    (sl_rcp_chunked_med,),
    template=xr.Dataset(
        {
            "lb": sl_rcp_chunked_med_to_interp.astype(object),
            "ub": sl_rcp_chunked_med_to_interp.astype(object),
            "ub_wt": sl_rcp_chunked_med_to_interp,
        }
    ),
).load()
baseline_rcp_extra_med = (
    rcp_wt_ds_med.ub_wt * baselines_med.load().sel(rcp=rcp_wt_ds_med.ub, drop=True)
    + (1 - rcp_wt_ds_med.ub_wt) * baselines_med.sel(rcp=rcp_wt_ds_med.lb, drop=True)
).rename(tmp="rcp")
baselines_med = xr.concat((baselines_med, baseline_rcp_extra_med), dim="rcp").sel(
    rcp=sl_rcp.rcp
)

# Add on the post-2300 years
By bias-correcting SESL forecasts to FACTS in 2300

In [20]:
extra = sl_rcp.sel(pulse_year=0, year=slice(final_year + 1, None)).drop("pulse_year")
comp = sl_rcp.sel(pulse_year=0, year=final_year, drop=True)
baseline_rcp = xr.concat(
    (baseline_rcp, extra + baseline_rcp.isel(year=-1, drop=True) - comp), dim="year"
)

In [21]:
# add delta from SESL to baseline
out_rcp = xr.concat(
    (
        baseline_rcp.expand_dims(pulse_year=[0]),
        baseline_rcp
        + sl_rcp.drop_sel(pulse_year=0)
        - sl_rcp.sel(pulse_year=0, drop=True),
    ),
    dim="pulse_year",
)

out_rcp_med = xr.concat(
    (
        baselines_med.expand_dims(pulse_year=[0]),
        baselines_med
        + sl_rcp_med.drop_sel(pulse_year=0)
        - sl_rcp_med.sel(pulse_year=0, drop=True),
    ),
    dim="pulse_year",
).squeeze()

out_rcp = xr.Dataset({"gmsl": out_rcp, "gmsl_median": out_rcp_med})

In [22]:
del baselines

## RFF-based SESL workflow

In [None]:
fair_rff = []
for gas_stub, version in [
    ("CO2_Fossil", ps.FAIR_RFF_CO2_VERS),
    ("CH4", ps.FAIR_RFF_CH4_VERS),
    ("N2O", ps.FAIR_RFF_N20_VERS),
]:
    print(version)
    this = xr.open_mfdataset(
        ps.DIR_HAZARD_FAIR_RFF.glob(
            FAIR_RFF_STUB.format(gas_stub=gas_stub, version=version)
        ),
        chunks={"runid": 5000},
    )[["temperature", "climate_param_index"]]
    fair_rff.append(flatten_runtype(this.temperature))

fair_rff = xr.concat(
    fair_rff, dim=pd.Index(["CO2_Fossil", "CH4", "N2O"], name="gas")
).to_dataset(name="temperature")
fair_rff["climate_param_index"] = this.climate_param_index.isel(pulse_year=0, drop=True)
fair_rff = fair_rff.expand_dims({"iter" : [1]}).persist()

# get temp projections from FAIR
fair_temps_rff = fair_rff.temperature
fair_params_rff = fair_rff.climate_param_index.load()

In [25]:
CROSSWALK = xr.open_dataset(
'./coastal_gmsl_inputs/other/rffsp_fair_sequence.nc'
)

In [None]:
# Get T0 initial condition for SESL projections

# Option 1: Randomly assign SESL params - Captures more of the distribution, but not
# matched to FAIR params in same way as we did for RCP model.
# ics_rff = get_ics(
#     len(fair_temps_rff.simulation), params, sesl_p, ps.DIR_HAZARD_SLR_SESL_RAW
# )
# param_sims_rff = resample_ics(
#     ics_rff,
#     fair_temps_rff.simulation,
#     params,
# )

# Option 2: Use same pairings of FAIR-SESL params as for RCP. Captures less of the SESL
# param distribution, but maintains the random FAIR-SESL pairings from the RCP model

# rff_sp, iter
param_sims_rff = param_sims_rcp.sel(simulation=CROSSWALK.simulation, drop=True).expand_dims({'iter' : [1]})

# Resample the SESL initial condition and bias correction factors to match up with the fair simulations
fair_temps_rff = bias_correct_temps(
    fair_temps_rff,
    sesl_p["T_bias_correction_period"],
    param_sims_rff.T_ref,
    first_year=2000,
).persist()

# Run sim
sl_rff = fair_temps_rff.map_blocks(
    project_sesl,
    (param_sims_rff.chunk({"runid": 5000}),),
    template=fair_temps_rff,
).persist()

In [None]:
fair_temps_rff_diag = (
    fair_temps_rff.sel(gas="CO2_Fossil", pulse_year=[0, 2020])
    .load()
    .stack(simulation=["runid", "iter"])
)
del fair_temps_rff, fair_rff

## Calculate weighted average of AR6 baselines for RFF-SPs

In [None]:
with dask.config.set(**{"array.slicing.split_large_chunks": True}):
    rff_wt_vals = (
        sl_rff.sel(year=baseline_rcp.year, pulse_year=0, gas="CO2_Fossil")
        .drop(["pulse_year", "gas"])
        .persist()
    )
rcp_wt_vals = (
    sl_rcp.sel(year=baseline_rcp.year, pulse_year=0)
    .drop_sel(rcp=interpolated)
    .drop("pulse_year")
)

In [None]:
rcp_wt_vals_reshaped = rcp_wt_vals.sel(simulation=fair_params_rff, drop=True).chunk(
    rff_wt_vals.chunksizes
)

wt_ds = rff_wt_vals.map_blocks(
    get_bound_wts,
    (rcp_wt_vals_reshaped,),
    template=xr.Dataset(
        {
            "lb": rff_wt_vals.astype(object),
            "ub": rff_wt_vals.astype(object),
            "ub_wt": rff_wt_vals,
        }
    ),
).persist()

In [None]:
baseline_rff = (
    rff_wt_vals.chunk({"year": 10})
    .map_blocks(
        quantile_map_rff,
        (
            rcp_wt_vals.chunk({"year": 10}),
            baseline_rcp.drop_sel(rcp=interpolated).chunk({"year": 10}),
            wt_ds.chunk({"year": 10}),
        ),
        template=rff_wt_vals.chunk({"year": 10}),
    )
    .persist()
)

## Aggregate datasets

In [None]:
with dask.config.set(**{"array.slicing.split_large_chunks": False}):
    sl_rff = (
        sl_rff.sel(year=baseline_rff.year)
        .chunk({k: v for k, v in baseline_rff.chunksizes.items() if k in sl_rff.dims})
        .persist()
    )

In [None]:
out_rff = xr.concat(
    (
        baseline_rff.expand_dims(pulse_year=[0]),
        baseline_rff
        + sl_rff.drop_sel(pulse_year=0)
        - sl_rff.sel(pulse_year=0, drop=True),
    ),
    dim="pulse_year",
).persist()

## Diagnostics

In [None]:
out_rff_diag = (
    out_rff.sel(pulse_year=[0, 2020], gas="CO2")
    .load()
    .stack(simulation=["rff_sp", "iter"])
)

In [None]:
DIAG_YEARS = [2100, 2150, 2300]

diag = {}
for key, out, temp in [
    ("rcp", out_rcp.gmsl, fair_temps_rcp),
    (
        "rff",
        out_rff_diag,
        fair_temps_rff_diag,
    ),
]:
    quantiles = [0.01, 0.05, 0.17, 0.5, 0.83, 0.95, 0.99]
    sl_diff = out.sel(pulse_year=2020, drop=True) - out.sel(pulse_year=0, drop=True)
    t_diff = temp.sel(pulse_year=2020, drop=True) - temp.sel(pulse_year=0, drop=True)
    rat = sl_diff / t_diff.reindex(year=sl_diff.year)
    gmsl_pulse_qs = sl_diff.sel(year=DIAG_YEARS).quantile(q=quantiles, dim="simulation")
    gmst_pulse_qs = t_diff.sel(year=DIAG_YEARS).quantile(q=quantiles, dim="simulation")
    rat_qs = rat.sel(year=DIAG_YEARS).quantile(q=quantiles, dim="simulation")
    gmsl_base_qs = out.sel(pulse_year=0, year=DIAG_YEARS).quantile(
        q=quantiles, dim="simulation"
    )
    diag[key] = {
        "gmsl_pulse_qs": gmsl_pulse_qs,
        "gmst_pulse_qs": gmst_pulse_qs,
        "rat_qs": rat_qs,
        "gmsl_base_qs": gmsl_base_qs,
    }

### $\Delta$ GMSL from Pulse (cm)

#### 2100

In [None]:
diag["rcp"]["gmsl_pulse_qs"].sel(year=2100).to_series().unstack()

In [None]:
diag["rff"]["gmsl_pulse_qs"].sel(year=2100).to_series()

#### 2150

In [None]:
diag["rcp"]["gmsl_pulse_qs"].sel(year=2150).to_series().unstack()

In [None]:
diag["rff"]["gmsl_pulse_qs"].sel(year=2150).to_series()

#### 2300

In [None]:
diag["rcp"]["gmsl_pulse_qs"].sel(year=2300).to_series().unstack()

In [None]:
diag["rff"]["gmsl_pulse_qs"].sel(year=2300).to_series()

### $\Delta$ GMST from Pulse ($^{\circ}C$)

#### 2100

In [None]:
diag["rcp"]["gmst_pulse_qs"].sel(year=2100).to_series().unstack()

In [None]:
diag["rff"]["gmst_pulse_qs"].sel(year=2100).to_series()

#### 2150

In [None]:
diag["rcp"]["gmst_pulse_qs"].sel(year=2150).to_series().unstack()

In [None]:
diag["rff"]["gmst_pulse_qs"].sel(year=2150).to_series()

#### 2300

In [None]:
diag["rcp"]["gmst_pulse_qs"].sel(year=2300).to_series().unstack()

In [None]:
diag["rff"]["gmst_pulse_qs"].sel(year=2300).to_series()

### $\Delta$ GMSL / $\Delta$ GMST ($\frac{cm}{^{\circ}C}$)

#### 2100

In [None]:
diag["rcp"]["rat_qs"].sel(year=2100).to_series().unstack()

In [None]:
diag["rff"]["rat_qs"].sel(year=2100).to_series()

#### 2150

In [None]:
diag["rcp"]["rat_qs"].sel(year=2150).to_series().unstack()

In [None]:
diag["rff"]["rat_qs"].sel(year=2150).to_series()

#### 2300

In [None]:
diag["rcp"]["rat_qs"].sel(year=2300).to_series().unstack()

In [None]:
diag["rff"]["rat_qs"].sel(year=2300).to_series()

### Baseline GMSL, rel. 1991-2009 (cm)

#### 2100

In [None]:
diag["rcp"]["gmsl_base_qs"].sel(year=2100).to_series().unstack()

In [None]:
diag["rff"]["gmsl_base_qs"].sel(year=2100).to_series()

#### 2150

In [None]:
diag["rcp"]["gmsl_base_qs"].sel(year=2150).to_series().unstack()

In [None]:
diag["rff"]["gmsl_base_qs"].sel(year=2150).to_series()

#### 2300

In [None]:
diag["rcp"]["gmsl_base_qs"].sel(year=2300).to_series().unstack()

In [None]:
diag["rff"]["gmsl_base_qs"].sel(year=2300).to_series()

## Reshape to add back runtype dim and crop years

In [None]:
out_rcp = unflatten_runtype(out_rcp.sel(year=slice(None, 2300)))
out_rff = unflatten_runtype(out_rff.sel(year=slice(None, 2300))).persist()

In [None]:
out_rcp = out_rcp.chunk(
    {k: v for k, v in out_rff.chunksizes.items() if k in out_rcp.dims}
).persist()

out_rff = (
    out_rff.to_dataset(name="gmsl")
    .rename(iter="simulation")
    .chunk({"year": 100})
    .persist()
)

## Add attrs and save

In [None]:
REF_PERIOD = "1991-2009"

attr_all = {
    "units": "cm",
    "updated": pd.Timestamp.now(tz="US/Pacific").strftime("%c"),
    "reference_period": "1991-2009",
    "history": HISTORY,
    "author": AUTHOR,
    "contact": CONTACT,
}

out_rcp.attrs.update(
    {
        **attr_all,
        "description": DESCRIPTION_RCP,
        "method": METHOD_RCP,
        "version": ps.FAIR_RCP_VERS,
    }
)

out_rff.attrs.update(
    {
        **attr_all,
        "description": DESCRIPTION_RFF,
        "method": METHOD_RFF,
        "version": ps.FAIR_RFF_OUT_VERS,
    }
)

for ds in [out_rcp, out_rff]:
    ds.gmsl.attrs.update(
        {
            "description": "Simulations of 19-year centered mean of Global Mean Sea Level anomaly under SSP scenarios",
            "units": "cm",
            "reference_period": REF_PERIOD,
            "long_name": "GMSL sims rel. " + REF_PERIOD,
        }
    )
    ds.pulse_year.attrs.update({"description": "Year of GHG pulse"})

out_rcp.gmsl_median.attrs.update(
    {
        "description": "Simulation of 19-year centered mean of Global Mean Sea Level anomaly under SSP scenarios, using median SESL and FAIR parameters",
        "units": "cm",
        "reference_period": REF_PERIOD,
        "long_name": "GMSL med. rel. 1991-2009",
    }
)

In [None]:
out_rcp.to_zarr(ps.PATH_HAZARD_SLR_GMSL_FAIR_RCP, mode="w")
# out_rff.to_zarr(ps.PATH_HAZARD_SLR_GMSL_FAIR_RFF, mode="w")

In [None]:
cluster.close(), client.close()