# Global temperature in 2024 with regard to climate change and internal variability
Christophe Cassou & Aurélien Liné

In order to plot years 2022, 2023, and 2024, this Notebook needs to be excecuted 3 times.

## Library import

In [None]:
import cartopy
import cartopy.crs as ccrs
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import statsmodels.api
import time
import xarray as xr

from importlib import reload
from scipy import stats
from sklearn.linear_model import LinearRegression
from matplotlib import patches
from matplotlib.lines import Line2D
from matplotlib.ticker import PercentFormatter
from matplotlib import colors as mclrs
from matplotlib import ticker as mtick

## Declaration of analysis parameters

In [None]:
variable = 'tas'
table = 'Amon'
model = None
experiment = ['hist-ssp126', 'hist-ssp245', 'hist-ssp370', 'hist-ssp585']
# Restrictions on SMILES
only_one_member = False ### If True, keeps only the r1 member for each couple of model-experiment
ensemble_size_min = 1 # 1 5
# Targeted values and years
targets = {'yr':  {'year': '2024', 'GWL': 1.36}}
observations = {'WMO': {'yr': 1.52}}
year_min = 2024
# Pre-industrial period used to compute GWL from
piStr = '1850'; piEnd = '1900'
# Reference period to use that can differ from the PI (for example 1991-2020)
refStr = '1850'; refEnd = '1900'
gwl_method = 'cross' # cross state
gwl_interval = 0.01 ### Relevant if gwl_method == 'state'
window_size = 11 
# Constraint on interanual variability
constraint = 'obs_rea' # obs rea obs_rea
stats_str = '1950'
stats_end = '2014'
# Drivers of internal variability
driver = ['amv', 'nino34']
# Bootstrap
bootstrap = True
n_boot = 1000
confidence_range = .9
# Plotting parameters
show_gauss = True
extension = '.pdf'

dataDir = 'Data/'
outDir = 'Outputs/'

In [None]:
# Execute year 2022
"""
targets = {'yr':  {'year': '2022', 'GWL': 1.26}}
observations = {'WMO': {'yr': 1.15}}
"""

In [None]:
# Execute year 2023
"""
targets = {'yr':  {'year': '2023', 'GWL': 1.31}}
observations = {'WMO': {'yr': 1.44}}
"""

In [None]:
# Execute year 2024
"""
targets = {'yr':  {'year': '2024', 'GWL': 1.36}}
observations = {'WMO': {'yr': 1.52}}
"""

### Automatic routines to prepare for analysis

In [None]:
temporality = set()
for _obs in observations:
    temporality.update(list(observations[_obs].keys()))
temporality = list(temporality)
int_low = (1. - confidence_range) / 2.; int_high = 1. - int_low
temporality

In [None]:
allModels = ['ACCESS-ESM1-5', 'CNRM-CM6-1', 'CanESM5', 'IPSL-CM6A-LR', 'MIROC6', 'MPI-ESM1-2-LR']
allSSPs = ['hist-ssp126', 'hist-ssp245', 'hist-ssp370', 'hist-ssp585']

saveName = variable
saveName += '_tempo'
saveName += '_ref'+piStr+piEnd
saveName += '_win'+str(window_size)
if isinstance(model, list) and set(model).intersection(set(allModels)) == set(allModels):
    saveName += '_CMIP6MS'
    ensemble_lgd = 'CMIP6-MS'
    ensemble_size_min = 1
    model = allModels
elif model is None:
    saveName += '_allCMIP6Models'
    if ensemble_size_min != 1:
        saveName += str(ensemble_size_min)
    ensemble_lgd = 'CMIP6'
    model = None
if only_one_member:
    saveName += '_1mem'
if experiment is None or set(experiment).intersection(set(allSSPs)) == set(allSSPs):
    saveName += '_allSSP'
else:
    saveName += '_'+'-'.join(experiment).replace('hist-', '').replace('-ssp', '-')

saveName

In [None]:
for _tempo in temporality:
    _tmp = outDir+_tempo.join(saveName.split('tempo'))
    if not os.path.exists(_tmp):
        os.makedirs(_tmp)

In [None]:
variable_dict = {
    'tas': 'GSAT',
    'pr': 'precipitations',
    'nasst': 'North Atlantic SST',
    'natas': 'North Atlantic SAT',
    'naesst': 'North East Atlantic SST',
    'naetas': 'North East Atlantic SAT',
    'naessst': 'Horseshoe SST',
    'naestas': 'North East Atlantic SAT',
    'naecsst': 'North West Atlantic SST',
    'naectas': 'North West Atlantic SAT',
    'naescsst': 'Complement HS SST',
    'naesctas': 'North West Atlantic SAT',
}

temporality_dict = {
    'yr': 'Annual',
    'DJF': 'hivernale (DJF)',
    'JFM': 'hivernale (JFM)',
    'JJA': 'estivale (JJA)',
    'MJJ': 'May-June-July',
    'MJ': 'May-June',
    'm01': 'January',
    'm04': "d'avril",
    'm05': 'May',
    'm06': 'June',
    'm09': 'September',
}

lgd_obs_dict = {
    'Berkeley Earth': 'BEST',
    'ERA5': 'ERA5',
    'NOAAGlobalTempv6': 'NOAA',
    'WMO': 'Consolidated obs.',
}

driver_dict = {
    'amv':      {'variable': 'natas',   'tempo': 'yr',  'lag': 0},
    'nino34':   {'variable': 'nino34',  'tempo': 'OND', 'lag': 1},
}
driver_dict['amvGlob'] = driver_dict['amv']
driver_dict

In [None]:
gwl_lvls = [targets[_tempo]['GWL'] for _tempo in temporality]
gwl_lvls = list(set(gwl_lvls))
gwl_lvls

## Functions definition

In [None]:
experiment_dict = {'historical': {'name': 'Historical', 'color': 'gray'},
                   'piControl': {'name': 'Pre-industrical Control', 'color': 'black'},
                   'ssp119': {'name': 'SSP1-1.9', 'color': 'green'},
                   'ssp126': {'name': 'SSP1-2.6', 'color': 'blue'},
                   'ssp245': {'name': 'SSP2-4.5', 'color': 'darkgoldenrod'},
                   'ssp370': {'name': 'SSP3-7.0', 'color': 'red'},
                   'ssp434': {'name': 'SSP4.3-4', 'color': 'purple'},
                   'ssp585': {'name': 'SSP5-8.5', 'color': 'brown'}}


experiment_dict['hist-ssp119'] = experiment_dict['ssp119']
experiment_dict['hist-ssp126'] = experiment_dict['ssp126']
experiment_dict['hist-ssp245'] = experiment_dict['ssp245']
experiment_dict['hist-ssp370'] = experiment_dict['ssp370']
experiment_dict['hist-ssp434'] = experiment_dict['ssp434']
experiment_dict['hist-ssp585'] = experiment_dict['ssp585']

In [None]:
def sort_by_IV_ensemble(members:list=None):
    _dict = {}
    for _member in members:
        _source, _experiment, _ripf = _member.split('_')
        _config = 'r*i'+_ripf.split('i')[1]
        if _source not in _dict:
            _dict[_source] = {_experiment: {_config: [_member]}}
        elif _experiment not in _dict[_source]:
            _dict[_source][_experiment] = {_config: [_member]}
        elif _config not in _dict[_source][_experiment]:
            _dict[_source][_experiment][_config] = [_member]
        else:
            _dict[_source][_experiment][_config].append(_member)
    return _dict

In [None]:
season2m_dict = {'DJ':1,'JF':2,'FM':3,'MA':4,'AM':5,'MJ':6,'JJ':7,'JA':8,'AS':9,'SO':10,'ON':11,'ND':12}
season3m_dict = {'DJF':1,'JFM':2,'FMA':3,'MAM':4,'AMJ':5,'MJJ':6,'JJA':7,'JAS':8,'ASO':9,'SON':10,'OND':11,'NDJ':12}
season4m_dict = {'ONDJ':1,'NDJF':2,'DJFM':3,'JFMA':4,'FMAM':5,'MAMJ':6,'AMJJ':7,'MJJA':8,'JJAS':9,'JASO':10,'ASON':11,'SOND':12}
season5m_dict = {'NDJFM':1,'DJFMA':2,'JFMAM':3,'FMAMJ':4,'MAMJJ':5,'AMJJA':6,'MJJAS':7,'JJASO':8,'JASON':9,'ASOND':10,'SONDJ':11,'ONDJF':12}
season6m_dict = {'ASONDJ':1,'SONDJF':2,'ONDJFM':3,'NDJFMA':4,'DJFMAM':5,'JFMAMJ':6,'FMAMJJ':7,'MAMJJA':8,'AMJJAS':9,'MJJASO':10,'JJASON':11,'JASOND':12}
season12m_dict = {'FMAMJJASONDJ':1,'MAMJJASONDJF':2,'AMJJASONDJFM':3,'MJJASONDJFMA':4,'JJASONDJFMAM':5,'JASONDJFMAMJ':6,'ASONDJFMAMJJ':7,'SONDJFMAMJJA':8,'ONDJFMAMJJAS':9,'NDJFMAMJJASO':10,'DJFMAMJJASON':11,'JFMAMJJASOND':12}

def get_season(data, season):
    if season == 'yr':
        return data.resample(time='YS').mean(dim='time', keep_attrs = True, skipna=False)
    elif len(season) <= 3 and season[-1] == 'm':
        return data.rolling(time = int(season[:-1]), center = True).mean(dim='time', keep_attrs = True, skipna=False)
    elif len(season) <= 3 and season[0] == 'm':
        return data.groupby('time.month')[int(season[1:])]
    elif season in season2m_dict.keys():
        _seasons = data.rolling(time = 2, center = False).mean(dim='time', keep_attrs = True, skipna=False)
        return _seasons.sel(time = np.in1d(_seasons['time.month'], season2m_dict[season]))
    elif season in season3m_dict.keys():
        _seasons = data.rolling(time = 3, center = True).mean(dim='time', keep_attrs = True, skipna=False)
        return _seasons.sel(time = np.in1d(_seasons['time.month'], season3m_dict[season]))
    elif season in season4m_dict.keys():
        _seasons = data.rolling(time = 4, center = False).mean(dim='time', keep_attrs = True, skipna=False)
        return _seasons.sel(time = np.in1d(_seasons['time.month'], season4m_dict[season]))
    elif season in season5m_dict.keys():
        _seasons = data.rolling(time = 5, center = True).mean(dim='time', keep_attrs = True, skipna=False)
        return _seasons.sel(time = np.in1d(_seasons['time.month'], season5m_dict[season]))
    elif season in season6m_dict.keys():
        _seasons = data.rolling(time = 6, center = False).mean(dim='time', keep_attrs = True, skipna=False)
        return _seasons.sel(time = np.in1d(_seasons['time.month'], season6m_dict[season]))
    elif season in season12m_dict.keys():
        _seasons = data.rolling(time = 12, center = False).mean(dim='time', keep_attrs = True, skipna=False)
        return _seasons.sel(time = np.in1d(_seasons['time.month'], season12m_dict[season]))
    elif isinstance(season, int):
        try:
            return data.groupby('time.month')[season]
        except:
            print('Error; Nothing happened')
            return data
    else:
        print('Error; Nothing happened')
        return data

