In [1]:
import numpy as np
import xarray as xr
import dask
import cftime
import os
import intake
import pandas as pd
from collections import defaultdict
from tqdm.autonotebook import tqdm
from xmip.utils import google_cmip_col
from xmip.preprocessing import rename_cmip6, promote_empty_dims, correct_coordinates, broadcast_lonlat, correct_lon, correct_units, fix_metadata,_drop_coords
from xmip.postprocessing import combine_datasets,_concat_sorted_time, match_metrics, _match_datasets
from CMIP6cex.cmip6_processing.cmip_catalogue_operations import reduce_cat_to_max_num_realizations, drop_older_versions
from CMIP6cex.cmip6_processing.cmip_ds_dict_operations import select_period, drop_duplicate_timesteps, drop_coords, drop_incomplete, drop_vars
import xesmf as xe
import gcsfs
fs = gcsfs.GCSFileSystem() #list stores, stripp zarr from filename, load 

  from tqdm.autonotebook import tqdm


Various functionalities:

In [2]:
# create regridders per source_id (datasets of source_id should be on the same grid after passing through reduce_cat_to_max_num_realizations
def create_regridder_dict(dict_of_ddicts, target_grid_ds):

    regridders = {}
    source_ids = np.unique([[ds.attrs['source_id'] for ds in dataset_dict.values()] for dataset_dict in dict_of_ddicts.values()])
    
    for si in tqdm(source_ids):
        for ds_ddict in dict_of_ddicts.values():
            matching_keys = [k for k in ds_ddict.keys() if si in k]
            if len(matching_keys)>0:
                # take the first one (we don't really care here which one we use)
                ds = ds_ddict[matching_keys[0]]
                regridder = xe.Regridder(ds,target_grid_ds,'bilinear',ignore_degenerate=True,periodic=True) #create regridder for this source_id
                regridders[si] = regridder
                continue
    return regridders
    
def partial_combined_preprocessing(ds): #'combined_preprocessing' from xmip is problematic for some datasets
    ds = rename_cmip6(ds) # fix naming
    ds = promote_empty_dims(ds) # promote empty dims to actual coordinates
    ds = correct_coordinates(ds) # demote coordinates from data_variables
    ds = broadcast_lonlat(ds) # broadcast lon/lat
    ds = correct_lon(ds) # shift all lons to consistent 0-360
    ds = correct_units(ds) # fix the units
    ''' part of combined preprocessing
    ds = parse_lon_lat_bounds(ds) # rename the `bounds` according to their style (bound or vertex)
    ds = sort_vertex_order(ds) # sort verticies in a consistent manner
    ds = maybe_convert_bounds_to_vertex(ds) # convert vertex into bounds and vice versa, so both are available
    ds = maybe_convert_vertex_to_bounds(ds)
    '''
    ds = fix_metadata(ds)
    ds = ds.drop_vars(_drop_coords, errors="ignore")
    return ds

def search_cloud(variable_id=None,experiment_id=None,table_id=None,require_all_on=None):
    
    kwargs = locals().copy()
    for k,v in locals().items():
        if v is None:
            kwargs.pop(k)

    def_col = google_cmip_col() #google cloud catalog
    qc_col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/cmip6-pgf-ingestion-test/catalog/catalog.json") #temporary pangeo-leap-forge catalogue
    nonqc_col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/cmip6-pgf-ingestion-test/catalog/catalog_noqc.json")
    #set from which catalogues data is preferred if duplicates are present (higher=higher priority)
    col_df = def_col.df
    qc_df = qc_col.df
    nonqc_df = nonqc_col.df
    
    col_df['prio'] = 1
    qc_df['prio'] = 2
    nonqc_df['prio'] = 0
    
    def_col.esmcat._df = pd.concat([col_df,qc_df,nonqc_df],ignore_index=True) #merge catalogs

    cat = def_col.search(**kwargs)
    cat = drop_older_versions(cat)
    return cat

def generate_dict_of_datasets(cat,models_to_exclude,preprocessing_func):
    cat.esmcat._df = cat.df[['activity_id',	'institution_id',	'source_id',	'experiment_id',	'member_id',	'table_id',	'variable_id',	'grid_label',	'zstore',	'dcpp_init_year','version']]
    for model in models_to_exclude: #preprocessing issues, regarding time coordinate (KIOST) and vertices?
        cat.esmcat._df = cat.esmcat._df.drop(cat.esmcat._df[(cat.esmcat._df.source_id == model)].index)
    cat.esmcat.aggregation_control.groupby_attrs = [] 
    ddict = cat.to_dataset_dict(**{'zarr_kwargs':{'consolidated':True,'use_cftime':True},'aggregate':True,'preprocess':preprocessing_func}) 
    return ddict

