In [1]:
'''script to regrid CMIP6 datatsets to target grid and store them'''
import numpy as np
import xarray as xr
import dask
import os
import intake
import pandas as pd
from collections import defaultdict
from tqdm.autonotebook import tqdm
from xmip.utils import google_cmip_col
from xmip.postprocessing import combine_datasets,_concat_sorted_time, merge_variables
from cmip_catalogue_operations import reduce_cat_to_max_num_realizations, drop_vars_from_cat, drop_older_versions
from cmip_ds_dict_operations import select_period, pr_flux_to_m, drop_duplicate_timesteps, drop_coords, drop_incomplete
import xesmf as xe
import gcsfs
fs = gcsfs.GCSFileSystem() #list stores, stripp zarr from filename, load 

  from tqdm.autonotebook import tqdm


In [2]:
#configure settings
output_path = 'gs://leap-persistent/timh37/HighResMIP/timeseries_eu_gesla2_tgs/'
overwrite_existing = False #whether or not to process files for which output already exists in the output path

target_grid = xr.Dataset( #grid to interpolate CMIP6 simulations to
        {   "longitude": (["longitude"], np.arange(-30,22.5,1.5), {"units": "degrees_east"}),
            "latitude": (["latitude"], np.arange(70,30,-1.5), {"units": "degrees_north"}),})

tg_coords = xr.open_dataset('/home/jovyan/CMIP6cex/cmip6_processing/gssr_mlr_coefs_1p5_9deg_gesla2.nc')

query_vars = ['sfcWind','pr'] #variables to process
required_vars = ['sfcWind','pr','psl'] #variables that includes models should provide
freq = 'day'

In [9]:
#query simulations & manipulate data catalog:
col = google_cmip_col() #google cloud catalog

#cat = col.search(activity_id='HighResMIP',table_id='6hrPlevPt',variable_id=['psl','sfcWind','vas','uas'],experiment_id = [' highres-future','highresSST-present','hist-1950','highresSST-future'])
#cat = col.search(activity_id='HighResMIP',table_id='6hrPlevPt',variable_id=['psl','uas','vas'],source_id='GFDL-CM4C192',require_all_on=['member_id','grid_label','experiment_id'])
cat = col.search(activity_id='HighResMIP',table_id=freq,experiment_id=['highresSST-present','highresSST-future'],variable_id=['psl','pr','sfcWind'],source_id='GFDL-CM4C192',require_all_on=['member_id','grid_label','experiment_id'])
cat = drop_vars_from_cat(cat,[k for k in required_vars if k not in query_vars]) #only process desired variables

kwargs = {'zarr_kwargs':{'consolidated':True,'use_cftime':True},'aggregate':False} #keyword arguments for generating dictionary of datasets from cmip6 catalogue
ddict = cat.to_dataset_dict(**kwargs) #open datasets into dictionary

ddict = pr_flux_to_m(ddict) #convert pr flux to accumulated pr if pr in ds
ddict = drop_duplicate_timesteps(ddict) #remove duplicate timesteps from ds if present
ddict = drop_coords(ddict,['bnds','nbnd','height']) #remove some unused auxiliary coordinates
    
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    hist_fut = combine_datasets(ddict,_concat_sorted_time,match_attrs =['source_id', 'grid_label','table_id','variant_label','variable_id'],combine_func_kwargs={'join':'inner','coords':'minimal'})    

hist_fut = drop_duplicate_timesteps(hist_fut) #remove overlap between historical and ssp experiments which sometimes exists
hist_fut = drop_incomplete(hist_fut) #remove historical+ssp timeseries which are not montonically increasing or have large timegaps (based on Julius Buseckes rudimentary testing in CMIP6-LEAP-feadstock)
        


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'


In [21]:
regridded_datasets = defaultdict(dict)
for key,ds in tqdm(hist_fut.items()):
    variable = key.split('.')[-1]

    output_fn = os.path.join(output_path,variable,freq,ds.source_id,key+'.highresSST-pres_fut')

    try:
        if overwrite_existing or not fs.exists(output_fn):
            ds.attrs["time_concat_key"] = key+'.highresSST-pres_fut' #add current key information to attributes
            ds = ds.isel(dcpp_init_year=0,drop=True) #remove this coordinate
            
            datasets=[]
            for tg in range(len(tg_coords.tg)):
                single_tg = tg_coords.isel(tg=[tg])
                regridder = xe.Regridder(ds,single_tg, method='nearest_s2d')
                datasets.append(regridder(ds,keep_attrs=True).squeeze())
                
            ds_at_tgs = xr.concat(datasets, dim='tg')
            ds_at_tgs['tg'] = tg_coords.tg.values
            ds_at_tgs = ds_at_tgs.unify_chunks().chunk({'time':100000,'tg':200})
            ds_at_tgs['tg'] = ds_at_tgs.tg.astype('str') #something wrong with encoding object types in zarr, this is the work-around
    
            ds_at_tgs.to_zarr(output_fn,mode='w') 
            
            ds_at_tgs.close()
        else:
            continue
    except:
        pass

  0%|          | 0/3 [00:00<?, ?it/s]