In [None]:
def normalisation(data:xr, scaler='std', period:tuple=None):
    '''
    Parameters
    ----------
    data: array of either DataArray or Dataset type from the xarray package
    scaler: can either be a string, a float, or an int
     - 'std': divided by the standard deviation
     - float in (0., 1.] : devided by the interquantile range
     - int in (0., 100.] : devided by the interpercentile range
    period: None or tuple (or list)

    ! Add an option to put back the forced response?
    
    Returns
    -------
    out: array of same type as data, normalised according to the given period
    '''
    if isinstance(data, xr.Dataset):
        _das = [data.get(_var) for _var in data.data_vars]; _to_ds = True;
    elif isinstance(data, xr.DataArray):
        _das = [data]; _to_ds = False;
    else:
        raise TypeError('Input data type should either be a Dataset or a DataArray from the xarray module.')
    _start = _stop = None
    if isinstance(period, tuple) or isinstance(period, list):
        if len(period) == 2:
            _start, _stop = period
    _dict = sort_by_IV_ensemble(list(data.member.values))
    _out = list()
    for _da in _das:
        if 'member' in _da.dims:
            _dims = ['member']
            if 'time' in _da.dims:
                _dims.append('time')
            _store = list()
            for s in _dict:
                for e in _dict[s]:
                    for c in _dict[s][e]:
                        _tmp = _da.sel(member=_dict[s][e][c])
                        _tmp_ref = _tmp.sel(time=slice(_start, _stop)) if 'time' in _tmp.dims and period is not None else _tmp
                        if isinstance(scaler, int):
                            if scaler > 0 and scaler <= 100:
                                scaler /= 100.
                            else:
                                scaler = 'std'; print('ERROR: if scaler is of type int, it must be in (0, 100].\nStandard deviation used instead.')
                        elif isinstance(scaler, float):
                            if scaler > 0. and scaler <= 1.:
                                _scl = _tmp_ref.quantile(1.-(1.-scaler)/2., dim=_dims) - _tmp_ref.quantile((1.-scaler)/2., dim=_dims)
                            else:
                                scaler = 'std'; print('ERROR: if scaler is of type float, it must be in (0., 1.].\nStandard deviation used instead.')
                        elif scaler is None:
                            _scl = 1.
                        else:
                            scaler = 'std'; print('ERROR: scaler not understood.\nStandard deviation used instead.')
                        if scaler == 'std':
                            _scl = _tmp_ref.std(dim=_dims, keep_attrs=True)
                        _tmp = (_tmp - _tmp_ref.mean(dim=_dims, keep_attrs=True)) / _scl
                        try:
                            if _scl.size != 1:
                                _tmp = xr.where(_scl < 1e-10, 0., _tmp)
                        except:
                            pass
                        _store.append(_tmp)
            _out.append(xr.concat(_store, dim='member'))
    if _to_ds:
        out = xr.merge(_out)
        out.attrs = data.attrs
    else:
        out = _out[0]
    return out

In [None]:
def get_weight(members, method='model', dim='member'):
    _weight = list()
    _members = list()
    if isinstance(members, xr.DataArray) or isinstance(members, xr.Dataset):
        members = members[dim].values
    _dict = sort_by_IV_ensemble(members)
    for _source in _dict:
        for _experiment in _dict[_source]:
            for _config in _dict[_source][_experiment]:
                _members += _dict[_source][_experiment][_config]
                [_weight.append(
                    1. / (
                          len(_dict[_source])
                        * len(_dict[_source][_experiment])
                        * len(_dict[_source][_experiment][_config])
                    )
                ) for _ in range(len(_dict[_source][_experiment][_config]))]
    out = xr.DataArray(
        _weight,
        dims = [dim],
        coords = {dim: _members}
    )
    return out

In [None]:
def ensemble_estimated_forcing(data:xr.DataArray=None, min_size=1):
    '''
    Make it xarray.Dataset fiendly!
    '''
    _dict = sort_by_IV_ensemble(list(data.member.values))
    _tmp = list()
    for s in _dict:
        for e in _dict[s]:
            for c in _dict[s][e]:
                if len(_dict[s][e][c]) >= min_size:
                    _mean = data.sel(member=_dict[s][e][c]).mean(dim='member', keep_attrs=True)
                    for l in _dict[s][e][c]:
                        _tmp.append(_mean.assign_coords({'member':l}))
                        _tmp[-1].attrs=data.attrs
                else:
                    _tmp.append(xr.full_like(data.sel(member=_dict[s][e][c]), np.nan))

    return xr.concat(_tmp, dim='member')


def remove_forced_response(data:xr.DataArray=None, min_size=1):
    '''
    Make it xarray.Dataset fiendly!
    '''
    return data - ensemble_estimated_forcing(data, min_size=min_size)

In [None]:
import scipy.signal

def get_trend(data, dim='time', groupby='time.month', **kwargs):
    '''
    Computes the linear trend on a given dimension.
    '''
    if isinstance(data, xr.Dataset):
        _das = [data.get(_var) for _var in data.data_vars]; _to_ds = True;
    elif isinstance(data, xr.DataArray):
        _das = [data]; _to_ds = False;
    else:
        raise TypeError('Input data type should either be a Dataset or a DataArray from the xarray module.')

    _trend = []
    for _da in _das:
        if dim in _da.dims and dim not in _da.name:
            _tmp_da = _da.transpose(..., dim)
            if dim == 'time' and groupby is not None:
                _da_trend = xr.concat(
                    [titi[1] - xr.DataArray(name = titi[1].name,
                        data = scipy.signal.detrend(titi[1], **kwargs),
                        dims = titi[1].dims,
                        coords = titi[1].coords) for titi in _tmp_da.groupby(groupby)], dim='time')
            else:
                _da_trend = _tmp_da - xr.DataArray(name = _tmp_da.name,
                    data = scipy.signal.detrend(_tmp_da, **kwargs),
                    dims = _tmp_da.dims,
                    coords = _tmp_da.coords)
            for _dim in _da.dims:
                _da_trend = _da_trend.transpose(..., _dim)
            _da_trend = _da_trend.reindex_like(_da)
            _da_trend.attrs = _da.attrs
        _trend.append(_da_trend)

    if _to_ds:
        out = xr.merge(_trend)
        out.attrs = data.attrs
    else:
        out = _trend[0]

    return out

In [None]:
def local_weights(da, method = 'member-per-model'):
    _members = xr.where(da, 1., np.nan).dropna(dim='member', how='all').member
    _weights = get_weight(_members, method='model').reindex_like(da, fill_value=0.)
    if method in ['mpm', 'member-per-model']:
        pass
    elif method in ['opm', 'occurrence-per-model']:
        _weights = (_weights/da.sum(dim='time')).fillna(0.)
    else:
        raise OptionNotCoddedYet
    return _weights

In [None]:
def compute_bootstrap(data, weight, n_boot=1000, dim='member'):
    _coords = data[dim]
    _data = data.rename({dim: 'sample'})
    _tmp_boot = list()
    for i in range(n_boot):
        _tmp = _data.sel(sample=np.random.choice(_data.sample, size=len(_data.sample), replace=True, p=weight/weight.sum()))
        _tmp_boot.append(_tmp.assign_coords({'boot': i+1}).drop_vars('sample'))
    return xr.concat(_tmp_boot, dim='boot').rename({'sample': dim}).assign_coords({dim: _coords})

In [None]:
def return_period(data, weight, event, bootstrap = False, data_boot = None, low = .05, high = .95):

    _mean = data.weighted(weight).mean()
    _median = data.weighted(weight).quantile(.5)
    _std = data.weighted(weight).std()
    _skew = ((data ** 3.).weighted(weight).mean() - 3. * _mean * _std ** 2. - _mean ** 3. ) / (_std ** 3.)


    data = data.sortby(data, ascending = True if event > _mean else False)

    _index = int(abs(data-event).argmin().values)
    if _index < 1e-5*len(data) or data[_index] == data.min():
        _p = 0. ; _computed = np.nan
    elif _index > (1-1e-5)*len(data) or data[_index] == data.max():
        _p = 1. ; _computed = np.nan
    else:
        _p = (weight.isel(sample=slice(None, _index+1)).sum()/weight.sum()).values ; _computed = int(round(1. / (1. - _p), 0))

    _p = stats.norm.cdf(observations[_obs][_tempo], loc=_mean, scale=_std)
    _normal = int(round(1. / (1. - _p) if event > _mean else 1. / _p, 0)) if _p not in [0., 1.] else np.nan
    _p = stats.skewnorm.cdf(observations[_obs][_tempo], _skew, loc=_mean, scale=_std)
    _skew_normal = int(round(1. / (1. - _p) if event > _mean else 1. / _p, 0)) if _p not in [0., 1.] else np.nan

    if bootstrap:
        _p_boot = list()
        for _boot in data_boot.boot:
            _da_boot = data_boot.sel(boot=_boot).dropna('sample', how='all')
            _da_boot = _da_boot.sortby(_da_boot, ascending = True if event > _mean else False)
            _index = int(abs(_da_boot-event).argmin().values)
            if _index < 1e-5*len(_da_boot) or _da_boot[_index] == _da_boot.min():
                _p_boot.append(0.)
            elif _index > (1-1e-5)*len(_da_boot) or _da_boot[_index] == _da_boot.max():
                _p_boot.append(1.)
            else:
                _p_boot.append((_index+1)/len(_da_boot))
        try:
            _bootstrap_low = int(round(1. / (1. - np.quantile(_p_boot, low)), 0))
        except:
            _bootstrap_low = np.nan
        try:
            _bootstrap_med = int(round(1. / (1. - np.quantile(_p_boot, .5)), 0))
        except:
            _bootstrap_med = np.nan
        try:
            _bootstrap_high = int(round(1. / (1. - np.quantile(_p_boot, high)), 0))
        except:
            _bootstrap_high = np.nan
    else:
        _bootstrap_low = _bootstrap_med = _bootstrap_high = np.nan

    return [_computed, _normal, _skew_normal, _bootstrap_low, _bootstrap_med, _bootstrap_high]

In [None]:
def map_global_warming_levels(da, window, levels=[1.5, 2., 3., 4.], method='cross', interval=0.1):
    """
    da: DataArray containing global average temperature with a 'time' dimension.
    window: Number of years (N) on which a time average is applied in order to compute the moving global average temperature.
    levels: List of target levels.
    method:
     - 'cross': returns the N years for which the global average temperature first crossed the dedicated warming level,
     - 'state': returns all the sets of N years for which the global average temperature was near the dedicated level (interval/2).
    interval: Interval used around each level if the method is 'state'.
    """
    _rolled = da.rolling(time=window, center=True, min_periods=window).mean(skipna=False)
    _lst = list()
    if method == 'cross':
        for _level in levels:
            _tmp = xr.where(_rolled >= _level, 1., 0.).rolling(time=window, center=True, min_periods=window).mean(skipna=False)
            _tmp = _tmp.where(_tmp != 0.)
            _tmp += np.arange(len(_tmp.time))
            _lst.append(xr.where(_tmp == _tmp.min(dim='time'), 1., 0.).rolling(time=window, center=False, min_periods=1).mean(skipna=False))
    elif method == 'state':
        [_lst.append(xr.where((_rolled > _level-interval/2.) & (_rolled < _level+interval/2.), 1., 0.).rolling(time=window, center=True, min_periods=1).mean()) for _level in levels]
    return xr.where(xr.concat(_lst, pd.Index(levels, name='warming_level')) != 0., True, False)