def cleanup_datasets_in_dict(ddict):
    ddict = drop_duplicate_timesteps(ddict) #remove duplicate timesteps if present
    ddict = drop_coords(ddict,['vertices_latitude','vertices_longitude']) #remove coords & variables
    ddict = drop_vars(ddict,['vertices_latitude','vertices_longitude'])

    for k,v in ddict.items():
        if 'dcpp_init_year' in v:
            ddict[k] = v.isel(dcpp_init_year=0,drop=True)
    return ddict

def reduce_areacello_cat(areacello_cat):
    #remove duplicate source/grids & PMIP runs with potentially different masks
    areacello_cat.esmcat._df = areacello_cat.df.drop_duplicates(subset=['source_id','grid_label'])[['activity_id',	'institution_id',	'source_id',	'experiment_id',	'member_id',	'table_id',	'variable_id',	'grid_label',	'zstore',	'dcpp_init_year','version']]
    areacello_cat.esmcat._df = areacello_cat.esmcat._df.drop(areacello_cat.esmcat._df[areacello_cat.esmcat._df.activity_id == 'PMIP'].index) #land mask may be different
    return areacello_cat

def _match_twosided_attrs(ds_a, ds_b, attrs_a, attrs_b): #custom version of _match_attrs in xmip that allows to compare differently named attributes between datasets
    """returns the number of matched attrs between two datasets"""
    if len(attrs_a)!=len(attrs_b):
        raise Exception('lists of attributes in each dataset must be of equal length.')
        
    try:
        n_match = sum([ds_a.attrs[attrs_a[i]] == ds_b.attrs[attrs_b[i]] for i in range(len(attrs_a))])
        return n_match
    except:
        return 0

def dedrift_datasets_linearly(ds_ddict,pic_ddict,variable_id,min_numYrs_pic):
    #note: assumed both dataset dicts need to have the same frequency!
    #note: this is different from the default xmip dedrifting because it computes the linear drift over the full piControl length instead of only the part overlapping with the experiment. The reason is that sometimes the piControl simulation is too short to cover all experiments.
    ds_ddict_dedrifted = defaultdict(dict)
    drift_ddict = defaultdict(dict)
    
    datasets_without_pic = []

    attrs_a = ['parent_source_id','grid_label','parent_variant_label']
    attrs_b = ['source_id','grid_label','variant_label']

    for i,ds in tqdm(ssp_ddict.items()):
        #_match_datasets would ideally be used for this, but currently does not take differently named attributes to be matched:

        ## adapted from '_match_datasets'
        matched_datasets = []
        pic_keys = list(pic_ddict.keys())
        for k in pic_keys:
            if _match_twosided_attrs(ds, pic_ddict[k], attrs_a,attrs_b) == len(attrs_a): #
                if len(np.unique(pic_ddict[k].time.dt.year))>=min_numYrs_pic: #length requirement piControl
                    ds_matched = pic_ddict[k]
                    # preserve the original dictionary key of the chosen dataset in attribute.
                    ds_matched.attrs["original_key"] = k
                    matched_datasets.append(ds_matched) #if multiple, we just take the first one for now..
                    
        if len(matched_datasets) == 0:
            datasets_without_pic.append(i)
            #print('No (long enough) piControl found for: '+i)
        else: #do the dedrifting
            pic_ds = matched_datasets[0] #take first matching dataset
            pic_fit = pic_ds[variable_id].polyfit(dim='time',deg=1) #linear fit

            if (pic_ds.time[1]-pic_ds.time[0]).dtype != (ds.time[1]-ds.time[0]).dtype: #check if deltatime units are equal between pic and experiment to be dedrifted
                print('Time units piControl dataset different from dataset to be dedrifted, cannot dedrift: '+i)
                datasets_without_pic.append(i)
                continue
            else:    
                ds_drift = xr.polyval(ds.time,pic_fit) #evaluate fit
                ds_drift = ds_drift - ds_drift.isel(time=0) #remove intercept

                drift_ddict[i] = ds_drift
                
                ds_ddict_dedrifted[i] = ds
                ds_ddict_dedrifted[i][variable_id] = ds[variable_id] - ds_drift.polyfit_coefficients.isel(member_id=0,drop=True) #drop parent member_id because it may differ from the ds member_id
                
    return ds_ddict_dedrifted, drift_ddict, datasets_without_pic

