#### 1. Import all packages

In [16]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import cartopy
import dask
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
import intake
import fsspec
import requests
import aiohttp

In [17]:
import gcsfs

In [18]:
%matplotlib inline
plt.rcParams['figure.figsize'] = 12, 6
%config InlineBackend.figure_format = 'retina'

In [19]:
from dask_gateway import Gateway
from dask.distributed import Client

gateway = Gateway("http://127.0.0.1:8000")
gateway.list_clusters()
cluster = gateway.new_cluster()
cluster.adapt(minimum=1, maximum=20)
client = Client(cluster)
cluster

In [15]:
# ######## cluster.shutdown()

#### 2 open the file by using url

In [20]:
df = pd.read_csv('https://storage.googleapis.com/cmip6/cmip6-zarr-consolidated-stores.csv')
df.head()

Unnamed: 0,activity_id,institution_id,source_id,experiment_id,member_id,table_id,variable_id,grid_label,zstore,dcpp_init_year,version
0,HighResMIP,CMCC,CMCC-CM2-HR4,highresSST-present,r1i1p1f1,Amon,ps,gn,gs://cmip6/CMIP6/HighResMIP/CMCC/CMCC-CM2-HR4/...,,20170706
1,HighResMIP,CMCC,CMCC-CM2-HR4,highresSST-present,r1i1p1f1,Amon,rsds,gn,gs://cmip6/CMIP6/HighResMIP/CMCC/CMCC-CM2-HR4/...,,20170706
2,HighResMIP,CMCC,CMCC-CM2-HR4,highresSST-present,r1i1p1f1,Amon,rlus,gn,gs://cmip6/CMIP6/HighResMIP/CMCC/CMCC-CM2-HR4/...,,20170706
3,HighResMIP,CMCC,CMCC-CM2-HR4,highresSST-present,r1i1p1f1,Amon,rlds,gn,gs://cmip6/CMIP6/HighResMIP/CMCC/CMCC-CM2-HR4/...,,20170706
4,HighResMIP,CMCC,CMCC-CM2-HR4,highresSST-present,r1i1p1f1,Amon,psl,gn,gs://cmip6/CMIP6/HighResMIP/CMCC/CMCC-CM2-HR4/...,,20170706


#### 3. use the package intake to create an instance that stores data. If we directly show it, it shows how many unique values are there in each column

In [21]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col

Unnamed: 0,unique
activity_id,18
institution_id,36
source_id,88
experiment_id,170
member_id,657
table_id,37
variable_id,700
grid_label,10
zstore,514818
dcpp_init_year,60


In [22]:
df['experiment_id'].unique()