In [None]:
def get_stats(data_dict, stats_str, stats_end, window_size, center=True):
    _std_list = list() ; _qdw_list = list() ; _qup_list = list()
    for _var in data_dict:
        _data = xr.open_dataset(data_dict[_var]['path']).get(data_dict[_var]['var'])
        _data_yr = _data.sel(time=slice(stats_str, stats_end)).resample(time='YS').mean(dim = 'time', keep_attrs = True)
        _data_yr = _data_yr - _data_yr.rolling(time=window_size, center=center).mean()
        _std_list.append(float(_data_yr.std().values))
        _qdw_list.append(float(_data_yr.quantile(.05).values))
        _qup_list.append(float(_data_yr.quantile(.95).values))
    _ds_stats_out = xr.Dataset(
        data_vars={
            'std': (['source'], _std_list),
            'qdw': (['source'], _qdw_list),
            'qup': (['source'], _qup_list),
        },
        coords = {'source': list(obs_dict.keys())}
    )
    _ds_stats_out = _ds_stats_out.sortby(_ds_stats_out['std'])
    return _ds_stats_out

In [None]:
def plot_stats(x, std, qdw, qup, ax, color):
    ax[0].plot(x, std, 'o', mfc='w', mec=color, ms=7)
    ax[0].axhline(y=std, ls='--', lw=.75, c=color, zorder=0)
    ax[1].plot(x, qdw, 'v', mfc='w', mec=color, ms=7)
    ax[1].axhline(y=qdw, ls='--', lw=.75, c=color, zorder=0)
    ax[1].plot(x, qup, '^', mfc='w', mec=color, ms=7)
    ax[1].axhline(y=qup, ls='--', lw=.75, c=color, zorder=0)

In [None]:
from statsmodels.stats.weightstats import DescrStatsW

def mean(data, weights=None):
    '''
    Parameters
    ----------
    data: array_like, 1-D or 2-D
        dataset

    weights: None or 1-D ndarray
        weights for each observation, with same length as zero axis of data
    '''
    return DescrStatsW(data, weights=weights).mean


def var(data, weights=None):
    return DescrStatsW(data, weights=weights).var


def std(data, weights=None):
    return DescrStatsW(data, weights=weights).std


def median(data, weights=None):
    return DescrStatsW(data, weights=weights).quantile(0.5)


def quantile(data, probs, weights=None, return_pandas=False):
    return DescrStatsW(data, weights=weights).quantile(probs, return_pandas=return_pandas)


def cov(data, weights=None):
    return DescrStatsW(data, weights=weights).cov


def demeaned(data, weights=None):
    return DescrStatsW(data, weights=weights).demeaned


def ttest_mean(data, value=0, alternative='two-sided', weights=None):
    '''
    Parameters
    ----------

    value: float or array
        The hypothesized value for the mean (0 by default).

    alternative: str
        The alternative hypothesis, H1, has to be one of the following:
         - 'two-sided': H1: mean not equal to value (default)
         - 'larger' : H1: mean larger than value
         - 'smaller' : H1: mean smaller than value

    Returns
    -------

    tstat: float
        Test statistic

    pvalue: float
        pvalue of the t-test

    df: int or float
    '''
    return DescrStatsW(data, weights=weights).ttest_mean(value=value, alternative=alternative)

In [None]:
def boxplot(data, ax = None,
            weights=None,
            label = None,
            yTitle = None,
            bar = 'median', box = .5, ext = .1,
            alpha = 1,
            color = None,
            edgecolor = None,
            hatch = None,
            outliers = True,
            extrema = False,
            markersize = 3,
            dx=0, width=0.5,
            rotation=0,
            orientation='vertical'):
    """
    Create a boxplot of the dispersion of *data*.

    Parameters
    ----------

    data : array-like, shape (n, )
        Input data.

    ax : matplotlib.axes.Axes
        The axes object to draw the ellipse into.

    bar : 'mean' or 'median' or 'both'
    ...
    """

    if not isinstance(data, list):
        data = [data]

    _ax = ax or plt.gca()
    _size = np.shape(data)

    _color = color or ['black' for i in range (0, _size[0])]
    _edgecolor = edgecolor or ['black' for i in range (0, _size[0])]
    _hatch = hatch or [None for i in range (0, _size[0])]
    _label = label or ['' for i in range (0, _size[0])]
    for i in range(0, _size[0]):
        _tmp = np.array(data[i])
        _weights = weights[i] if weights is not None else weights
        if len(np.shape(_tmp)) > 1:
            _tmp = np.concatenate(data[i])
        if len(np.shape(_weights)) > 1:
            _weights = np.concatenate(_weights[i]) if weights is not None else None
        if len(_tmp) != 1:
            # compute statistics
            _median = median(_tmp, weights=_weights)
            _mean = mean(_tmp, weights=_weights)
            _boxlow = quantile(_tmp, box/2, weights=_weights)
            _boxhigh = quantile(_tmp, 1 - box/2, weights=_weights)
            _low = quantile(_tmp, ext/2, weights=_weights)
            _high = quantile(_tmp, 1 - ext/2, weights=_weights)
            _minima = np.min(_tmp)
            _maxima = np.max(_tmp)
            # plot
            _fill = True if _hatch[i] is None else None
            _x_data = [
                (dx + i, dx + i),
                (dx + i, dx + i),
                (dx + i - width / 4., dx + i + width / 4.),
                (dx + i - width / 4., dx + i + width / 4.),
                (dx + i - width / 2., dx + i + width / 2.),
                (dx + i - width / 2., dx + i + width / 2.),
                i,
                i,
            ]
            _y_data = [
                (_low, _boxlow),
                (_boxhigh, _high),
                (_low, _low),
                (_high, _high),
                (_mean, _mean),
                (_median, _median),
                _minima,
                _maxima,
            ]
            if orientation == 'vertical':
                _X = _x_data; _Y = _y_data
                _ax.add_patch(patches.Rectangle(
                    (dx + i - width / 2., _boxlow), width, _boxhigh - _boxlow,
                    facecolor = _color[i], edgecolor = _edgecolor[i], hatch = _hatch[i], fill = _fill, alpha = alpha))
            elif orientation == 'horizontal':
                _X = _y_data; _Y = _x_data
                _ax.add_patch(patches.Rectangle(
                    (_boxlow, dx + i - width / 2.), _boxhigh - _boxlow, width,
                    facecolor = _color[i], edgecolor = _edgecolor[i], hatch = _hatch[i], fill = _fill, alpha = alpha))

            _ax.plot(_X[0], _Y[0], 'k', ls = '--', lw = 1.5)
            _ax.plot(_X[1], _Y[1], 'k',ls = '--', lw = 1.5)
            _ax.plot(_X[2], _Y[2], 'k', lw = 1.5)
            _ax.plot(_X[3], _Y[3], 'k', lw = 1.5)
            
            if bar == 'mean':
                _ax.plot(_X[4], _Y[4], color = _color[i], lw = 7)
                _ax.plot(_X[4], _Y[4], color = 'w', lw = 2)
            elif bar == 'median':
                _ax.plot(_X[5], _Y[5], color = _color[i], lw = 7)
                _ax.plot((dx + i - width / 2., dx + i + width / 2.), _Y[5], color = 'w', lw = 2)
            elif bar == 'both':
                _ax.plot(_X[4], _Y[4], color = 'k', lw = 7)
                _ax.plot(_X[5], _Y[5], color = _color[i], lw = 7)
                _ax.plot(_X[5], _Y[5], color = 'w', lw = 2)

            if outliers:
                for j in range(0, len(_tmp)):
                    if (_tmp[j] < _low or _tmp[j] > _high):
                        if orientation == 'vertical':
                            _ax.plot(dx + i, _tmp[j], 'o', c = 'k', markersize = markersize)
                        elif orientation == 'horizontal':
                            _ax.plot(_tmp[j], dx + i, 'o', c = 'k', markersize = markersize)
            if extrema:
                _ax.plot(_X[6], _Y[6], 'o', c=_color[i], markersize=markersize)
                _ax.plot(_X[7], _Y[7], 'o', c=_color[i], markersize=markersize)

    if orientation == 'vertical':
        _ax.set_xticks(np.arange(_size[0]))
        _ax.set_xticklabels(_label, rotation=rotation)
        _ax.set_ylabel(yTitle)
    elif orientation == 'horizontal':
        _ax.set_yticks(np.arange(_size[0]))
        _ax.set_yticklabels(_label, rotation=rotation)
        _ax.set_xlabel(yTitle)

    return _ax

In [None]:
def get_interval_averages(x, y, z, zmin, zmax):
    _sup_mask = xr.where(z > zmin, True, False)
    _inf_mask = xr.where(z < zmax, True, False)
    _z = z.where(_sup_mask).where(_inf_mask).stack(xy=('member', 'time')).dropna(dim='xy')
    _x = x.stack(xy=('member', 'time')).sel(xy=_z.xy).unstack() ; _y = y.stack(xy=('member', 'time')).sel(xy=_z.xy).unstack()
    _w = local_weights(_z.unstack(), method='occurrence-per-model')
    return _x.weighted(_w).mean(), _y.weighted(_w).mean(), _z.unstack().weighted(_w).mean()

In [None]:
import matplotlib
def get_diverging_cmap(vmin = -1., vmax = 1., origin = 0., neg = 'Blues', pos = 'YlOrBr'):
    viridis = matplotlib.cm.get_cmap('viridis', 256)
    newcolors = viridis(np.linspace(0, 1, 256))
    if vmin < origin and vmax > origin:
        vrange = vmax - vmin
        zero = int((origin-vmin) / vrange * 256)
        negcolors = matplotlib.cm.get_cmap(neg, zero)
        poscolors = matplotlib.cm.get_cmap(pos, 256 - zero)
        newcolors[:zero, :] = negcolors(np.linspace(1, 0, zero))
        newcolors[zero:, :] = poscolors(np.linspace(0, 1, 256 - zero))
        return matplotlib.colors.ListedColormap(newcolors)
    elif vmin >= origin and vmax > vmin:
        zero = int((vmin-origin) / (vmax-origin) * 256)
        poscolors = matplotlib.cm.get_cmap(pos, 256-zero)
        newcolors = poscolors(np.linspace((vmin-origin)/(vmax-origin), 1, 256-zero))
        return matplotlib.colors.ListedColormap(newcolors)
    elif vmax <= origin and vmax > vmin:
        zero = int((vmax-vmin) / (origin-vmin) * 256)
        negcolors = matplotlib.cm.get_cmap(neg, zero)
        newcolors = negcolors(np.linspace(1, (vmax-vmin) / (origin-vmin), zero))
        return matplotlib.colors.ListedColormap(newcolors)
    else:
        print('ERROR in get_cmap().')

## Loading data

### Loading GSAT

In [None]:
# Takes about 1'30''
list_of_simulations = glob.glob(
    dataDir+'Simulations/tas_Amon_*_hist-*_global_*.nc'
)
list_of_simulations.sort()
data_list = list()
for _in_file in list_of_simulations:
    _load = True
    if model is not None:
        if _in_file.split('_')[2] not in model:
            _load = False
    if experiment is not None:
        if _in_file.split('_')[3] not in experiment:
            _load = False
    if _load:
        _in_ds = xr.open_dataset(_in_file).get('tas')
        if int(_in_ds.time.dt.year[-1]) >= year_min:
            if 'member' not in _in_ds.dims:
                _in_ds = _in_ds.assign_coords({'member': '_'.join(_in_file.split('_')[2:5])})
            data_list.append(_in_ds)
da_tas_glo = xr.concat(data_list, dim='member').dropna(dim='member',how='all')
da_tas_glo