def subtract_ocean_awMean(ds_ddict,variable_id):
    noMean_ddict = defaultdict(dict)
    for k,v in tqdm(ds_ddict.items()):
        if 'areacello' in v:
            v[variable_id] = v[variable_id] - v[variable_id].weighted(v.areacello.fillna(0)).mean(['x','y'])
            noMean_ddict[k] = v
        
    return noMean_ddict

def get_availability(dict_of_ddicts):
    all_models = np.unique([[ds.attrs['source_id'] for ds in dataset_dict.values()] for dataset_dict in dict_of_ddicts.values()])
    availability = defaultdict(dict)
    
    for k,ddict in dict_of_ddicts.items():
        model_members = defaultdict(dict)
        
        for model in all_models:
            members = np.unique([ds.member_id for ds in ddict.values() if ds.source_id == model])
            model_members[str(model)] = list(members)
            
        availability[k] = model_members
    
    availability['all'] = defaultdict(dict)
    for model in all_models:
        model_member_list = []
        for k in dict_of_ddicts.keys():
            model_member_list.append(availability[k][model])
        availability['all'][str(model)] = list(set.intersection(*map(set,model_member_list)))
    return availability

Configure the script:

In [26]:
query_var = 'zos' #variables to process
ssps = ['ssp245']
#ssps = ['ssp126','ssp245','ssp370','ssp585'] #SSPs to process #(TODO: loop over multiple, streamline code!)

regrid = True
target_grid = xr.Dataset(
    {"lat": (["lat"], np.arange(-90, 90, 1), {"units": "degrees_north"}),
     "lon": (["lon"], np.arange(0, 360, 1), {"units": "degrees_east"}),})
target_grid.attrs['name'] = '1x1'

#models to exclude a-priori becaue of preprocessing issues (to be sorted out?)
models_to_exclude = ['AWI-CM-1-1-MR','AWI-ESM-1-1-LR','AWI-CM-1-1-LR','KIOST-ESM']

min_pic_numYears = 150

output_period = ['1950','2500']

output_path = 'gs://leap-persistent/timh37/CMIP6/'

Query datasets, put into dictionaries of datasets, and preprocess:

In [4]:
#search & load piControl dictionary of datasets
pic_cat = search_cloud(query_var,'piControl','Omon',['source_id', 'member_id','grid_label']) #done separately because parent variant (i.e., piControl variant) is not necessarily the same as historical/SSPs variant
pic_ddict = generate_dict_of_datasets(pic_cat,models_to_exclude,partial_combined_preprocessing)
pic_ddict = cleanup_datasets_in_dict(pic_ddict) #do some preprocessing
pic_ddict = drop_incomplete(pic_ddict) #remove historical+pic timeseries which are not montonically increasing or have large timegaps (based on checks in CMIP6-LEAP-feadstock)

#search & load areacello dictionary of datasets
areacello_cat = search_cloud(variable_id='areacello')
areacello_cat = reduce_areacello_cat(areacello_cat)
areacello_ddict = generate_dict_of_datasets(areacello_cat,models_to_exclude,partial_combined_preprocessing)
areacello_ddict = cleanup_datasets_in_dict(areacello_ddict)

#search & load hist+ssp dictionary of datasets
ssp_cats = defaultdict(dict)
for s,ssp in enumerate(ssps):
    cat = search_cloud(query_var,['historical',ssp],'Omon',['source_id', 'member_id','grid_label']) #done per SSP because availability may be different
    ssp_cats[ssp] = cat

#put ssp cats together (AFAIK no other way but to copy an existing catalog and to assign the concatenation of the dataframes inside each separate catalogue as the new dataframe)   
ssp_cats_merged = ssp_cats[ssps[0]] 
ssp_cats_merged.esmcat._df = pd.concat([v.df for k,v in ssp_cats.items()],ignore_index=True).drop_duplicates(ignore_index=True)
ssp_cats_merged = reduce_cat_to_max_num_realizations(ssp_cats_merged) #per model, select grid and 'ipf' combination providing most realizations (needs to be applied to both SSPs together to ensure the same variants are used under both scenarios)

