In [1]:
import numpy as np
import xarray as xr
import dask
import cftime
import os
import intake
import pandas as pd
from collections import defaultdict
from tqdm.autonotebook import tqdm
from xmip.postprocessing import combine_datasets, concat_members,concat_experiments
import gcsfs
from datetime import datetime
fs = gcsfs.GCSFileSystem() #list stores, stripp zarr from filename, load 

  from tqdm.autonotebook import tqdm


Configure the script:

In [2]:
def get_ddict_of_model_datasets(input_path,ssps):
    #input path must contain directories organized by model
    ddict = defaultdict(dict)
    for model_path in tqdm(fs.ls(input_path)):
        model_datasets = []
        model_ssps = []
        #print(model_path)
        for s,ssp in enumerate(ssps):
            #print(ssp)
            model_fns = ['gs://'+k for k in fs.ls(model_path) if ssp in k]
            if len(model_fns)>0:
                model_ds = xr.open_mfdataset(model_fns,engine='zarr',combine='nested',coords='minimal',compat='override',chunks={'source_id':1,'member_id':1,'lon':1000,'lat':1000,'time':200})
                model_datasets.append(model_ds)
                model_ssps.append(ssp)
        
        if len(model_ssps)>0:
            model_ssps_ds = xr.concat(model_datasets,dim='ssp',compat='override',coords='minimal')
            model_ssps_ds['ssp'] = model_ssps
        
            ddict[str(model_ssps_ds.source_id)]=model_ssps_ds
    return ddict

def get_ddict_of_datasets(input_path,ssps):
    #input path must contain directories organized by model
    ddict = defaultdict(dict)
    for model_path in tqdm(fs.ls(input_path)):
        for ssp in ssps:
            fns = ['gs://'+k for k in fs.ls(model_path) if ssp in k]
            for fn in fns:
                ds = xr.open_dataset(fn,engine='zarr',use_cftime=True,chunks={'lat':500,'lon':500,'time':200})
                ds = ds.expand_dims(dim={'experiment_id':[str(ssp)]})
                ds.attrs['experiment_id'] = ssp
                ddict[fn] = ds
    return ddict
    
def add_years_months_to_coords(ds):
    years = []
    months = []
    for k,ts in enumerate(ds.time.values):
        try: #cant remember what this was for, is this even the right syntax?
            years.append(ts.dt.year)
            months.append(ts.dt.month)
        except:
            years.append(ts.year)
            months.append(ts.month)
    ds = ds.assign_coords({'year':('time',years),'month':('time',months)})
    return ds

def enough_numMonths_in_period(ds,period,tol=.95):
    return len(ds.time) >= tol * 12 * (1+period[-1]-period[0])

def derive_dict_of_month_means(ddict,hist,fut):
    month_means = defaultdict(dict)
    for key,ds in tqdm(ddict.items()):
        #print(key)
        try:
            ds_hist = ds.sel(time=slice(str(hist[0]),str(hist[-1])))
            ds_fut = ds.sel(time=slice(str(fut[0]),str(fut[-1])))
            if not enough_numMonths_in_period(ds_hist,hist) or not enough_numMonths_in_period(ds_fut,fut):
                print('dataset incomplete for historical and/or future periods, not including: '+key)
                continue
            month_means[key] = xr.concat((ds_hist.groupby(ds_hist.time.dt.month).mean(),
                                       ds_fut.groupby(ds_fut.time.dt.month).mean()
                                       ),dim='period')
            month_means[key]['period'] = ['hist','fut']

        except: #irregular calendar, issue with selecting years/months
            print('Cannot select years/months due to time index issues, computing these coordinates manually, for: '+key)
            ds = add_years_months_to_coords(ds)
            ds_hist = ds.where((ds.year >= hist[0])&(ds.year <= hist[1]))
            ds_fut = ds.where((ds.year >= fut[0])&(ds.year <= fut[1]))

            if not enough_numMonths_in_period(ds_hist,hist) or not enough_numMonths_in_period(ds_fut,fut):
                print('dataset incomplete for historical and/or future periods, not including: '+key)
                continue
            
            month_means[key] = xr.concat((ds_hist.groupby(ds_hist.month).mean(),
                                       ds_fut.groupby(ds_fut.month).mean()
                                       ),dim='period')
            month_means[key]['period'] = ['hist','fut']
            
    return month_means

In [3]:
query_var = 'zos' #variables to process
#ssps = ['ssp585']
ssps = ['ssp126','ssp245','ssp370','ssp585'] #SSPs to process #(TODO: loop over multiple, streamline code!)
input_path = 'gs://leap-persistent/timh37/CMIP6/zos_1x1'
output_path = 'gs://leap-persistent/timh37/CMIP6/zos_1x1_month_means'
hist = [1993,2022]
fut = [2071,2100]