### Loading regional variable of interest

In [None]:
# Takes about 1'30'' if variable is not 'tas'
if variable == 'tas':
    da_region = da_tas_glo
else:
    if variable in ('nasst', 'natas', 'naesst', 'naetas', 'naessst', 'naecsst', 'naescsst'):
        list_of_simulations = glob.glob(
            dataDir+'Simulations/'+variable+'_mon_*_hist-*_index_*.nc'
        )
    else:
        list_of_simulations = glob.glob(
            dataDir+'Simulations/'+variable+'_Amon_*_hist-*_index_*.nc'
        )
    list_of_simulations.sort()
    data_list = list()
    for _in_file in list_of_simulations:
        _load = True
        if model is not None:
            if _in_file.split('_')[2] not in model:
                _load = False
        if experiment is not None:
            if _in_file.split('_')[3] not in experiment:
                _load = False
        if only_one_member:
            if 'r1i' not in str(_in_file.split('_')[4]):
                _load = False
        if _load:
            _in_ds = xr.open_dataset(_in_file).get(variable)
            if int(_in_ds.time.dt.year[-1]) >= year_min:
                _in_ds['time'] = _in_ds['time'].astype('datetime64[ns]')
                if 'member' not in _in_ds.dims:
                    _in_ds = _in_ds.assign_coords({'member': '_'.join(_in_file.split('_')[2:5])})
                try:
                    _in_ds = _in_ds.drop_vars('height')
                except:
                    pass
                data_list.append(_in_ds)
    da_region = xr.concat(data_list, dim='member').dropna(dim='member', how='all')
    da_region = da_region.resample(time='M').mean(dim = 'time', keep_attrs = True)
da_region

## Determining largest set of common members for regional and global simulations, and merging of tables

In [None]:
common_members_full = list(
    set(da_tas_glo.member.values).intersection(set(da_region.member.values))
)
common_members = list()
if experiment is not None:
    for _member in common_members_full:
        if _member.split('_')[1] in experiment:
            common_members.append(_member)
else:
    common_members = common_members_full
common_members.sort()
da_tas_glo = da_tas_glo.sel(member=common_members).sortby('member')
da_region = da_region.sel(member=common_members).sortby('member')
ensemble_dict = sort_by_IV_ensemble(common_members)
common_members

In [None]:
if ensemble_size_min != 1:
    ensemble_members = list()
    for _source in ensemble_dict.keys():
        for _experiment in ensemble_dict[_source].keys():
            for _config in ensemble_dict[_source][_experiment].keys():
                if len(ensemble_dict[_source][_experiment][_config]) >= ensemble_size_min:
                    ensemble_members = ensemble_members + ensemble_dict[_source][_experiment][_config]
    ensemble_members.sort()
    da_tas_glo = da_tas_glo.sel(member=ensemble_members).sortby('member')
    da_region = da_region.sel(member=ensemble_members).sortby('member')
    ensemble_dict = sort_by_IV_ensemble(ensemble_members)
    print(len(ensemble_dict), 'models kept.')
ensemble_dict

In [None]:
for _model in ensemble_dict:
    print(_model)
    for _experiment in ensemble_dict[_model]:
        for _ripf in ensemble_dict[_model][_experiment]:
            print('', _experiment, len(ensemble_dict[_model][_experiment][_ripf]))
    print()

In [None]:
ssp_ensemble_dict = dict()
for _model in ensemble_dict:
    for _experiment in ensemble_dict[_model]:
        for _config in ensemble_dict[_model][_experiment]:
            if _experiment not in ssp_ensemble_dict:
                ssp_ensemble_dict[_experiment] = ensemble_dict[_model][_experiment][_config]
            else:
                ssp_ensemble_dict[_experiment] = ssp_ensemble_dict[_experiment] + ensemble_dict[_model][_experiment][_config]
ssp_ensemble_dict.keys()

In [None]:
hist_members = list()
for _model in ensemble_dict:
    for _ssp in ensemble_dict[_model]:
        for _ripf in ensemble_dict[_model][_ssp]:
            for _member in ensemble_dict[_model][_ssp][_ripf]:
                _tmp = _member.split('_')[-1]
                if _model+'_'+_tmp not in hist_members:
                    hist_members.append(_model+'_'+_tmp)
len(hist_members), hist_members

In [None]:
data_title = '\n(' #f_p =
if len(ensemble_dict) != 1:
    data_title += str(len(ensemble_dict))+' models, '
    mean_lgd = 'Multi-model mean'
else:
    data_title += model[0]
    mean_lgd = 'Ensemble mean'
if not only_one_member:
    data_title += str(len(hist_members))+' hist. members, '
else: ###
    data_title += str(len(ensemble_dict))+' hist. members, '
if len(experiment) == 1:
    data_title += experiment_dict[experiment[0]]['name']+', '
elif set(experiment) == set(['hist-ssp126', 'hist-ssp245', 'hist-ssp370', 'hist-ssp585']):
    data_title += 'all SSP, '
if data_title[-2:] == ', ':
    data_title = data_title[:-2]
data_title += ')'
if data_title == ' ()':
    data_title  = ''
data_title

In [None]:
ds_raw = xr.Dataset({
    'GSAT': get_season(da_tas_glo, 'yr'),
    variable: xr.concat(
        [get_season(da_region, _tempo).resample(time='YS').mean(dim = 'time', keep_attrs = True).assign_coords({'temporality':_tempo}) for _tempo in temporality],
        dim='temporality'
    )
})
ds_raw

## Calculating the anomaly over time and by level of warming

### Calculation of the anomaly with respect to the pre-industrial period

In [None]:
ds_ano_pi = normalisation(ds_raw, scaler=None, period=(piStr, piEnd))
ds_ano_pi

### Calculation of the anomaly with respect to the reference period

In [None]:
ds_ano = normalisation(ds_raw, scaler=None, period=(refStr, refEnd))
ds_ano

### Calculation of weighting (three-dimensional: model, scenario, model configuration)

In [None]:
da_weight = get_weight(ds_ano_pi)
da_weight

## Interannual statistics

### Observations

In [None]:
obs_dict = {
    'Berkeley Earth':   {'var': 'tas',      'path': dataDir+'Observations/tas_BEST_1m_global_185001-202401.nc'},
    'GISTEMPv4':        {'var': 'tas',      'path': dataDir+'Observations/tas_index_global_GISS_188001-202412.nc'},
    'HadCRUT5':         {'var': 'tas_mean', 'path': dataDir+'Observations/gsat_1y_HadCRUT5_185001-202407.nc'},
    'NOAAGlobalTempv6': {'var': 'tas',      'path': dataDir+'Observations/tas_NOAAGT_1m_global_185001-202404.nc'},
}

ds_stats_obs = get_stats(obs_dict, stats_str = stats_str, stats_end = stats_end, window_size = window_size, center=True)
ds_stats_obs

### Reanalyses

In [None]:
rea_dict = {
    'ERA5':         {'var': 'tas', 'path': dataDir+'Observations/tas_ERA5_1m_global_194001-202312.nc'},
    'JRA-30C':      {'var': 'tas', 'path': dataDir+'Observations/tas_JRA30C_1m_global_194709-202501.nc'},
    'NCEP-NCAR':    {'var': 'tas', 'path': dataDir+'Observations/tas_NCEP_1m_global_194801-202405.nc'},
    'NOAA 20C':     {'var': 'tas', 'path': dataDir+'Observations/tas_NOAA20C_1m_global_183601-201512.nc'},
}

ds_stats_rea = get_stats(rea_dict, stats_str = stats_str, stats_end = stats_end, window_size = window_size, center=True)
ds_stats_rea

### Models

#### Projections

In [None]:
_data = da_region.sel(time=slice(stats_str, stats_end)).resample(time='YS').mean(dim = 'time', keep_attrs = True)
_data = _data - _data.rolling(time=window_size, center=True).mean()
proj_std = _data.std(dim='time')
proj_qdw = _data.quantile(.05, dim='time')
proj_qup = _data.quantile(.95, dim='time')

_members_hist = list()

std_list = list()
qdw_list = list()
qup_list = list()

for i, _source in enumerate(ensemble_dict):
    _members = list()
    for _experiment in ensemble_dict[_source]:
        for _config in ensemble_dict[_source][_experiment]:
            _tmp = ensemble_dict[_source][_experiment][_config]
            if len(_tmp) > len(_members):
                _members = _tmp
    _members_hist += _members

    std_list.append(xr.DataArray(
        [[proj_std.sel(member=_members).mean().values,
          proj_std.sel(member=_members).quantile(.05).values,
          proj_std.sel(member=_members).quantile(.95).values]],
        dims=['s', 'stats'], coords={'stats': ['Average', '5%', '95%'], 'source': ('s', [_source]), 'size': ('s', [len(_members)])}
        ))
    qdw_list.append(xr.DataArray(
        [[proj_qdw.sel(member=_members).mean().values,
          proj_qdw.sel(member=_members).quantile(.05).values,
          proj_qdw.sel(member=_members).quantile(.95).values]],
        dims=['s', 'stats'], coords={'stats': ['Average', '5%', '95%'], 'source': ('s', [_source]), 'size': ('s', [len(_members)])}
        ))
    qup_list.append(xr.DataArray(
        [[proj_qup.sel(member=_members).mean().values,
          proj_qup.sel(member=_members).quantile(.05).values,
          proj_qup.sel(member=_members).quantile(.95).values]],
        dims=['s', 'stats'], coords={'stats': ['Average', '5%', '95%'], 'source': ('s', [_source]), 'size': ('s', [len(_members)])}
        ))

_weights = da_weight.sel(member=_members_hist)
std_list.append(xr.DataArray(
    [[proj_std.sel(member=_members_hist).weighted(_weights).mean().values,
      proj_std.sel(member=_members_hist).weighted(_weights).quantile(.05).values,
      proj_std.sel(member=_members_hist).weighted(_weights).quantile(.95).values]],
    dims=['s', 'stats'], coords={'stats': ['Average', '5%', '95%'], 'source': ('s', ['CMIP6']), 'size': ('s', [len(_weights)])}
    ))
qdw_list.append(xr.DataArray(
    [[proj_qdw.sel(member=_members_hist).weighted(_weights).mean().values,
      proj_qdw.sel(member=_members_hist).weighted(_weights).quantile(.05).values,
      proj_qdw.sel(member=_members_hist).weighted(_weights).quantile(.95).values]],
    dims=['s', 'stats'], coords={'stats': ['Average', '5%', '95%'], 'source': ('s', ['CMIP6']), 'size': ('s', [len(_weights)])}
    ))
qup_list.append(xr.DataArray(
    [[proj_qup.sel(member=_members_hist).weighted(_weights).mean().values,
      proj_qup.sel(member=_members_hist).weighted(_weights).quantile(.05).values,
      proj_qup.sel(member=_members_hist).weighted(_weights).quantile(.95).values]],
    dims=['s', 'stats'], coords={'stats': ['Average', '5%', '95%'], 'source': ('s', ['CMIP6']), 'size': ('s', [len(_weights)])}
    ))

ds_gsat_model_interannualvariability_stats = xr.Dataset(
    {'std': xr.concat(std_list, dim='s'),
    '5%': xr.concat(qdw_list, dim='s'),
    '95%': xr.concat(qup_list, dim='s'),
    }
)
ds_gsat_model_interannualvariability_stats = ds_gsat_model_interannualvariability_stats.set_index(s=['source', 'size'])
ds_gsat_model_interannualvariability_stats = ds_gsat_model_interannualvariability_stats.sortby(ds_gsat_model_interannualvariability_stats['std'].sel(stats='Average'))
ds_gsat_model_interannualvariability_stats

