In [66]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm
import xarray as xr
xr.set_options(display_style='html')
import intake
%matplotlib inline
import fsspec
import seaborn as sns

import pandas as pd
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [76]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import zarr
import fsspec
import nc_time_axis

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
plt.rcParams['figure.figsize'] = 12, 6

In [77]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col

Unnamed: 0,unique
activity_id,18
institution_id,36
source_id,88
experiment_id,170
member_id,657
table_id,37
variable_id,700
grid_label,10
zstore,514818
dcpp_init_year,60


In [78]:
[eid for eid in col.df['experiment_id'].unique() if 'ssp' in eid]

['ssp585',
 'ssp245',
 'ssp370SST-lowCH4',
 'ssp370-lowNTCF',
 'ssp370SST-lowNTCF',
 'ssp370SST-ssp126Lu',
 'ssp370SST',
 'ssp370pdSST',
 'ssp119',
 'ssp370',
 'esm-ssp585-ssp126Lu',
 'ssp126-ssp370Lu',
 'ssp370-ssp126Lu',
 'ssp126',
 'esm-ssp585',
 'ssp245-GHG',
 'ssp245-nat',
 'ssp460',
 'ssp434',
 'ssp534-over',
 'ssp245-stratO3',
 'ssp245-aer',
 'ssp245-cov-modgreen',
 'ssp245-cov-fossil',
 'ssp245-cov-strgreen',
 'ssp245-covid',
 'ssp585-bgc']

In [79]:
# there is currently a significant amount of data for these runs
expts = ['historical', 'ssp126','ssp245', 'ssp585']

query = dict(
    experiment_id=expts,
    table_id='Amon',
    variable_id=['pr'],
    member_id = 'r1i1p1f1',
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset.df.groupby("source_id")[
    ["experiment_id", "variable_id", "table_id"]
].nunique()

Unnamed: 0_level_0,experiment_id,variable_id,table_id
source_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ACCESS-CM2,4,1,1
AWI-CM-1-1-MR,4,1,1
BCC-CSM2-MR,4,1,1
CAMS-CSM1-0,4,1,1
CAS-ESM2-0,4,1,1
CESM2-WACCM,4,1,1
CMCC-CM2-SR5,4,1,1
CMCC-ESM2,4,1,1
CanESM5,4,1,1
EC-Earth3,4,1,1


In [80]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True)
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = defaultdict(dict)

for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
    dsets[group[0]][group[1]] = open_delayed(df)

In [81]:
dsets_ = dask.compute(dict(dsets))[0]

  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)


In [82]:
# calculate global means

def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    ds = ds.sel(lon = slice(117, 125), lat = slice(-11, -1))
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'time'}
    return (ds * weight).mean(other_dims)

In [83]:
expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                       coords={'experiment_id': expts})

dsets_aligned = {}

for k, v in tqdm(dsets_.items()):
    expt_dsets = v.values()
    if any([d is None for d in expt_dsets]):
        print(f"Missing experiment for {k}")
        continue

    for ds in expt_dsets:
        ds.coords['year'] = ds.time.dt.year

    # workaround for
    # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
    dsets_ann_mean = [v[expt].pipe(global_mean)
                             .swap_dims({'time': 'year'})
                             .drop('time')
                             .coarsen(year=12).mean()
                      for expt in expts]

    # align everything with the 4xCO2 experiment
    dsets_aligned[k] = xr.concat(dsets_ann_mean, join='outer',
                                 dim=expt_da)

  0%|          | 0/29 [00:00<?, ?it/s]

In [84]:
with progress.ProgressBar():
    dsets_aligned_ = dask.compute(dsets_aligned)[0]

[########################################] | 100% Completed |  9min 50.7s


In [85]:
source_ids = list(dsets_aligned_.keys())
source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})

big_ds = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_aligned_.values()],
                    dim=source_da)

big_ds

In [86]:
big_ds = big_ds.sel(year = slice("1850",'2100'))

In [87]:
big_ds.to_netcdf('/home/c4ubuntu/projDir/Indonesia/Data/CMIP6/Pangeo/CMIP6_PR_Yearly.nc')