ssp_ddicts = defaultdict(dict) #not sure when/this is needed?
for s,ssp in enumerate(ssps):
    ssp_cat = ssp_cats_merged.search(experiment_id=['historical',ssp],table_id='Omon',variable_id=query_var,require_all_on=['source_id', 'member_id','grid_label']) #retrieve ssp cat from reduced catalogue
    ssp_ddict = generate_dict_of_datasets(ssp_cat,models_to_exclude,partial_combined_preprocessing)
    ssp_ddict = cleanup_datasets_in_dict(ssp_ddict)    
    
    with dask.config.set(**{'array.slicing.split_large_chunks': True}): #concatenate historical and SSP
        ssp_ddict = combine_datasets(ssp_ddict,_concat_sorted_time,match_attrs =['source_id', 'grid_label','table_id','variant_label','variable_id'],combine_func_kwargs={'join':'inner','coords':'minimal','compat':'override'})    
    
    ssp_ddict = drop_duplicate_timesteps(ssp_ddict) #remove overlap between historical and ssp experiments, which sometimes exists, again using 'drop_duplicate_timesteps'

    #intermediate step to drop incomplete time series for hist+ssp, to-do: put in a separate function?
    inconsistent_experiment_calendars = [] #identify if historical and SSP experiments have different calendars, which causes issues later on
    for k,v in ssp_ddict.items():
        try:
            v.time[-1] - v.time[0]
        except: #unify calendars 
            not_prolgreg = np.where(np.array([type(i) for i in v.time.values]) != cftime._cftime.DatetimeProlepticGregorian)[0] #find where calendar is not proleptic gregorian
            converted_time = v.isel(time=not_prolgreg).convert_calendar('proleptic_gregorian',use_cftime=True).time #convert at these indices
            newtime = v.time.values #replace old time index with new values
            newtime[not_prolgreg] = converted_time.values
            ssp_ddict[k]['time'] = newtime
        
    ssp_ddict = drop_incomplete(ssp_ddict) #remove historical+ssp timeseries which are not montonically increasing or have large timegaps (based on checks in CMIP6-LEAP-feadstock
    ssp_ddict = match_metrics(ssp_ddict,areacello_ddict,['areacello']) #add 'areacello' metric, if possible

    ssp_ddict,drift_ddict,keys_without_pic = dedrift_datasets_linearly(ssp_ddict,pic_ddict,query_var,min_pic_numYears)

    ssp_ddict = subtract_ocean_awMean(ssp_ddict,query_var)
    
    ssp_ddicts[ssp] = ssp_ddict #add to dictionary of dictionaries of datasets


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'





--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'





--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'




Dropping duplicate timesteps for:EC-Earth3-Veg.gn.Omon.r5i1p1f1.zos




  0%|          | 0/406 [00:00<?, ?it/s]

  result = blockwise(


  0%|          | 0/348 [00:00<?, ?it/s]

Get an idea of available models & members:

```python
availability = get_availability(ssp_ddicts)
for k,v in availability.items():
    print('')
    print(k)
    for model in np.unique([[ds.attrs['source_id'] for ds in dataset_dict.values()] for dataset_dict in ssp_ddicts.values()]):
        print(str(len(v[model])))
```

Regrid:

In [5]:
if regrid:
    regridder_dict = create_regridder_dict(ssp_ddicts,target_grid)   

    for s,ssp in enumerate(ssps):
        print(ssp)
        ssp_ddict = ssp_ddicts[ssp]
        for key,ds in tqdm(ssp_ddict.items()):
            regridder = regridder_dict[ds.attrs['source_id']] #select regridder for this source_id
            regridded_ds = regridder(ds, keep_attrs=True) #do the regridding
            ssp_ddicts[ssp][key] = regridded_ds

  0%|          | 0/35 [00:00<?, ?it/s]

ssp245


  0%|          | 0/347 [00:00<?, ?it/s]

  result_vars[name] = func(*variable_args)


Store output:

In [None]:
for s,ssp in enumerate(ssps):
    print(ssp)
    ssp_ddict = ssp_ddicts[ssp]
    for key,ds in tqdm(ssp_ddict.items()):
        
        ds = ds.sel(time=slice(output_period[0],output_period[1])) #select output period
        ds_name = key+'.hist_'+ssp+'.'+str(ds.time[0].dt.year.values)+'-'+str(ds.time[-1].dt.year.values) #generate file name

        output_fn = os.path.join(output_path,query_var+['','_'+target_grid.attrs['name']][regrid],ds.source_id,ds_name)

        '''
        #store
        try:
            ds.to_zarr(output_fn,mode='w') #fails if chunks are not uniform due to time concatenation
        except:
            ds[query_var] = ds[query_var].chunk({'time':'auto'})
            ds.to_zarr(output_fn,mode='w')
        '''