## Return period statistics

### Constrained projections

#### Constrained by observed interannual variability

In [None]:
cont_min = np.inf ; cont_max = - np.inf
if 'obs' in constraint:
    cont_min = min(cont_min, ds_stats_obs['std'].min().values)
    cont_max = max(cont_max, ds_stats_obs['std'].max().values)
if 'rea' in constraint:
    cont_min = min(cont_min, ds_stats_rea['std'].min().values)
    cont_max = max(cont_max, ds_stats_rea['std'].max().values)
cont_min, cont_max

In [None]:
# Constraint to keep only members with at least one member with an inter-annual variance within the observations range
constrain_models = list()
tested_std = proj_std.set_index(m=['member', 'source_id'])
for _source in ds_gsat_model_interannualvariability_stats['source'].values:
    if _source not in ['CMIP6']:
        _tmp = tested_std.sel(source_id=_source)
        if not (_tmp.min() > cont_max or _tmp.max() < cont_min):
            constrain_models.append(_source)
len(constrain_models), constrain_models

In [None]:
# Constraint to keep only members with an inter-annual variance within the observations range
constrain_members = xr.where(proj_std > cont_min, 1., np.nan).where(proj_std < cont_max, 1., np.nan).dropna(dim='member', how='all').member
constrain_members.attrs['n_models'] = len(set(constrain_members.source_id.values))
constrain_members.attrs['n_members'] = len(set(constrain_members.member.values))
constrain_members = constrain_members.drop_vars(['source_id', 'experiment_id', 'member_id'])
constrain_members

In [None]:
ds_mask = xr.Dataset({
    'GWL': map_global_warming_levels(ds_ano_pi.GSAT, window=window_size, levels=gwl_lvls, method=gwl_method, interval=gwl_interval)
})
ds_mask

In [None]:
reachedMem = set(xr.where(ds_mask['GWL'], 1., np.nan).dropna('member', how='all').member.values)
notReachedMem = list(set(ds_mask.member.values).difference(reachedMem)) ; notReachedMem.sort()
print('Members that did not reach the targetted GWL ('+str(len(notReachedMem))+'):')
for _member in notReachedMem:
    print(_member)

In [None]:
ds_mask = ds_mask.assign({
    'cnstMod': xr.where(ds_ano_pi.member.where(ds_ano_pi.source_id.isin(constrain_models)).notnull(), True, False),
    'cnstMem': xr.where(ds_ano_pi.member.where(ds_ano_pi.member.isin(constrain_members)).notnull(), True, False),
})
ds_mask

In [None]:
setMod = set(xr.where(ds_mask['cnstMod'], 1., np.nan).where(ds_mask['GWL']).dropna('member', how='all').source_id.values)
setMem = set(xr.where(ds_mask['cnstMem'], 1., np.nan).where(ds_mask['GWL']).dropna('member', how='all').source_id.values)
notReachedMod = list(setMod.difference(setMem)) ; notReachedMod.sort()
print('Models where all members that have a correct inter-annual variability do not reach the targetted GWL:')
for _model in notReachedMod:
    print(_model)

#### Constrained by internal variability

In [None]:
drivers_list = list()
for _driver in driver:
    list_of_simulations = glob.glob(dataDir+'Simulations/'+driver_dict[_driver]['variable']+'_mon_*_hist-ssp*_index_*.nc')
    list_of_simulations.sort()
    print(len(list_of_simulations), list_of_simulations)
    _tmp_driver = list()
    for _file in list_of_simulations:
        _member = '_'.join(_file.split('/')[-1].split('_')[2:5])
        if _member in da_region.member:
            _da = xr.open_dataarray(_file).assign_coords({'member': _member})
            for _drop in ['height', 'month']:
                try:
                    _da = _da.drop_vars(_drop)
                except:
                    pass
            _da['time'] = _da['time'].astype('datetime64[ns]')
            _tmp_driver.append(_da)
    _da_driver = xr.concat(_tmp_driver, dim='member').dropna(dim='member', how='all').resample(time='MS').mean(dim = 'time', keep_attrs = True)
    _da_driver.name = _driver
    _da_driver = get_season(_da_driver, driver_dict[_driver]['tempo'])
    if driver_dict[_driver]['lag'] > 0:
        _da_driver = _da_driver.isel(time=slice(0, -driver_dict[_driver]['lag'])).assign_coords({'time': _da_driver.time.isel(time=slice(driver_dict[_driver]['lag'], None))}).resample(time='YS').mean(dim = 'time', keep_attrs = True)
    if 'amv' in _driver:
        _da_driver = xr.DataArray(
            data = np.apply_along_axis(
                func1d = statsmodels.api.nonparametric.lowess, axis =_da_driver.get_axis_num('time'), arr = _da_driver,
                exog = _da_driver.time.dt.year, return_sorted = False, frac = 10/len(_da_driver.time)),
            coords = _da_driver.coords, dims = _da_driver.dims, name = _da_driver.name, attrs = _da_driver.attrs)
    drivers_list.append(remove_forced_response(_da_driver, min_size = 5))
ds_driver = xr.merge(drivers_list).assign_coords({'source_id': da_region.source_id})
ds_driver # should take less than 1'30''

In [None]:
driver_std = ds_driver.std(dim='time')
driver_std

In [None]:
for _driver in ds_driver:
    ds_mask = ds_mask.assign({_driver+'+': xr.where(ds_driver[_driver] > driver_std[_driver], True, False)})
    ds_mask = ds_mask.assign({_driver+'~': xr.where((ds_driver[_driver] < driver_std[_driver]) & (ds_driver[_driver] > - driver_std[_driver]), True, False)})
    ds_mask = ds_mask.assign({_driver+'-': xr.where(ds_driver[_driver] < - driver_std[_driver], True, False)})
ds_mask

In [None]:
drivers_obs_list = list()
for _driver in ['amvGlob', 'nino34']:
    _file = dataDir+'Observations/'+_driver+'_mon_ERSSTv5_index_*.nc'
    if 'amv' in _driver:
        _file = _file.replace(_driver+'_', 'amv_').replace('_mon_', '_yr_').replace('.nc', '_unfiltered.nc')
    print(_file)
    _da = xr.open_mfdataset(_file).load().get(_driver)
    for _drop in ['height', 'month']:
        try:
            _da = _da.drop_vars(_drop)
        except:
            pass
    _da['time'] = _da['time'].astype('datetime64[ns]')
    _da_driver_obs = _da ; _da_driver_obs.name = _driver
    if 'amv' not in _driver:
        _da_driver_obs = _da_driver_obs.resample(time='MS').mean(dim = 'time', keep_attrs = True)
        _da_driver_obs = get_season(_da_driver_obs, driver_dict[_driver]['tempo'])
    if driver_dict[_driver]['lag'] > 0:
       _da_driver_obs = _da_driver_obs.isel(time=slice(0, -driver_dict[_driver]['lag'])).assign_coords({'time': _da_driver_obs.time.isel(time=slice(driver_dict[_driver]['lag'], None))}).resample(time='YS').mean(dim = 'time', keep_attrs = True)
    if 'amv' in _driver:
        _da_driver_obs = xr.DataArray(
            data = np.apply_along_axis(
                func1d = statsmodels.api.nonparametric.lowess, axis =_da_driver_obs.get_axis_num('time'), arr = _da_driver_obs,
                exog = _da_driver_obs.time.dt.year, return_sorted = False, frac = 10/len(_da_driver_obs.time)),
            coords = _da_driver_obs.coords, dims = _da_driver_obs.dims, name = _da_driver_obs.name, attrs = _da_driver_obs.attrs)
    elif 'nino' in _driver:
        _da_driver_obs = _da_driver_obs - get_trend(_da_driver_obs)
    drivers_obs_list.append(_da_driver_obs.sel(time=slice('1940', '2024')).resample(time='YS').mean(dim = 'time', keep_attrs = True))
ds_driver_obs = xr.merge(drivers_obs_list)
ds_driver_obs

#### Total constraint

In [None]:
ds_GWL = xr.Dataset({variable: ds_ano_pi[variable].where(ds_mask['GWL'])})

for i, _driver in enumerate(ds_driver):
    for _tier in ['+', '~', '-']:
        ds_GWL = ds_GWL.assign({variable+'_'+_driver+_tier: ds_GWL[variable].where(ds_mask[_driver+_tier])})
        for _driver2 in ds_driver:
            if _driver2 != _driver and _driver2 in driver[i:]:
                for _tier2 in ['+', '~', '-']:
                    ds_GWL = ds_GWL.assign({
                        variable+'_'+_driver+_tier+'_'+_driver2+_tier2: ds_GWL[variable].where(ds_mask[_driver+_tier]).where(ds_mask[_driver2+_tier2])
                        })

for _var in ds_GWL:
    ds_GWL = ds_GWL.assign({_var+'_kMod': ds_GWL[_var].where(ds_mask['cnstMod'])})
    ds_GWL = ds_GWL.assign({_var+'_kMem': ds_GWL[_var].where(ds_mask['cnstMem'])})

for _var in ds_GWL:
    _tmp = ds_GWL[_var].dropna(dim='member', how='all')
    ds_GWL[_var].attrs['n_models'] = len(set(_tmp.source_id.values))
    ds_GWL[_var].attrs['n_members'] = len(set(_tmp.member.values))

ds_GWL

### Return period computation

In [None]:
_rtn_list = list()

for _obs in observations:
    _tmp = list()
    _tempo = list(observations[_obs].keys())[0]
    _event = observations[_obs][_tempo]

    ### PI
    _data = ds_ano[variable].sel(time=slice(piStr, piEnd), temporality=_tempo)

    _da = _data.stack(sample=['member', 'time']).dropna('sample', how='all')
    _da_weight = da_weight.expand_dims(
        {'time': ds_ano.time.sel(time=slice(piStr, piEnd))}
    ).transpose(..., 'time').stack(sample=['member', 'time']).sel(sample=_da.sample)
    _da_weight, _da = _da_weight.isel(sample=_da.argsort().values), _da.isel(sample=_da.argsort().values)
    _data_boot = compute_bootstrap(_data, da_weight.sel(member=_data.member), n_boot=n_boot).stack(sample=['member', 'time'])

    _tmp.append(xr.DataArray(
        return_period(data = _da, weight = _da_weight, event = observations[_obs][_tempo], bootstrap = bootstrap, data_boot = _data_boot, low=int_low, high=int_high),
        dims = ['return_period'],
        coords = {
            'return_period': ['computed', 'normal', 'skew-normal' , 'bootstrap-low', 'bootstrap-med', 'bootstrap-high'],
            'dataset': variable+'-past',
            'size': len(_da.member),
        }
    ))

    ### Currend GWL
    for _data_name in [
        variable,
        variable+'_kMod',
        variable+'_amv+_nino34+_kMod',
        variable+'_amv+_nino34-_kMod',
        variable+'_amv~_nino34~_kMod',
        variable+'_kMem',
    ]:
        _data = ds_GWL.get(_data_name).sel(temporality=_tempo).isel(warming_level=0)
        _da_weight = local_weights(xr.where(_data, True, False), method = 'occurrence-per-model').expand_dims(dim={'time': _data.time})

        _da = _data.stack(sample=('member', 'time')).dropna('sample')
        _da_weight = _da_weight.stack(sample=('member', 'time')).sel(sample=_da.sample)
        
        _da_weight, _da  = _da_weight.isel(sample=_da.argsort().values), _da.isel(sample=_da.argsort().values)

        _data_boot = compute_bootstrap(_da, _da_weight, n_boot=n_boot, dim='sample') if bootstrap else None

        _tmp.append(xr.DataArray(
            return_period(data = _da, weight = _da_weight, event = observations[_obs][_tempo], bootstrap = bootstrap, data_boot = _data_boot, low=int_low, high=int_high),
            dims = ['return_period'],
            coords = {
                'return_period': ['computed', 'normal', 'skew-normal' , 'bootstrap-low', 'bootstrap-med', 'bootstrap-high'],
                'dataset': _data_name,
                'size': len(_da.member),
            }
        ))
    _rtn_list.append(xr.concat(_tmp, dim = 'dataset'))
    _rtn_list[-1].name = _obs

