In [None]:
from distributed import Client
client = Client()
client

In [None]:
import time
tic = time.time()

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

from xmip.preprocessing import combined_preprocessing
from xmip.utils import google_cmip_col

from xarrayutils.plotting import shaded_line_plot
xr.set_options(keep_attrs=True)
%matplotlib inline
%config InlineBackend.figure_format='retina'
plt.rcParams['figure.figsize'] = (10,5)

## Load CMIP6 data from Pangeo Cloud Storage

In [None]:
col = google_cmip_col()
query = dict(
    source_id = [
     'IPSL-CM6A-LR',
     'MPI-ESM1-2-LR',
     'GFDL-ESM4',
     'EC-Earth3',
     'CMCC-ESM2',
     'CESM2',
    ],
    experiment_id = ['historical','ssp126', 'ssp370', 'ssp245', 'ssp585'],
    grid_label='gn',
)
cat = col.search(
    **query,
    variable_id='tos',
    member_id=['r1i1p1f1',],#'r2i1p1f1'
    table_id='Omon'
)
kwargs = dict(preprocess=combined_preprocessing, xarray_open_kwargs=dict(use_cftime=True), aggregate=False)
ddict = cat.to_dataset_dict(**kwargs)

In [None]:
cat_area = col.search(
    **query,
    table_id='Ofx',
    variable_id='areacello',
)
ddict_area = cat_area.to_dataset_dict(**kwargs)

## Postprocess loaded data with xmip

In [None]:
from xmip.postprocessing import match_metrics
ddict_w_area = match_metrics(ddict, ddict_area, 'areacello', print_statistics=True) 

In [None]:
_ = xr.set_options(use_new_combine_kwarg_defaults=True)

In [None]:
from xmip.postprocessing import concat_members

ddict_trimmed = {k:ds.sel(time=slice(None, '2100')) for k,ds in ddict_w_area.items()}
ddict_combined_members = concat_members(
    ddict_w_area,
    concat_kwargs = {'coords':'minimal', 'compat':'override', 'join':'override'}
)

## Organize datasets in xarray-datatree

In [None]:
from xarray.core.datatree import DataTree

# create a path: dataset dictionary, where the path is based on each datasets attributes
tree_dict = {f"{ds.source_id}/{ds.experiment_id}/":ds for ds in ddict_combined_members.values()}

dt = DataTree.from_dict(tree_dict)
dt

In [None]:
dt.nbytes / 1e9  # size in GB

## Select a single member that is present in each experiment

In [None]:
dt_single_member = DataTree()
for model_name, model in dt.children.items():
    member_id_values = []
    for experiment_name, experiment in model.children.items():
        ds = experiment.ds
        member_id_values.append(set(ds.member_id.data))  
    
    # find the intersection of all values
    # print(member_id_values)
    full_members = set(member_id_values[0]).intersection(*member_id_values)
    # sort and take the first one
    pick_member = sorted(full_members)[0]
    dt_single_member[model_name] = model.sel(member_id=pick_member)

## Compute weighted global mean SST

In [None]:
# average temperature globally
def global_mean_sst(ds):
    if not ds:
        return None
    else:
        return ds.tos.weighted(ds.areacello.fillna(0)).mean(['x', 'y']).persist().to_dataset() 

timeseries = dt_single_member.map_over_datasets(global_mean_sst)
timeseries

In [None]:
timeseries['/IPSL-CM6A-LR/ssp585'].ds['tos'].plot()

## Compute anomaly to 1950-1980

In [None]:
def get_ref_value(ds):
    return ds.sel(time=slice('1950','1980')).mean('time')

anomaly = DataTree()
for model_name, model in timeseries.children.items():
    # model-specific base period
    base_period = get_ref_value(model["historical"].ds)
    anomaly[model_name] = model - base_period   # subtree - Dataset

In [None]:
def replace_time(ds):
    start_date = ds.time.data[0]
    new_time = xr.date_range(f"{start_date.year}-{start_date.month:02}", 
                             freq='1MS', periods=len(ds.time), use_cftime=True)
    ds_new_cal = ds.assign_coords(time=new_time, source_id=model_name)
    return ds_new_cal
    

experiment_dict = {k:[] for k in ['historical','ssp126', 'ssp370', 'ssp245', 'ssp585']}

for model_name, model in anomaly.children.items():
    for experiment_name, experiment in model.children.items():
        # replace the time dimension
        ds_new_cal = replace_time(experiment.ds)
        experiment_dict[experiment_name].append(ds_new_cal.load())

# concatenate all models for a given experiment
plot_dict = {k:xr.concat(ds_lst, dim='source_id') for k, ds_lst in experiment_dict.items()}

## Load observational dataset (thanks [pangeo-forge](https://pangeo-forge.org)!)

In [None]:
store = 'https://ncsa.osn.xsede.org/Pangeo/pangeo-forge/HadISST-feedstock/hadisst.zarr'
ds_obs = xr.open_dataset(store, engine='zarr', chunks={}).rename({'sst':'tos'}).convert_calendar('standard', use_cftime=True)

# mask missing values
ds_obs = ds_obs.where(abs(ds_obs.tos)<50)

# reconstruct area
area = np.cos(np.deg2rad(ds_obs.latitude)) * 110e3 **2

# Repeat same steps from above
ds_obs_ts = ds_obs.weighted(area).mean(['longitude', 'latitude'])
ds_obs_anomaly = ds_obs_ts - get_ref_value(ds_obs_ts)

# add to plot_dict
plot_dict['observations'] = ds_obs_anomaly.expand_dims(['source_id', 'dcpp_init_year']).load()

## Here it is!

In [None]:
fig, ax = plt.subplots()

color_dict = {
    'historical':'0.5',
    'ssp126': 'C2',
    'ssp245': 'gold',
    'ssp370': 'C1',
    'ssp585': 'C3',
    'observations': 'C5'
}
for experiment, ds in plot_dict.items():
    color = color_dict[experiment]
    smooth = ds['tos'].sel(time=slice(None, '2100')).rolling(time=2*12).mean().squeeze('dcpp_init_year')
    lw = 2 if experiment=='observations' else 1.5
    shaded_line_plot(smooth, 'source_id', ax=ax, spreads=[2.0], alphas=[0.2], line_kwargs=dict(color=color, label=f"{experiment} ({len(ds.source_id)})", lw=lw))
plt.legend(loc=2);
plt.grid()
plt.tight_layout()

In [None]:
toc = time.time()
print(f"Elapsed time: {int(toc-tic)} seconds")