In [1]:
import numpy as np
import xarray as xr
import dask
import cftime
import os
import intake
import pandas as pd
from collections import defaultdict
from tqdm.autonotebook import tqdm
from xmip.postprocessing import combine_datasets, concat_members,concat_experiments
import gcsfs
from datetime import datetime
fs = gcsfs.GCSFileSystem() #list stores, stripp zarr from filename, load 

  from tqdm.autonotebook import tqdm


Configure the script:

In [2]:
def get_ddict_of_model_datasets(input_path,ssps):
    #input path must contain directories organized by model
    ddict = defaultdict(dict)
    for model_path in tqdm(fs.ls(input_path)):
        model_datasets = []
        model_ssps = []
        #print(model_path)
        for s,ssp in enumerate(ssps):
            #print(ssp)
            model_fns = ['gs://'+k for k in fs.ls(model_path) if ssp in k]
            if len(model_fns)>0:
                model_ds = xr.open_mfdataset(model_fns,engine='zarr',combine='nested',coords='minimal',compat='override',chunks={'source_id':1,'member_id':1,'lon':1000,'lat':1000,'time':200})
                model_datasets.append(model_ds)
                model_ssps.append(ssp)
        
        if len(model_ssps)>0:
            model_ssps_ds = xr.concat(model_datasets,dim='ssp',compat='override',coords='minimal')
            model_ssps_ds['ssp'] = model_ssps
        
            ddict[str(model_ssps_ds.source_id)]=model_ssps_ds
    return ddict

def get_ddict_of_datasets(input_path,ssps):
    #input path must contain directories organized by model
    ddict = defaultdict(dict)
    for model_path in tqdm(fs.ls(input_path)):
        for ssp in ssps:
            fns = ['gs://'+k for k in fs.ls(model_path) if ssp in k]
            for fn in fns:
                #ds = xr.open_dataset(fn,engine='zarr',use_cftime=True,chunks={'lat':500,'lon':500,'time':200})
                ds = xr.open_dataset(fn,engine='zarr',decode_times=xr.coders.CFDatetimeCoder(use_cftime=True),chunks={'lat':500,'lon':500,'time':200})
                
                ds = ds.expand_dims(dim={'experiment_id':[str(ssp)]})
                ds.attrs['experiment_id'] = ssp
                ddict[fn] = ds
    return ddict
    
def add_years_months_to_coords(ds):
    years = []
    months = []
    for k,ts in enumerate(ds.time.values):
        try: #cant remember what this was for, is this even the right syntax?
            years.append(ts.dt.year)
            months.append(ts.dt.month)
        except:
            years.append(ts.year)
            months.append(ts.month)
    ds = ds.assign_coords({'year':('time',years),'month':('time',months)})
    return ds

def enough_numMonths_in_period(ds,period,tol=.95):
    return len(ds.time) >= tol * 12 * (1+period[-1]-period[0])

def derive_dict_of_anmax(ddict,hist,fut):
    anmax = defaultdict(dict)
    for key,ds in tqdm(ddict.items()):
        #print(key)
        try:
            ds_hist = ds.sel(time=slice(str(hist[0]),str(hist[-1])))
            ds_fut = ds.sel(time=slice(str(fut[0]),str(fut[-1])))
            if not enough_numMonths_in_period(ds_hist,hist) or not enough_numMonths_in_period(ds_fut,fut):
                print('dataset incomplete for historical and/or future periods, not including: '+key)
                continue
            anmax[key] = xr.concat((ds_hist.groupby(ds_hist.time.dt.year).max().mean(dim='year'),
                                       ds_fut.groupby(ds_fut.time.dt.year).max().mean(dim='year')
                                       ),dim='period')
            anmax[key]['period'] = ['hist','fut']
            anmax[key].attrs = ds.attrs
        except: #irregular calendar, issue with selecting years/months
            print('Cannot select years/months due to time index issues, computing these coordinates manually, for: '+key)
            ds = add_years_months_to_coords(ds)
            ds_hist = ds.where((ds.year >= hist[0])&(ds.year <= hist[1]))
            ds_fut = ds.where((ds.year >= fut[0])&(ds.year <= fut[1]))

            if not enough_numMonths_in_period(ds_hist,hist) or not enough_numMonths_in_period(ds_fut,fut):
                print('dataset incomplete for historical and/or future periods, not including: '+key)
                continue
            
            anmax[key] = xr.concat((ds_hist.groupby(ds_hist.year).max().mean(dim='year'),
                                       ds_fut.groupby(ds_fut.year).max().mean(dim='year'),
                                       ),dim='period')
            anmax[key]['period'] = ['hist','fut']
            anmax[key].attrs = ds.attrs
    return anmax