Generate dictionary of month means per model:

In [4]:
ddict = {}
ddict = get_ddict_of_datasets(input_path,ssps) #open datasets (grouped by model) into dictionary
month_means = derive_dict_of_month_means(ddict,hist,fut) #compute month means for each dataset
month_means = concat_members(month_means) #not sure why this also automatically combines experiment ids, but ok?

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/1090 [00:00<?, ?it/s]

dataset incomplete for historical and/or future periods, not including: gs://leap-persistent/timh37/CMIP6/zos_1x1/CESM2-WACCM/CESM2-WACCM.gn.Omon.r2i1p1f1.zos.hist_ssp370.1980-2055
dataset incomplete for historical and/or future periods, not including: gs://leap-persistent/timh37/CMIP6/zos_1x1/CESM2-WACCM/CESM2-WACCM.gn.Omon.r3i1p1f1.zos.hist_ssp370.1980-2055
dataset incomplete for historical and/or future periods, not including: gs://leap-persistent/timh37/CMIP6/zos_1x1/CNRM-CM6-1/CNRM-CM6-1.gn.Omon.r10i1p1f2.zos.hist_ssp245.1980-2020
dataset incomplete for historical and/or future periods, not including: gs://leap-persistent/timh37/CMIP6/zos_1x1/CNRM-CM6-1/CNRM-CM6-1.gn.Omon.r7i1p1f2.zos.hist_ssp245.1980-2020
dataset incomplete for historical and/or future periods, not including: gs://leap-persistent/timh37/CMIP6/zos_1x1/CNRM-CM6-1/CNRM-CM6-1.gn.Omon.r8i1p1f2.zos.hist_ssp245.1980-2020
dataset incomplete for historical and/or future periods, not including: gs://leap-persistent/timh37/

In [5]:
out_ddict = defaultdict(dict)
for model in np.unique([v.source_id for v in month_means.values()]):
    datasets = []
    for k,v in month_means.items():
        if v.source_id == model:
            datasets.append(v)
    model_dataset = xr.concat(datasets,coords='different',dim='experiment_id')
    out_ddict[str(model)] = model_dataset
    

Store:

In [None]:
for key,ds in tqdm(out_ddict.items()):
    fn = ds.source_id+'_'+query_var+'_month_means_hist'+str(hist[0])+'-'+str(hist[-1])+'_fut'+str(fut[0])+'-'+str(fut[-1])
    
    ds['member_id'] = ds['member_id'].astype('str')
    ds['experiment_id'] = ds['experiment_id'].astype('str')
    ds = ds.assign_coords({'source_id':ds.source_id})
    ds[query_var] = ds[query_var].chunk({'lat':500,'lon':500,'period':1,'member_id':50,'experiment_id':1,'month':12})

    ds.to_zarr(os.path.join(output_path,fn),mode='w')
    ds.close()

  0%|          | 0/40 [00:00<?, ?it/s]

In [3]:
fs.rm('gs://leap-persistent/timh37/CMIP6/zos_1x1_month_means','r')

['leap-persistent/timh37/CMIP6/zos_1x1_month_means/ACCESS-CM2_zos_month_means_hist1993-2022_fut2071-2100/.zattrs',
 'leap-persistent/timh37/CMIP6/zos_1x1_month_means/ACCESS-CM2_zos_month_means_hist1993-2022_fut2071-2100/.zgroup',
 'leap-persistent/timh37/CMIP6/zos_1x1_month_means/ACCESS-CM2_zos_month_means_hist1993-2022_fut2071-2100/.zmetadata',
 'leap-persistent/timh37/CMIP6/zos_1x1_month_means/ACCESS-CM2_zos_month_means_hist1993-2022_fut2071-2100/degree/.zarray',
 'leap-persistent/timh37/CMIP6/zos_1x1_month_means/ACCESS-CM2_zos_month_means_hist1993-2022_fut2071-2100/degree/.zattrs',
 'leap-persistent/timh37/CMIP6/zos_1x1_month_means/ACCESS-CM2_zos_month_means_hist1993-2022_fut2071-2100/degree/0',
 'leap-persistent/timh37/CMIP6/zos_1x1_month_means/ACCESS-CM2_zos_month_means_hist1993-2022_fut2071-2100/experiment_id/.zarray',
 'leap-persistent/timh37/CMIP6/zos_1x1_month_means/ACCESS-CM2_zos_month_means_hist1993-2022_fut2071-2100/experiment_id/.zattrs',
 'leap-persistent/timh37/CMIP6/zos