ds_rtn = xr.merge(_rtn_list)
ds_rtn # should take a bit more than 10' (sets=21 x n_boot=1000), less than 15'' without bootstrap

# Figures

### Inter-annual variability

In [None]:
obs_c = 'b'
rea_c = 'r'
for _cnstrn in (None, 'models', 'members'):
    _suffix = ''
    if _cnstrn == 'models':
        _suffix += '_kMod'
    elif _cnstrn == 'members':
        _suffix += '_kMem'

    fig, ax = plt.subplots(2,1, figsize=(10, 10))
    fig.suptitle('Statistics of annual residuals (10-year window)\nof global surface air temperature ('+stats_str+'-'+stats_end+')', fontsize=15)
    ax[1].axhline(y=0, ls='-', lw=.75, c='k')

    xticks = [0] ; btm_lbls = list() ; upr_lbls = list()

    if 'obs' in constraint:
        for _source in ds_stats_obs.source:
            btm_lbls.append(_source.values) ; upr_lbls.append('Obs.')
            plot_stats(xticks[-1], ds_stats_obs['std'].sel(source=_source), ds_stats_obs['qdw'].sel(source=_source), ds_stats_obs['qup'].sel(source=_source), ax, obs_c)
            xticks.append(xticks[-1]+1)
        ax[0].axvline(x=xticks[-1], lw=.75, c='k')
        ax[1].axvline(x=xticks[-1], lw=.75, c='k')
        xticks[-1] += 1

    if 'rea' in constraint:
        for _source in ds_stats_rea.source:
            btm_lbls.append(_source.values) ; upr_lbls.append('Rea.')
            plot_stats(xticks[-1], ds_stats_rea['std'].sel(source=_source), ds_stats_rea['qdw'].sel(source=_source), ds_stats_rea['qup'].sel(source=_source), ax, rea_c)
            xticks.append(xticks[-1]+1)
        ax[0].axvline(x=xticks[-1], lw=.75, c='k')
        ax[1].axvline(x=xticks[-1], lw=.75, c='k')
        xticks[-1] += 1

    _source = 'CMIP6'
    _da = ds_gsat_model_interannualvariability_stats.sel(source='CMIP6').isel(size=0)
    btm_lbls.append('CMIP6') ; upr_lbls.append(int(_da['size']))
    _xlab_CMIP = len(upr_lbls) - 1
    _members = proj_std.member
    if _cnstrn == 'models':
        _loc_std = proj_std.where(ds_mask['cnstMod'].sel(member=_members))
        _loc_qdw = proj_qdw.where(ds_mask['cnstMod'].sel(member=_members))
        _loc_qup = proj_qup.where(ds_mask['cnstMod'].sel(member=_members))
    if _cnstrn == 'members':
        _loc_std = proj_std.where(ds_mask['cnstMem'].sel(member=_members))
        _loc_qdw = proj_qdw.where(ds_mask['cnstMem'].sel(member=_members))
        _loc_qup = proj_qup.where(ds_mask['cnstMem'].sel(member=_members))
    else:
        _loc_std = proj_std
        _loc_qdw = proj_qdw
        _loc_qup = proj_qup
    _local_weights = local_weights(_loc_std)
    boxplot([_loc_std.values], ax=ax[0], dx=xticks[-1], weights=[_local_weights.values], bar='mean', outliers=False, color='k')
    boxplot([_loc_qdw.values], ax=ax[1], dx=xticks[-1], weights=[_local_weights.values], bar='mean', outliers=False, color='k')
    boxplot([_loc_qup.values], ax=ax[1], dx=xticks[-1], weights=[_local_weights.values], bar='mean', outliers=False, color='k')
    xticks.append(xticks[-1]+1)

    ax[0].axvline(x=xticks[-1], lw=.75, c='k')
    ax[1].axvline(x=xticks[-1], lw=.75, c='k')
    xticks[-1] += 1

    _nMod = upr_lbls[_xlab_CMIP] ; _nMem = 0
    for _source in ds_gsat_model_interannualvariability_stats['source'].values:
        if _source not in ['CMIP6']:
            _members = list()
            for _experiment in ensemble_dict[_source]:
                for _config in ensemble_dict[_source][_experiment]:
                    _tmp = ensemble_dict[_source][_experiment][_config]
                    if len(_tmp) > len(_members):
                        _members = _tmp
            n_proj = len(_members)

            _mks = [['o', 'v', '^'] for _ in range(n_proj)]
            _prefix=''
            if _source not in constrain_models and _cnstrn == 'models':
                _nMod -= n_proj
                _mks = [['x', 'x', 'x'] for _ in range(n_proj)]
            elif _cnstrn == 'members':
                _mks = list() ; _nKeep = 0
                for _member in _members:
                    if _member in constrain_members:
                        _nKeep += 1
                        _mks.append(['o', 'v', '^'])
                    else:
                        _mks.append(['x', 'x', 'x'])
                _prefix=str(_nKeep)+'/'
                _nMem += _nKeep
            _mks = np.array(_mks)
            btm_lbls.append(_source) ; upr_lbls.append(_prefix+str(n_proj))
            
            for i, _member in enumerate(_members):
                ax[0].plot(xticks[-1], proj_std.sel(member=_member), _mks[i,0][0], c='k', ms=3, zorder=2)
                ax[1].plot(xticks[-1], proj_qdw.sel(member=_member), _mks[i,1][0], c='k', ms=3, zorder=2)
                ax[1].plot(xticks[-1], proj_qup.sel(member=_member), _mks[i,2][0], c='k', ms=3, zorder=2)

            xticks.append(xticks[-1]+1)
    
    if _cnstrn == 'models':
        upr_lbls[_xlab_CMIP] = str(_nMod)+'/'+str(upr_lbls[_xlab_CMIP])
    elif _cnstrn == 'members':
        upr_lbls[_xlab_CMIP] = str(_nMem)+'/'+str(upr_lbls[_xlab_CMIP])
    
    xticks = xticks[:-1]

    ax[0].set_ylabel('Standard deviation', fontsize=13)
    ax[0].xaxis.tick_top()
    ax[0].set_xticks(xticks, upr_lbls, rotation=90)
    ax[0].set_xlim(-1, xticks[-1]+1)
    ax[1].set_ylabel('5-95% Range', fontsize=13)
    ax[1].set_xlabel('Dataset', fontsize=13)
    ax[1].set_xticks(xticks, btm_lbls, rotation=90)
    ax[1].set_xlim(-1, xticks[-1]+1)
    outFig = 'interanualVariability-obsReaCMIPmodels'
    outFig += _suffix
    plt.savefig(outDir+saveName.replace('_tempo_', '_yr_')+'/'+outFig+extension)
    print('Saved:', outDir+outFig+extension)

### Single year detection and attribution

#### Raw and constrained by observed inter-annual variability

In [None]:
### Count and show number of models!

colors = ['#364B9A', '#F67E4B', '#A50026']
linestyles = ['-', '-.', '--', '-']
lw = 3

for _past, _current in [
        [ds_ano[variable], ds_GWL[variable]],
        [ds_ano[variable].where(ds_ano.source_id.isin(constrain_models)), ds_GWL[variable+'_kMod']],
        [ds_ano[variable].where(ds_ano.member.isin(constrain_members)), ds_GWL[variable+'_kMem']],
    ]:
    reg_mod = dict()
    for _tempo in temporality:

        _range = (-0.75, 2.00); _nbins = int((_range[1]-_range[0]) * 20)

        rtn_box = [Line2D([0], [0], label='Return period:', ls='')]
        text_list = list()
        fig, ax = plt.subplots(figsize=(10, 7))

        _first_line = 'Past and current '+variable_dict[variable]+' probability ('+str(window_size-1)+'-year window)'
        if '_kMod' in _current.name or '_kMem' in _current.name: ###
            _second_line = 'from constrained CMIP6 ('+str(len(set(_current.dropna(dim='member', how='all').source_id.values)))+' models, '+str(len(_current.dropna(dim='member', how='all').member))+' members)'
        else:
            _second_line = 'from CMIP6 ('+str(len(set(_current.dropna(dim='member', how='all').source_id.values)))+' models, '+str(len(_current.dropna(dim='member', how='all').member))+' members)'
        plt.title(_first_line+'\n'+_second_line, fontsize=15)
        ax.set_xlabel(temporality_dict[_tempo]+' '+variable_dict[variable]+' anomaly (°C), reference period: '+piStr+'-'+piEnd, fontsize=15)
        ax.set_ylabel('Occurrence (%)', fontsize=15)
        ax.tick_params(axis='both', which='major', labelsize=15)
        ax.axvline(x=0, ls='-', c='k', lw=.5)

        ax.set_xlim(_range)
        if window_size == 21:
            ax.set_ylim((0.0, 2.0))
        elif window_size == 11:
            ax.set_ylim((0.0, 4.0))
        else:
            ax.set_ylim((0.0, 2.0))
        plt.locator_params(axis='y', nbins=10)
        ax.yaxis.set_major_formatter(PercentFormatter(xmax=_nbins/(_range[1]-_range[0]), decimals=1, symbol=0))
        _xlims= ax.get_xlim(); _ylims= ax.get_ylim()
        _saveName = _tempo.join(saveName.split('tempo'))+'-reg_anom_distributions_ref-gwl-obs'

        ###
        i = 0
        _da = _past.sel(time=slice(piStr, piEnd), temporality=_tempo)
        _da_weight = local_weights(_da)
        _da = _da.stack(sample=['member', 'time'])
        _da_weight = _da_weight.expand_dims(
            {'time': ds_ano.time.sel(time=slice(piStr, piEnd))}
        ).transpose(..., 'time').stack(sample=['member', 'time']).sel(sample=_da.sample)
        _da_weight, _da = _da_weight.isel(sample=_da.argsort().values), _da.isel(sample=_da.argsort().values)
        print('Data size:', len(_da.member))

        plt.hist(_da, weights=_da_weight,
                    density=True, bins=_nbins, range=_range,
                    histtype='stepfilled', alpha=0.5, color=colors[i], label=ensemble_lgd+' '+piStr+'-'+piEnd,
                    edgecolor=colors[i])
        ax.axvline(x=_da.weighted(_da_weight).mean(), ls='-', c=colors[i], lw=2, zorder=2)

        if show_gauss:
            _mean = _da.weighted(_da_weight).mean()
            _std = _da.weighted(_da_weight).std()
            _x = np.linspace(_da.min(), _da.max(), 100)
            ax.plot(_x, stats.norm.pdf(_x, loc=_mean, scale=_std), color=colors[i])

        ###
        i = 1
        _wrm_lvl = targets[_tempo]['GWL']
        _legend=ensemble_lgd+' GWL = {:.2f}°C'.format(targets[_tempo]['GWL'])

        _tmp = _current.sel(temporality=_tempo)

        _da = _tmp.sel(warming_level=_wrm_lvl)
        _da_weight = local_weights(_da).expand_dims(dim={'time': _current.time})
        
        _da = _da.stack(sample=['member', 'time']).dropna('sample')
        _da_weight = _da_weight.stack(sample=('member', 'time')).sel(sample=_da.sample)
        
        _da_weight, _da  = _da_weight.isel(sample=_da.argsort().values), _da.isel(sample=_da.argsort().values)
        print('Data size:', len(_da.member))

        plt.hist(_da, weights=_da_weight,
                    density=True, bins=_nbins, range=_range,
                    histtype='stepfilled', alpha=0.5, color=colors[i], label=_legend,
                    edgecolor=colors[i])
        ax.axvline(x=_da.weighted(_da_weight).mean(), ls='-', c=colors[i], lw=2, zorder=2)

        h, l = ax.get_legend_handles_labels()
        _lgd = ax.legend(h, l, fontsize=12, loc='upper left')

        if show_gauss:
            _mean = _da.weighted(_da_weight).mean()
            _std = _da.weighted(_da_weight).std()
            _x = np.linspace(_da.min(), _da.max(), 100)
            ax.plot(_x, stats.norm.pdf(_x, loc=_mean, scale=_std), color=colors[i])

        for iObs, _obs in enumerate(observations):
            if _tempo in observations[_obs].keys():
                _text = ('{0:.0f} years'.format(ds_rtn[_obs].sel(return_period='computed', dataset=_current.name)))
                if bootstrap:
                    try:
                        _text += (
                            ' [{0:.0f} - '.format(min(ds_rtn[_obs].sel(return_period='computed', dataset=_current.name)-1, ds_rtn[_obs].sel(return_period='bootstrap-low', dataset=_current.name)))
                            +'{0:.0f}]'.format(max(ds_rtn[_obs].sel(return_period='computed', dataset=_current.name)+1, ds_rtn[_obs].sel(return_period='bootstrap-high', dataset=_current.name)))
                            )
                    except OverflowError:
                        pass
                rtn_box.append(Line2D([0], [0], label=_text, ls=linestyles[iObs], color=colors[i], lw=lw))

        for iObs, _obs in enumerate(observations):
            if _tempo in observations[_obs].keys():
                ax.axvline(x=observations[_obs][_tempo], ls=linestyles[iObs], c='red', lw=lw,
                )
                ax.text(observations[_obs][_tempo], 0, ' '+targets[_tempo]['year']+' '+lgd_obs_dict[_obs], fontsize=12, c='red', va='bottom', ha='right', rotation=90)
                ax.add_artist(ax.legend(handles=rtn_box, loc='upper right', fontsize=12))

        h, l = ax.get_legend_handles_labels()
        _lgd = ax.legend(h, l, fontsize=12, loc='upper left')

        _saveFig = outDir+_tempo.join(saveName.split('tempo'))+'/'+_saveName+'_'+_current.name+extension
        plt.savefig(_saveFig)
        print('Saved:', _saveFig)

