In [None]:
import numpy as np
import xarray as xr
import dask
import intake
import pandas as pd
import os
from collections import defaultdict
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from xmip.utils import google_cmip_col
from xmip.postprocessing import combine_datasets, _match_datasets,_concat_sorted_time

In [None]:
def reduce_cat_to_max_num_realizations(cmip6_cat):
    '''Reduce grid labels in pangeo cmip6 catalogue by 
    keeping grid_label and 'ipf' identifier combination with most datasets (=most realizations if using require_all_on)'''
    df = cmip6_cat.df
    cols = df.columns.tolist()
    
    df['ipf'] = [s[s.find('i'):] for s in df.member_id] #extract 'ipf' from 'ripf'

    #generate list of tuples of (source_id,ipf,grid_label) that provide most realizations (note this will omit realizations not available at this grid but possibly at others)
    max_num_ds_tuples = df.groupby(['source_id','ipf'])['grid_label'].value_counts().groupby(level=0).head(1).index.to_list() #head(1) gives (first) max. value since value_counts sorts max to min
    df_filter = pd.DataFrame(max_num_ds_tuples,columns=['source_id','ipf','grid_label']) #generate df to merge catalogue df on
    
    df = df_filter.merge(right=df, on = ['source_id','ipf','grid_label'], how='left') #do the subsetting
    df = df.drop(columns=['ipf']) #clean up
    df= df[cols]

    cmip6_cat.esmcat._df = df #(columns now ordered differently, probably not an issue?)
    return cmip6_cat

def drop_vars_from_cat(cmip6_cat,vars_to_drop):
    cmip6_cat.esmcat._df = cmip6_cat.df.drop(cmip6_cat.df[cmip6_cat.df.variable_id.isin(vars_to_drop)].index).reset_index(drop=True)
    return cmip6_cat

In [None]:
def preselect_years(ddict_in,start_year,end_year):
    '''select range of years of datasets'''
    ddict_out = defaultdict(dict)
    
    assert start_year<end_year
        
    if start_year>2014: #only using SSP
        for k, v in ddict_in.items():
            if 'ssp' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), str(end_year)))
                
    elif end_year<=2014: #only using historical
        for k, v in ddict_in.items():
            if 'historical' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), str(end_year)))
                
    elif ((start_year<=2014) & (end_year>2014)): #using both
        for k, v in ddict_in.items():
            if 'ssp' in k:
                ddict_out[k] = v.sel(time=slice(None, str(end_year)))
            elif 'historical' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), None))
    return ddict_out #NB: may result in no timesteps being selected at all

def pr_units_to_m(ddict_in):
    ddict_out = ddict_in
    for k, v in ddict_in.items():
        assert v.pr.units == 'kg m-2 s-1'
        
        with xr.set_options(keep_attrs=True): #convert 'kg m-2 s-1' to daily accumulated 'm'
            v['pr'] = 24*3600*v['pr']/1000 #multiply by number of seconds in a day to get to kg m-2, and divide by density (kg/m3) to get to m    
        v.pr.attrs['units'] = 'm'
        
        ddict_out[k] = v
    return ddict_out

def drop_duplicate_timesteps(ddict_in):
    ddict_out = ddict_in
    for k, v in ddict_in.items():  
        unique_time, idx = np.unique(v.time,return_index=True)
        
        if len(v.time) != len(unique_time):
            ddict_out[k] = v.isel(time=idx)
            print('Dropping duplicate timesteps for:' + k)   
    return ddict_out

def drop_coords(ddict_in,coords_to_drop):
    for k, v in ddict_in.items():
        ddict_in[k] = v.drop_dims(coords_to_drop,errors="ignore")
    return ddict_in

def align_lonlat(ds_list):
    aligned_ds_list = []
    for ds in ds_list: #list of ds can't seem to be passed to xr.align instead
        a,b = xr.align(ds_list[0],ds,join='override',exclude=['time','member_id'])
        aligned_ds_list.append(b)
    return aligned_ds_list

def merge_variables_aligning_lonlat(ds_list):
    aligned_ds_list = align_lonlat(ds_list) #override same-dimension lon/lat prior to concatenating (ensures lon/lats are not padded)
    return xr.merge(aligned_ds_list, join='outer',compat='override')