array(['highresSST-present', 'piControl', 'control-1950', 'hist-1950',
       'historical', 'amip', 'abrupt-4xCO2', 'abrupt-2xCO2',
       'abrupt-0p5xCO2', '1pctCO2', 'ssp585', 'esm-piControl', 'esm-hist',
       'hist-piAer', 'histSST-1950HC', 'ssp245', 'hist-1950HC', 'histSST',
       'piClim-2xVOC', 'piClim-2xNOx', 'piClim-2xdust', 'piClim-2xss',
       'piClim-histall', 'hist-piNTCF', 'histSST-piNTCF',
       'aqua-control-lwoff', 'piClim-lu', 'histSST-piO3', 'piClim-CH4',
       'piClim-NTCF', 'piClim-NOx', 'piClim-O3', 'piClim-HC',
       'faf-heat-NA0pct', 'ssp370SST-lowCH4', 'piClim-VOC',
       'ssp370-lowNTCF', 'piClim-control', 'piClim-aer', 'hist-aer',
       'faf-heat', 'faf-heat-NA50pct', 'ssp370SST-lowNTCF',
       'ssp370SST-ssp126Lu', 'ssp370SST', 'ssp370pdSST', 'histSST-piAer',
       'piClim-ghg', 'piClim-anthro', 'faf-all', 'hist-nat', 'hist-GHG',
       'ssp119', 'piClim-histnat', 'piClim-4xCO2', 'ssp370',
       'piClim-histghg', 'highresSST-future', 'esm-ssp585-

In [23]:
query = dict(
    experiment_id=['abrupt-4xCO2','piControl'], # pick the `abrupt-4xCO2` and `piControl` forcing experiments
    table_id='Amon',                            # choose to look at atmospheric variables (A) saved at monthly resolution (mon)
    variable_id=['tas', 'rsut','rsdt','rlut'],  # choose to look at near-surface air temperature (tas) as our variable
    member_id = 'r1i1p1f1',                     # arbitrarily pick one realization for each model (i.e. just one set of initial conditions)
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset.df.groupby("source_id")[
    ["experiment_id", "variable_id", "table_id"]
].nunique()

  for _, group in grouped:


Unnamed: 0_level_0,experiment_id,variable_id,table_id
source_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ACCESS-CM2,2,4,1
ACCESS-ESM1-5,2,4,1
AWI-CM-1-1-MR,2,4,1
BCC-CSM2-MR,2,4,1
BCC-ESM1,2,4,1
CAMS-CSM1-0,2,4,1
CESM2,2,4,1
CESM2-FV2,2,4,1
CESM2-WACCM,2,4,1
CESM2-WACCM-FV2,2,4,1


In [24]:
col_subset.df.head()

Unnamed: 0,activity_id,institution_id,source_id,experiment_id,member_id,table_id,variable_id,grid_label,zstore,dcpp_init_year,version
0,CMIP,CSIRO-ARCCSS,ACCESS-CM2,abrupt-4xCO2,r1i1p1f1,Amon,tas,gn,gs://cmip6/CMIP6/CMIP/CSIRO-ARCCSS/ACCESS-CM2/...,,20191108
1,CMIP,CSIRO-ARCCSS,ACCESS-CM2,abrupt-4xCO2,r1i1p1f1,Amon,rsut,gn,gs://cmip6/CMIP6/CMIP/CSIRO-ARCCSS/ACCESS-CM2/...,,20191108
2,CMIP,CSIRO-ARCCSS,ACCESS-CM2,abrupt-4xCO2,r1i1p1f1,Amon,rsdt,gn,gs://cmip6/CMIP6/CMIP/CSIRO-ARCCSS/ACCESS-CM2/...,,20191108
3,CMIP,CSIRO-ARCCSS,ACCESS-CM2,abrupt-4xCO2,r1i1p1f1,Amon,rlut,gn,gs://cmip6/CMIP6/CMIP/CSIRO-ARCCSS/ACCESS-CM2/...,,20191108
4,CMIP,CSIRO-ARCCSS,ACCESS-CM2,piControl,r1i1p1f1,Amon,tas,gn,gs://cmip6/CMIP6/CMIP/CSIRO-ARCCSS/ACCESS-CM2/...,,20191112


In [25]:
def drop_all_bounds(ds):
    """Drop coordinates like 'time_bounds' from datasets,
    which can lead to issues when merging."""
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dsets(df):
    """Open datasets from cloud storage and return xarray dataset."""
    dsets = [xr.open_zarr(fsspec.get_mapper(ds_url), consolidated=True)
             .pipe(drop_all_bounds)
             for ds_url in df.zstore]
    try:
        ds = xr.merge(dsets, join='exact')
        return ds
    except ValueError:
        return None

def open_delayed(df):
    """A dask.delayed wrapper around `open_dsets`.
    Allows us to open many datasets in parallel."""
    return dask.delayed(open_dsets)(df)

In [26]:
from collections import defaultdict

dsets = defaultdict(dict)
for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
    dsets[group[0]][group[1]] = open_delayed(df)

In [27]:
%time open_dsets(df)

CPU times: user 279 ms, sys: 80.5 ms, total: 359 ms
Wall time: 2.8 s


Unnamed: 0,Array,Chunk
Bytes,1.24 GiB,85.85 MiB
Shape,"(6000, 192, 288)","(407, 192, 288)"
Count,16 Tasks,15 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.24 GiB 85.85 MiB Shape (6000, 192, 288) (407, 192, 288) Count 16 Tasks 15 Chunks Type float32 numpy.ndarray",288  192  6000,

Unnamed: 0,Array,Chunk
Bytes,1.24 GiB,85.85 MiB
Shape,"(6000, 192, 288)","(407, 192, 288)"
Count,16 Tasks,15 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.24 GiB,69.19 MiB
Shape,"(6000, 192, 288)","(328, 192, 288)"
Count,20 Tasks,19 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.24 GiB 69.19 MiB Shape (6000, 192, 288) (328, 192, 288) Count 20 Tasks 19 Chunks Type float32 numpy.ndarray",288  192  6000,

Unnamed: 0,Array,Chunk
Bytes,1.24 GiB,69.19 MiB
Shape,"(6000, 192, 288)","(328, 192, 288)"
Count,20 Tasks,19 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.24 GiB,105.05 MiB
Shape,"(6000, 192, 288)","(498, 192, 288)"
Count,14 Tasks,13 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.24 GiB 105.05 MiB Shape (6000, 192, 288) (498, 192, 288) Count 14 Tasks 13 Chunks Type float32 numpy.ndarray",288  192  6000,

Unnamed: 0,Array,Chunk
Bytes,1.24 GiB,105.05 MiB
Shape,"(6000, 192, 288)","(498, 192, 288)"
Count,14 Tasks,13 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.24 GiB,74.88 MiB
Shape,"(6000, 192, 288)","(355, 192, 288)"
Count,18 Tasks,17 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.24 GiB 74.88 MiB Shape (6000, 192, 288) (355, 192, 288) Count 18 Tasks 17 Chunks Type float32 numpy.ndarray",288  192  6000,

Unnamed: 0,Array,Chunk
Bytes,1.24 GiB,74.88 MiB
Shape,"(6000, 192, 288)","(355, 192, 288)"
Count,18 Tasks,17 Chunks
Type,float32,numpy.ndarray


In [28]:
dsets_ = dask.compute(dict(dsets))[0]

In [29]:
def get_lat_name(ds):
    """Figure out what is the latitude coordinate for each dataset."""
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    """Return global mean of a whole dataset."""
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'time'}
    return (ds * weight).mean(other_dims)

In [30]:
expts = ['piControl', 'abrupt-4xCO2']
expt_da = xr.DataArray(expts, dims='experiment_id',
                       coords={'experiment_id': expts})

dsets_aligned = {}

for k, v in tqdm(dsets_.items()):
    expt_dsets = v.values()
    if any([d is None for d in expt_dsets]):
        print(f"Missing experiment for {k}")
        continue

    for ds in expt_dsets:
        ds.coords['year'] = ds.time.dt.year - ds.time.dt.year[0]

    # workaround for
    # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
    dsets_ann_mean = [v[expt].pipe(global_mean)
                             .swap_dims({'time': 'year'})
                             .drop('time')
                             .coarsen(year=12).mean()
                      for expt in expts]

    # align everything with the 4xCO2 experiment
    dsets_aligned[k] = xr.concat(dsets_ann_mean, join='right',
                                 dim=expt_da)

  3%|▎         | 1/37 [00:00<00:13,  2.60it/s]

Missing experiment for ACCESS-ESM1-5


 43%|████▎     | 16/37 [00:03<00:06,  3.49it/s]

Missing experiment for FIO-ESM-2-0
Missing experiment for GFDL-CM4


100%|██████████| 37/37 [00:10<00:00,  3.50it/s]


## Lack of Memory! cannot run the following, or it will probably break my laptop

In [None]:
dsets_aligned_ = dask.compute(dsets_aligned)[0]

In [None]:
source_ids = list(dsets_aligned_.keys())
source_da = xr.DataArray(source_ids, dims='source_id',
                         coords={'source_id': source_ids})

big_ds = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_aligned_.values()],
                   dim=source_da)