#### Constrained by drivers of internal variability

In [None]:
data2store = ds_rtn.copy(deep=True)

colors = {
    'high': '#EE99AA',
    'neutral': 'dimgrey',
    'low': '#6699CC',
}
linestyles = ['-', '-.', '--', '-']
lw = 3

fig_dict = {
    variable:                   {'legend': ensemble_lgd+' (all members)', 'color': colors['neutral']},
    variable+'_nino34+':        {'legend': ensemble_lgd+' El Niño', 'color': colors['high']},
    variable+'_nino34~':        {'legend': ensemble_lgd+' Neutral NINO3.4', 'color': colors['neutral']},
    variable+'_nino34-':        {'legend': ensemble_lgd+' La Niña', 'color': colors['low']},
    variable+'_amv+':           {'legend': ensemble_lgd+' AMV+', 'color': colors['high']},
    variable+'_amv~':           {'legend': ensemble_lgd+' Neutral AMV', 'color': colors['neutral']},
    variable+'_amv-':           {'legend': ensemble_lgd+' AMV-', 'color': colors['low']},
    variable+'_amv+_nino34+':   {'legend': ensemble_lgd+' AMV+ & El Niño', 'color': colors['high']},
    variable+'_amv+_nino34-':   {'legend': ensemble_lgd+' AMV+ & La Niña', 'color': colors['low']},
    variable+'_amv~_nino34~':   {'legend': ensemble_lgd+' Neutral AMV & NINO3.4', 'color': colors['neutral']},
}
_list_of_keys = list(fig_dict.keys())
for _cont in ['_kMod', '_kMem']:
    for _key in _list_of_keys:
        fig_dict[_key+_cont] = fig_dict[_key]

for _sets in [
    [variable+'_amv+_nino34+_kMod', variable+'_kMod'],
    [variable+'_amv+_nino34-_kMod', variable+'_kMod'],
]:
    reg_mod = dict()
    for _tempo in temporality:
        _wrm_lvl = targets[_tempo]['GWL']

        _range = (.8, 2.); _nbins = int((_range[1]-_range[0]) * 20)

        rtn_box = [Line2D([0], [0], label='Return period:', ls='')]
        text_list = list()
        fig, ax = plt.subplots(figsize=(10, 7))

        ax.set_xlim(_range)
        if window_size == 21:
            ax.set_ylim((0.0, 2.0))
        elif window_size == 11:
            ax.set_ylim((0.0, 4.0))
        else:
            ax.set_ylim((0.0, 2.0))
        plt.locator_params(axis='y', nbins=10)
        ax.yaxis.set_major_formatter(PercentFormatter(xmax=_nbins/(_range[1]-_range[0]), decimals=1, symbol=0))
        _xlims= ax.get_xlim(); _ylims= ax.get_ylim()
        _saveName = _tempo.join(saveName.split('tempo'))+'-reg_anom_distributions_ref-gwl-obs_totalConstraint'
        if bootstrap:
            _saveName += '_boot'+str(n_boot)

        _kept_models = set() ; _kept_members = set()
        _n_members = 0
        for i, _set in enumerate(_sets):

            _data = ds_GWL[_set].sel(temporality=_tempo).sel(warming_level=_wrm_lvl, method='nearest')
            _da_weight = local_weights(xr.where(_data, True, False), method = 'occurrence-per-model').expand_dims(dim={'time': _data.time}).transpose(..., 'time')
            if _set not in data2store:
                print(_set)
                data2store = data2store.assign({_set: _data.copy(deep=True), _set+'_weight': _da_weight.copy(deep=True)})
            _da = _data.stack(sample=('member', 'time')).dropna('sample')
            _da_weight = _da_weight.stack(sample=('member', 'time')).sel(sample=_da.sample)
            _da_weight, _da  = _da_weight.isel(sample=_da.argsort().values), _da.isel(sample=_da.argsort().values)
            print('Data size:', len(_da.member))
            _n_members += int(len(_da.member) / window_size)
            _kept_models.update(list(_da.source_id.values))
            _kept_members.update(list(_da.member.values))

            plt.hist(_da, weights=_da_weight,
                        density=True, bins=_nbins, range=_range,
                        histtype='stepfilled', alpha=0.5, color=fig_dict[_set]['color'], label=fig_dict[_set]['legend'],
                        edgecolor=fig_dict[_set]['color'])
            ax.axvline(x=_da.weighted(_da_weight).mean(), ls='-', c=fig_dict[_set]['color'], lw=1.5, zorder=2)

            h, l = ax.get_legend_handles_labels()
            _lgd = ax.legend(h, l, fontsize=12, loc='upper left')

            if show_gauss:
                _mean = _da.weighted(_da_weight).mean()
                _std = _da.weighted(_da_weight).std()
                _x = np.linspace(_da.min(), _da.max(), 100)
                ax.plot(_x, stats.norm.pdf(_x, loc=_mean, scale=_std), color=fig_dict[_set]['color'])

            for iObs, _obs in enumerate(observations):
                if _tempo in observations[_obs].keys():
                    _text = ('{0:.0f} years'.format(ds_rtn[_obs].sel(return_period='computed', dataset=_set)))
                    if bootstrap:
                        try:
                            _text += (
                                ' [{0:.0f} - '.format(min(ds_rtn[_obs].sel(return_period='computed', dataset=_set)-1, ds_rtn[_obs].sel(return_period='bootstrap-low', dataset=_set)))
                                +'{0:.0f}]'.format(max(ds_rtn[_obs].sel(return_period='computed', dataset=_set)+1, ds_rtn[_obs].sel(return_period='bootstrap-high', dataset=_set)))
                                )
                        except OverflowError:
                            pass
                    rtn_box.append(Line2D([0], [0], label=_text, ls=linestyles[iObs], color=fig_dict[_set]['color'], lw=lw))

        for iObs, _obs in enumerate(observations):
            if _tempo in observations[_obs].keys():
                ax.axvline(x=observations[_obs][_tempo], ls=linestyles[iObs], c='red', lw=lw,
                )
                ax.text(observations[_obs][_tempo], 0, ' '+targets[_tempo]['year']+' '+lgd_obs_dict[_obs]+' ', fontsize=12, c='red', va='bottom', ha='right', rotation=90)
                ax.add_artist(ax.legend(handles=rtn_box, loc='upper right', fontsize=12))

        plt.title(variable_dict[variable]+' anomaly distribution at {:.2f}°C'.format(targets[_tempo]['GWL'])+' ('+str(window_size-1)+'-year window)\nfrom constrained CMIP6 ('+str(len(_kept_models))+' models, '+str(len(_kept_members))+' members)', fontsize=15)
        ax.set_xlabel(temporality_dict[_tempo]+' '+variable_dict[variable]+' anomaly (°C), reference period: '+piStr+'-'+piEnd, fontsize=15)
        ax.set_ylabel('Occurrence (%)', fontsize=15)
        ax.tick_params(axis='both', which='major', labelsize=15)
        ax.axvline(x=0, ls='-', c='k', lw=.5)

        h, l = ax.get_legend_handles_labels()
        _lgd = ax.legend(h, l, fontsize=12, loc='upper left')

        _saveFig = outDir+_tempo.join(saveName.split('tempo'))+'/'+_saveName+extension
        plt.savefig(_saveFig, transparent=True)
        print('Saved:', _saveFig)
data2store.to_netcdf(outDir+_tempo.join(saveName.split('tempo'))+'/'+'forster_yr'+targets[_tempo]['year']+'.nc')
data2store

## Forster et al., 2025

### Detection and attribution

In [None]:
ds_2022 = xr.open_dataset(outDir+'yr'.join(saveName.split('tempo'))+'/'+'forster_yr2022.nc')
ds_2023 = xr.open_dataset(outDir+'yr'.join(saveName.split('tempo'))+'/'+'forster_yr2023.nc')
ds_2024 = xr.open_dataset(outDir+'yr'.join(saveName.split('tempo'))+'/'+'forster_yr2024.nc')
sets = {
    '2022': [variable+'_kMod', variable+'_amv+_nino34-_kMod'],
    '2023': [variable+'_kMod', variable+'_amv+_nino34-_kMod'],
    '2024': [variable+'_kMod', variable+'_amv+_nino34+_kMod'],
}

