In [1]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm 
import intake
import fsspec
import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'retina' 

  from tqdm.autonotebook import tqdm


In [2]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")

In [3]:
[eid for eid in col.df['experiment_id'].unique() if 'ssp' in eid]

['ssp370SST-lowNTCF',
 'ssp370SST-lowCH4',
 'ssp370-lowNTCF',
 'ssp370SST-ssp126Lu',
 'ssp370pdSST',
 'ssp370SST',
 'esm-ssp585',
 'ssp126-ssp370Lu',
 'esm-ssp585-ssp126Lu',
 'ssp585',
 'ssp370',
 'ssp370-ssp126Lu',
 'ssp245',
 'ssp119',
 'ssp126',
 'ssp245-nat',
 'ssp245-GHG',
 'ssp434',
 'ssp460',
 'ssp534-over',
 'ssp245-stratO3',
 'ssp245-aer',
 'ssp245-covid',
 'ssp245-cov-strgreen',
 'ssp245-cov-modgreen',
 'ssp245-cov-fossil',
 'ssp585-bgc']

In [4]:
# there is currently a significant amount of data for these runs
#expts = ['historical', 'ssp245', 'ssp585']
expts = ['historical','ssp245', 'ssp126','ssp370','ssp585']
query = dict(
    experiment_id=expts,
    table_id='Amon',                           
    variable_id=['tas'],
    member_id = 'r1i1p1f1',                     
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset.df.groupby("source_id")[
    ["experiment_id", "variable_id", "table_id"]
].nunique()

Unnamed: 0_level_0,experiment_id,variable_id,table_id
source_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ACCESS-CM2,5,1,1
ACCESS-ESM1-5,5,1,1
AWI-CM-1-1-MR,5,1,1
BCC-CSM2-MR,5,1,1
CAMS-CSM1-0,5,1,1
CESM2-WACCM,5,1,1
CMCC-CM2-SR5,5,1,1
CMCC-ESM2,5,1,1
CanESM5,5,1,1
EC-Earth3,5,1,1


In [5]:
#ds = intake.open_esm_datastore('/Volumes/Transcend/tas_Amon_CAS-ESM2-0_ssp245_r1i1p1f1_gn_201501-210012.nc')
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True)
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = defaultdict(dict) 

for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
    dsets[group[0]][group[1]] = open_delayed(df)

In [6]:
dsets_ = dask.compute(dict(dsets))[0]

In [7]:
# calculate global means

def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'time'}
    return (ds * weight).mean(other_dims)

In [8]:
expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                        coords={'experiment_id': expts})

dsets_aligned = {}

for k, v in tqdm(dsets_.items()):
    expt_dsets = v.values()
    if any([d is None for d in expt_dsets]):
        print(f"Missing experiment for {k}")
        continue
    
    for ds in expt_dsets:
        #ds.coords['year'] = ds.time.dt.year
        #ds.coords['month']=ds.time.dt.month
        ds.coords['rawmonth']=ds.time.dt.month+(ds.time.dt.year-1850)*12
        #print(ds.coords)
    # workaround for
    # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
    dsets_mon_mean = [v[expt].pipe(global_mean)
                             .swap_dims({'time': 'rawmonth'})
                             .drop('time')
                             #.drop('mon')
    #                         .coarsen(year=12).mean()
                      for expt in expts]
    
    #print(dsets_mon_mean)
    
    # align everything with the 4xCO2 experiment
    dsets_aligned[k] = xr.concat(dsets_mon_mean, join='outer',
                                 dim=expt_da)

  0%|          | 0/27 [00:00<?, ?it/s]

In [None]:
# expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
#                        coords={'experiment_id': expts})

# dsets_aligned = {}

# for k, v in tqdm(dsets_.items()):
#     expt_dsets = v.values()
#     if any([d is None for d in expt_dsets]):
#         print(f"Missing experiment for {k}")
#         continue
    
#     for ds in expt_dsets:
#         ds.coords['year'] = ds.time.dt.year
        
#     # workaround for
#     # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
#     dsets_ann_mean = [v[expt].pipe(global_mean)
#                              .swap_dims({'time': 'year'})
#                              .drop('time')
#                              .coarsen(year=12).mean()
#                       for expt in expts]
    
#     # align everything with the 4xCO2 experiment
#     dsets_aligned[k] = xr.concat(dsets_ann_mean, join='outer',
#                                  dim=expt_da)

In [9]:
with progress.ProgressBar():
    dsets_aligned_ = dask.compute(dsets_aligned)[0]

[########################################] | 100% Completed |  2min 50.9s


In [10]:
source_ids = list(dsets_aligned_.keys())
source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})

big_ds = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_aligned_.values()],
                    dim=source_da)

big_ds

In [12]:
df_all = big_ds.sel(rawmonth=slice(1,5412)).to_dataframe().reset_index()
df_all.head()

Unnamed: 0,experiment_id,rawmonth,source_id,tas
0,historical,1,ACCESS-CM2,285.191116
1,historical,1,ACCESS-ESM1-5,285.925713
2,historical,1,AWI-CM-1-1-MR,
3,historical,1,BCC-CSM2-MR,285.642202
4,historical,1,CAMS-CSM1-0,285.393356


In [None]:
sns.relplot(data=df_all,
            x="year", y="tas", hue='experiment_id',
            kind="line", ci="sd", aspect=2);

In [13]:
print(df_all.shape)
df_by_exp = df_all.groupby('experiment_id')
tas_exp = []
names = []
for name, group in df_by_exp:
    group_mon = group.groupby('rawmonth')
    tas = np.empty((13,251))
    tas[0,:]=np.arange(1850,2101)
    names.append(name)
    for mon, mongroup in group_mon:
        if mon<3012:
            montemp = np.nanmean(mongroup['tas'])
            i_month = int(mon%12)
            if i_month==0:
                i_month=12
            i_year = int(np.floor(mon/12))
            tas[i_month,i_year] = montemp
    #print(np.count_nonzero(tas))
    #print(np.count_nonzero(np.isnan(tas)))
    tas_exp.append(tas)
    #years = (np.asarray(months)/12)+1850
    #plt.plot(years,avg_temps,label=name,alpha=0.6)
    #plt.show()

(730620, 4)


  montemp = np.nanmean(mongroup['tas'])


In [None]:
# print(df_all.shape)
# df_by_exp = df_all.groupby('experiment_id')

# plt.figure(figsize=(8,6))
# for name, group in df_by_exp:
#     group_mon = group.groupby('rawmonth')
#     avg_temps = []
#     months = []
#     for mon, mongroup in group_mon:
#         montemp = np.nanmean(mongroup['tas'])
#         avg_temps.append(montemp)
#         months.append(mon)
#     temps = np.array(group['tas'])
#     print(temps.shape)
#     print(group.shape)
#     years = (np.asarray(months)/12)+1850
#     plt.plot(years,avg_temps,label=name,alpha=0.6)
#     #plt.show()
# #plt.xlim(2000,2020)
# #plt.ylim(270,296)
# plt.legend()

In [None]:
print(df_all['year'].shape)
#years = np.array(df_all['year'])
#print(years.shape)
plt.figure(figsize=(20,6))
plt.plot(years,'b.')

In [None]:
plt.figure(figsize=(20,6))
for name, group in df_by_exp:
    for 
    plt.plot(group['year'],label=name)
plt.legend()