big_ds

In [None]:
big_ds['imbalance'] = big_ds['rsdt'] - big_ds['rsut'] - big_ds['rlut']

ds_mean = big_ds[['tas', 'imbalance']].sel(experiment_id='piControl').mean(dim='year')
ds_anom = big_ds[['tas', 'imbalance']] - ds_mean

# add some metadata
ds_anom.tas.attrs['long_name'] = 'Global Mean Surface Temp Anom'
ds_anom.tas.attrs['units'] = 'K'
ds_anom.imbalance.attrs['long_name'] = 'Global Mean Radiative Imbalance'
ds_anom.imbalance.attrs['units'] = 'W m$^{-2}$'

ds_anom

In [None]:
ds_anom.tas.plot.line(col='source_id', x='year', col_wrap=5)

In [None]:
# limit to the gregory 150-year period
first_150_years = slice(0, 149)
ds_anom.tas.sel(year=first_150_years).plot.line(col='source_id', x='year', col_wrap=5)

In [None]:
ds_anom.imbalance.sel(year=first_150_years).plot.line(col='source_id', x='year', col_wrap=5)

In [None]:
ds_abrupt = ds_anom.sel(year=first_150_years, experiment_id='abrupt-4xCO2').reset_coords(drop=True)

def calc_ecs(ds):
    # Some sources don't have all 150 years, drop those missing values.
    a, b = np.polyfit(ds.tas.dropna("year"),
                      ds.imbalance.dropna("year"), 1)
    ecs = -0.5 * (b/a)
    return xr.DataArray(ecs)

ds_abrupt['ecs'] = ds_abrupt.groupby('source_id').apply(calc_ecs)
ds_abrupt

In [None]:
fg = ds_abrupt.plot.scatter(x='tas', y='imbalance', col='source_id', col_wrap=5)

def calc_and_plot_ecs(x, y, **kwargs):
    x = x[~np.isnan(x)]
    y = y[~np.isnan(y)]
    a, b = np.polyfit(x, y, 1)
    ecs = -0.5 * b/a
    plt.autoscale(False)
    plt.plot([0, 10], np.polyval([a, b], [0, 10]), 'k')
    plt.text(4, 6, f'ECS = {ecs:3.2f}', fontdict={'weight': 'bold', 'size': 12})
    plt.grid()

fg.map(calc_and_plot_ecs, 'tas', 'imbalance')


In [None]:
ds_abrupt.ecs.plot.hist();

In [None]:
ds_abrupt.ecs.to_dataframe().sort_values('ecs').plot(kind='bar')