colors = {
    'high': '#EE99AA',
    'neutral': 'dimgrey',
    'low': '#6699CC',
}
linestyles = ['-', '-.', '--', '-']
lw = 3

fig_dict = {
    variable:                   {'legend': ensemble_lgd+' (all members)', 'color': colors['neutral']},
    variable+'_nino34+':        {'legend': ensemble_lgd+' El Niño', 'color': colors['high']},
    variable+'_nino34~':        {'legend': ensemble_lgd+' Neutral ENSO', 'color': colors['neutral']},
    variable+'_nino34-':        {'legend': ensemble_lgd+' La Niña', 'color': colors['low']},
    variable+'_amv+':           {'legend': ensemble_lgd+' AMV+', 'color': colors['high']},
    variable+'_amv~':           {'legend': ensemble_lgd+' Neutral AMV', 'color': colors['neutral']},
    variable+'_amv-':           {'legend': ensemble_lgd+' AMV-', 'color': colors['low']},
    variable+'_amv+_nino34+':   {'legend': ensemble_lgd+' AMV+ & El Niño', 'color': colors['high']},
    variable+'_amv+_nino34-':   {'legend': ensemble_lgd+' AMV+ & La Niña', 'color': colors['low']},
    variable+'_amv~_nino34~':   {'legend': ensemble_lgd+' Neutral AMV & ENSO', 'color': colors['neutral']},
}
_list_of_keys = list(fig_dict.keys())
for _cont in ['_kMod', '_kMem']:
    for _key in _list_of_keys:
        fig_dict[_key+_cont] = fig_dict[_key]

fig, ax = plt.subplots(1, 3, figsize=(31, 9), sharex=True, sharey=True)
fig.suptitle(variable_dict[variable]+' interannual anomalies distribution from CMIP6 (15 models) for evolutive anthropogically-forced warming (ANT_GWL)\nand conditional to the combined phases of OND [yr-1] ENSO and AMV modes of variability', fontsize=21, y=1.01)

h, l = list(), list()
i = 0
for  _p,    _year, _ds,     _gwl, _wmo in [
    ['a) ', '2024', ds_2024, 1.36, 1.52],
    ['b) ', '2023', ds_2023, 1.31, 1.44],
    ['c) ', '2022', ds_2022, 1.26, 1.15],
    ]:

    reg_mod = dict()
    for _tempo in temporality:
        _wrm_lvl = targets[_tempo]['GWL']

        _range = (.8, 2.); _nbins = int((_range[1]-_range[0]) * 20)

        rtn_box = [Line2D([0], [0], label='Return period:', ls='')]
        text_list = list()

        ax[i].set_xlim(_range)
        if window_size == 21:
            ax[i].set_ylim((0.0, 2.0))
        elif window_size == 11:
            ax[i].set_ylim((0.0, 4.0))
        else:
            ax[i].set_ylim((0.0, 2.0))
        plt.locator_params(axis='y', nbins=10)
        ax[i].yaxis.set_major_formatter(PercentFormatter(xmax=_nbins/(_range[1]-_range[0]), decimals=1, symbol=0))
        _xlims= ax[i].get_xlim(); _ylims= ax[i].get_ylim()
        _saveName = _tempo.join(saveName.split('tempo'))+'-reg_anom_distributions_ref-gwl-obs_totalConstraint'
        if bootstrap:
            _saveName += '_boot'+str(n_boot)

        _kept_models = set() ; _kept_members = set()
        _n_members = 0
        for _set in sets[_year]:
            print(_set)
            _da, _da_weight = _ds.get(_set), _ds.get(_set+'_weight')
            _da = _da.stack(sample=('member', 'time')).dropna('sample')
            _da_weight = _da_weight.stack(sample=('member', 'time')).sel(sample=_da.sample)
            _da_weight, _da  = _da_weight.isel(sample=_da.argsort().values), _da.isel(sample=_da.argsort().values)
            print('Data size:', len(_da.member))

            _n_members += int(len(_da.member) / window_size)
            _kept_models.update(list(_da.source_id.values))
            _kept_members.update(list(_da.member.values))

            ax[i].hist(_da, weights=_da_weight,
                        density=True, bins=_nbins, range=_range,
                        histtype='stepfilled', alpha=0.5, color=fig_dict[_set]['color'], label=fig_dict[_set]['legend'],
                        edgecolor=fig_dict[_set]['color'])

            if show_gauss:
                _mean = _da.weighted(_da_weight).mean()
                _std = _da.weighted(_da_weight).std()
                _x = np.linspace(_da.min(), _da.max(), 100)
                ax[i].plot(_x, stats.norm.pdf(_x, loc=_mean, scale=_std), color=fig_dict[_set]['color'])
                print('Return period (gaussian-fit):', _ds[_obs].sel(return_period='normal', dataset=_set).values)

            for iObs, _obs in enumerate(observations):
                if _tempo in observations[_obs].keys():
                    _text = ('{0:.0f} years'.format(_ds[_obs].sel(return_period='computed', dataset=_set)))
                    if bootstrap:
                        try:
                            _text += (
                                ' [{0:.0f} - '.format(min(_ds[_obs].sel(return_period='computed', dataset=_set)-1, _ds[_obs].sel(return_period='bootstrap-low', dataset=_set)))
                                +'{0:.0f}]'.format(max(_ds[_obs].sel(return_period='computed', dataset=_set)+1, _ds[_obs].sel(return_period='bootstrap-high', dataset=_set)))
                                )
                        except OverflowError:
                            pass
                    rtn_box.append(Line2D([0], [0], label=_text, ls=linestyles[iObs], color=fig_dict[_set]['color'], lw=lw))

        for iObs, _obs in enumerate(observations):
            if _tempo in observations[_obs].keys():
                ax[i].axvline(x=_wmo, ls=linestyles[iObs], c='red', lw=lw,
                )
                ax[i].text(_wmo, 0, ' '+_year+' '+variable_dict[variable]+' Obs. = {:.2f}°C'.format(_wmo), fontsize=12, c='red', va='bottom', ha='right', rotation=90)
                ax[i].add_artist(ax[i].legend(handles=rtn_box, loc='upper right', fontsize=12))

        ax[i].set_title(_p+'Year '+_year+' at ANT_GWL = {:.2f}°C'.format(_gwl), fontsize=15)
        ax[i].set_xlabel(temporality_dict[_tempo]+' '+variable_dict[variable]+' anomaly (°C), reference period: '+piStr+'-'+piEnd, fontsize=15)
        ax[i].set_ylabel('Occurrence (%)', fontsize=15)
        ax[i].tick_params(axis='both', which='major', labelsize=15)
        ax[i].axvline(x=0, ls='-', c='k', lw=.5)

    _h, _l = ax[i].get_legend_handles_labels()
    for j, _l1 in enumerate(_l):
        if _l1 not in l:
            h += [_h[j]] ; l += [_l1]
    i += 1


_lgd = ax[0].legend(h, l, fontsize=12, loc='upper left') # 15

_saveFig = outDir+'yr'.join(saveName.split('tempo'))+'/'+'forster2025_fig8_cassou_line'+extension
plt.savefig(_saveFig, transparent=True)
print('Saved:', _saveFig)

### Modes of internal variability

In [None]:
gsat_wmo = {'2022': 1.15, '2023': 1.44, '2024': 1.52}
show_quantiles = False ; show_obs = True
vmin = 0.9 ; vmax = 2.0 ; nbins = (vmax - vmin) / 0.1

_driver_members = ds_driver.member.values
for _driver in ds_driver:
    _driver_members = list(set(_driver_members).intersection(set(ds_driver.get(_driver).dropna(dim='member', how='all').member.values)))
_kept_driver_members = list(set(_driver_members).intersection(set(xr.where(ds_mask['cnstMod'], 1, np.nan).dropna(dim='member').member.values)))
print(len(_kept_driver_members), 'kept members have all drivers in common.')
_LOCdata = ds_GWL.get(variable+'_kMod').isel(temporality = 0, warming_level = 0).sel(member = _kept_driver_members)
_LOCweights = local_weights(_LOCdata, method='occurrence-per-model')

fig, ax = plt.subplots(figsize=(10,7.5))
ax.set_title(variable_dict[variable]+' interannual anomalies from CMIP6 (15 models) at ANT_GWL = {:.2f}°C\nconditional to phases of OND [yr-1] ENSO and AMV modes of variability'.format(_LOCdata.warming_level), fontsize=15)
ax.set_xlabel('Annual yr[0] demeaned AMV (°C)', fontsize=15) ; ax.set_ylabel('OND yr[-1] demeaned Niño 3.4 (°C)', fontsize=15)
ax.axvline(x=0,lw=.75,c='k') ; ax.axhline(y=0,lw=.75,c='k')
ax.set_xlim(-0.6, 0.6)
ax.set_ylim(-3.75, 3.75)
_cmap = get_diverging_cmap()
_norm = mclrs.BoundaryNorm(mtick.MaxNLocator(nbins=nbins).tick_values(vmin, vmax), ncolors=_cmap.N, clip=True)

sc = plt.scatter(
    ds_driver.get('amv').sel(member = _kept_driver_members), ds_driver.get('nino34').sel(member = _kept_driver_members),
    c = _LOCdata,
    s = 15, cmap = _cmap, alpha = 0.75, norm=_norm,
)
cb = fig.colorbar(sc, label='°C')
cb.set_label(label=temporality_dict[_tempo]+' yr[0] '+variable_dict[variable]+' anomaly (°C),\nreference period: '+piStr+'-'+piEnd, size=15)

_handles = list()

if show_obs:
    for _year in gsat_wmo:
        x = ds_driver_obs.amvGlob.sel(time=_year)
        y = ds_driver_obs.nino34.sel(time=_year)
        plt.plot(x, y, c=_cmap(_norm(gsat_wmo[_year])), marker='o', mec='k', ms=15)
        _va = 'top' if _year == '2022' else 'bottom'
        plt.text(x, y, '   '+_year, va = _va, fontsize=12)
    _handles.append(Line2D([0], [0], color='w', linestyle='', marker='o', mec='k', ms=15, label='Obs: ERSSTv5/WMO'))

if show_quantiles:
    nquantiles = 10
    for i in range(nquantiles):
        zmin = _LOCdata.weighted(_LOCweights).quantile(i/nquantiles)
        zmax = _LOCdata.weighted(_LOCweights).quantile((i+1)/nquantiles)

        x, y, z = get_interval_averages(
            x = ds_driver.get('amv').sel(member = _kept_driver_members),
            y = ds_driver.get('nino34').sel(member = _kept_driver_members),
            z = ds_GWL.get(variable).isel(temporality = 0, warming_level = 0).sel(member = _kept_driver_members),
            zmin = zmin, zmax = zmax)
        plt.plot(x, y, c=_cmap(_norm(z)), marker='*', mec='k', ms=15, alpha = 0.75)
    _handles.append(Line2D([0], [0], color='w', linestyle='', marker='*', mec='k', ms=15, label='GSAT deciles'))

_handles.append(Line2D([0], [0], color='grey', linestyle='', marker='o', ms=5, label='CMIP6 occurrence'))
if len(_handles) != 0:
    ax.legend(handles=_handles, loc='lower right', fontsize=12)

_saveFig = outDir+'yr'.join(saveName.split('tempo'))+'/'+'forster2025_SI_cassou_line'+extension
#plt.savefig(_saveFig, transparent=True)
#print('Saved:', _saveFig)

In [None]:
print('My work is done.')