In [3]:
fs.ls('gs://leap-persistent/timh37/CMIP6/')

['leap-persistent/timh37/CMIP6/aslc_1x1',
 'leap-persistent/timh37/CMIP6/aslc_change',
 'leap-persistent/timh37/CMIP6/aslc_psmsl',
 'leap-persistent/timh37/CMIP6/aslc_relative_change',
 'leap-persistent/timh37/CMIP6/ibe_1x1',
 'leap-persistent/timh37/CMIP6/ibe_1x1_month_means',
 'leap-persistent/timh37/CMIP6/ibe_psmsl',
 'leap-persistent/timh37/CMIP6/ibe_psmsl_month_means',
 'leap-persistent/timh37/CMIP6/seasonal_slc',
 'leap-persistent/timh37/CMIP6/timeseries_eu_1p5',
 'leap-persistent/timh37/CMIP6/timeseries_eu_gesla2_tgs',
 'leap-persistent/timh37/CMIP6/zos_1x1',
 'leap-persistent/timh37/CMIP6/zos_1x1_month_means',
 'leap-persistent/timh37/CMIP6/zos_piControl_linfit',
 'leap-persistent/timh37/CMIP6/zos_piControl_quadfit',
 'leap-persistent/timh37/CMIP6/zos_psmsl',
 'leap-persistent/timh37/CMIP6/zos_psmsl_annual_maxima',
 'leap-persistent/timh37/CMIP6/zos_psmsl_month_means']

In [4]:
query_var = 'zos' #variables to process #TODO: add option to add zos to ibe
#ssps = ['ssp585']
ssps = ['ssp126','ssp245','ssp370','ssp585'] #SSPs to process #(TODO: loop over multiple, streamline code!)
zos_path = 'gs://leap-persistent/timh37/CMIP6/zos_psmsl'
ibe_path = 'gs://leap-persistent/timh37/CMIP6/ibe_psmsl'
output_path = 'gs://leap-persistent/timh37/CMIP6/zos_psmsl_annual_maxima'
hist = [1993,2022]
fut = [2071,2100]

Generate dictionary of month means per model:

In [5]:
zos_ddict = {}
zos_ddict = get_ddict_of_datasets(zos_path,ssps) #open datasets (grouped by model) into dictionary

ibe_ddict = {}
ibe_ddict = get_ddict_of_datasets(ibe_path,ssps) #open datasets (grouped by model) into dictionary


#anmax = derive_dict_of_anmax(ddict,hist,fut) #compute month means for each dataset
#anmax = concat_members(anmax) #not sure why this also automatically combines experiment ids, but ok?

  0%|          | 0/41 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [5]:
out_ddict = defaultdict(dict)
for model in np.unique([v.source_id for v in anmax.values()]):
    datasets = []
    for k,v in anmax.items():
        if v.source_id == model:
            datasets.append(v)
    model_dataset = xr.concat(datasets,coords='different',dim='experiment_id')
    out_ddict[str(model)] = model_dataset
    

Store:

In [7]:
for key,ds in tqdm(out_ddict.items()):
    fn = ds.source_id+'_'+query_var+'_anmax_hist'+str(hist[0])+'-'+str(hist[-1])+'_fut'+str(fut[0])+'-'+str(fut[-1])
    
    ds['member_id'] = ds['member_id'].astype('str')
    ds['experiment_id'] = ds['experiment_id'].astype('str')
    ds = ds.assign_coords({'source_id':ds.source_id})
    try:
        ds[query_var] = ds[query_var].chunk({'lat':500,'lon':500,'period':1,'member_id':50,'experiment_id':1})
    except:
        ds[query_var] = ds[query_var].chunk({'id':1000,'period':1,'member_id':200,'experiment_id':1})
    ds.to_zarr(os.path.join(output_path,fn),mode='w',zarr_version=2)
    ds.close()

  0%|          | 0/40 [00:00<?, ?it/s]

  ds.to_zarr(os.path.join(output_path,fn),mode='w',zarr_version=2)