In [None]:
def select_gridcells_nearest_to_tgs(tg_ds,ds):
    '''
    tg_ds = xr.DataSet containing 'lon' and 'lat' coordinates of tide gauges
    ds    = xr.DataSet containing CMIP6 data to subset
    '''
    
    lon_name = list(k for k in ds.dims if 'lon' in k)[0] #find lon/lat coordinate names
    lat_name = list(k for k in ds.dims if 'lat' in k)[0]
    
    #compute distances between TG coordinates and grid cell centers
    distances = 2*np.arcsin( np.sqrt(
        np.sin( (np.pi/180) * 0.5*(ds[lat_name]-tg_ds.lat) )**2 +
        np.cos((np.pi/180)*tg_ds.lat)*np.cos((np.pi/180)*ds[lat_name])*np.sin((np.pi/180)*0.5*(ds[lon_name]-tg_ds.lon))**2) )
    
    idx_nearest = distances.argmin(dim=[lon_name,lat_name]) #find indices of nearest grid cells
    ds_subsetted = ds[idx_nearest] #subset ds at nearest grid cells
    
    ds_subsetted = ds_subsetted.rename_vars({'lon':'gridcell_lon','lat':'gridcell_lat'}) #keep coordinates of nearest grid cells
    ds_subsetted = ds_subsetted.assign_coords(lon=tg_ds.lon,lat=tg_ds.lat) #replace coordinates with TG coordinates
    
    return ds_subsetted

In [None]:
'''
my_models = ['BCC-CSM2-MR','CESM2','CESM2-WACCM','CMCC-ESM2','CMCC-CM2-SR5',
                'GFDL-CM4','GFDL-ESM4','HadGEM3-GC31-MM','MIROC6','MPI-ESM1-2-HR','MRI-ESM2-0',
                'NorESM2-MM','TaiESM1']
'''
my_models = ['EC-Earth3']

col = google_cmip_col()

cat_data_ssp245 = col.search( #find instances providing all required variables for both historical & ssp245
    source_id=my_models,
    experiment_id=['historical','ssp245'],
    table_id='day',
    variable_id=['sfcWind','psl','pr'],
    require_all_on=['source_id', 'member_id','grid_label']
)


cat_data_ssp585 = col.search( #find instances providing all required variables for both historical & ssp585
    source_id=my_models,
    experiment_id=['historical','ssp585'],
    table_id='day',
    variable_id=['sfcWind','psl','pr'],
    require_all_on=['source_id', 'member_id','grid_label']
)
cat_data = cat_data_ssp585
cat_data.esmcat._df = pd.concat([cat_data_ssp245.df,cat_data_ssp585.df],ignore_index=True).drop_duplicates(ignore_index=True) #all instances we want

cat_data = reduce_cat_to_max_num_realizations(cat_data) #per model, select grid and 'ipf' combination providing most realizations
cat_data = drop_vars_from_cat(cat_data,['psl','sfcWind']) #we query only instances that also provide 'pr', but don't process 'pr' it in this script

In [None]:
cat_data.esmcat.aggregation_control.groupby_attrs = [] #to circumvent aggregate=false bug

kwargs = {
    'zarr_kwargs':{
        'consolidated':True,
        'use_cftime':True
    },
    'aggregate':True #to avoid this issue: https://github.com/intake/intake-esm/issues/496
    #doesn't actually aggregate if we set cmip6_cat.esmcat.aggregation_control.groupby_attrs = []
}

ddict = cat_data.to_dataset_dict(**kwargs) #open datasets into dictionary

**NB: I don't seem to need any preprocessing. If I turn it on I get a lot of renaming failed warnings.**

In [None]:
ddict = pr_units_to_m(ddict)
ddict = drop_duplicate_timesteps(ddict) #CESM2-WACCM has duplicate timeseries
ddict = preselect_years(ddict,1850,2100)
ddict = drop_coords(ddict,['bnds'])

In [None]:
with dask.config.set(**{'array.slicing.split_large_chunks': True}): #join=outer pads NaNs which result in large chunks for timeseries that differ in length
    ddict_merged = combine_datasets(ddict,merge_variables_aligning_lonlat,match_attrs=['source_id', 'grid_label', 'experiment_id', 'table_id','variant_label'])

Do the subsetting at grid cells nearest to the tide gauges:

In [None]:
mlrcoefs = xr.open_dataset('/home/jovyan/CMIP6cf/gssr_coefs_1degRes_forcing.nc') #contains coordinates of and MLR coefficients at TGs

ddict_at_tgs = defaultdict(dict)

for key,ds in tqdm(ddict_merged.items()):
    ds = ds.isel(dcpp_init_year=0,drop=True)
    ds.attrs["original_key"] = key
    ddict_at_tgs[key] = select_gridcells_nearest_to_tgs(mlrcoefs,ds)

In [None]:
for key,ds in tqdm(ddict_at_tgs.items()):
    model_path = os.path.join('/home/jovyan/CMIP6cf/output/subsetted_pr/',ds.source_id)
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    ds.to_netcdf(os.path.join(model_path,key.replace('.','_')+'.nc'),mode='w')
    ds.close()

^takes about an hour for hist+ssp245+ssp585 excluding EC-Earth3.