In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import dask
import intake
import fsspec
import os
from collections import defaultdict
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from xmip.utils import google_cmip_col
from xmip.postprocessing import combine_datasets, _match_datasets,_concat_sorted_time
from sklearn.decomposition import PCA

  from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!


In [2]:
def reduce_cat_to_max_num_realizations(cmip6_cat):
    '''Reduce grid labels in pangeo cmip6 catalogue by 
    keeping grid_label and 'ipf' identifier combination with most datasets (=most realizations if using require_all_on)'''
    df = cmip6_cat.df
    cols = df.columns.tolist()
    
    df['ipf'] = [s[s.find('i'):] for s in df.member_id] #extract 'ipf' from 'ripf'

    #generate list of tuples of (source_id,ipf,grid_label) that provide most realizations (note this will omit realizations not available at this grid but possibly at others)
    max_num_ds_tuples = df.groupby(['source_id','ipf'])['grid_label'].value_counts().groupby(level=0).head(1).index.to_list() #head(1) gives (first) max. value since value_counts sorts max to min
    df_filter = pd.DataFrame(max_num_ds_tuples,columns=['source_id','ipf','grid_label']) #generate df to merge catalogue df on
    
    df = df_filter.merge(right=df, on = ['source_id','ipf','grid_label'], how='left') #do the subsetting
    df = df.drop(columns=['ipf']) #clean up
    df= df[cols]

    cmip6_cat.esmcat._df = df #(columns now ordered differently, probably not an issue?)
    return cmip6_cat

def drop_vars_from_cat(cmip6_cat,vars_to_drop):
    cmip6_cat.esmcat._df = cmip6_cat.df.drop(cmip6_cat.df[cmip6_cat.df.variable_id.isin(vars_to_drop)].index).reset_index(drop=True)
    return cmip6_cat

In [3]:
def preselect_years(ddict_in,start_year,end_year):
    '''select range of years of datasets'''
    ddict_out = defaultdict(dict)
    
    assert start_year<end_year
        
    if start_year>2014: #only using SSP
        for k, v in ddict_in.items():
            if 'ssp' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), str(end_year)))
                
    elif end_year<=2014: #only using historical
        for k, v in ddict_in.items():
            if 'historical' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), str(end_year)))
                
    elif ((start_year<=2014) & (end_year>2014)): #using both
        for k, v in ddict_in.items():
            if 'ssp' in k:
                ddict_out[k] = v.sel(time=slice(None, str(end_year)))
            elif 'historical' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), None))
    return ddict_out #NB: may result in no timesteps being selected at all

def drop_duplicate_timesteps(ddict_in):
    ddict_out = ddict_in
    for k, v in ddict_in.items():  
        unique_time, idx = np.unique(v.time,return_index=True)
        
        if len(v.time) != len(unique_time):
            ddict_out[k] = v.isel(time=idx)
            print('Dropping duplicate timesteps for:' + k)   
    return ddict_out

def drop_coords(ddict_in,coords_to_drop):
    for k, v in ddict_in.items():
        ddict_in[k] = v.drop_dims(coords_to_drop,errors="ignore")
    return ddict_in

def align_lonlat(ds_list):
    aligned_ds_list = []
    for ds in ds_list: #list of ds can't seem to be passed to xr.align instead
        a,b = xr.align(ds_list[0],ds,join='override',exclude=['time','member_id'])
        aligned_ds_list.append(b)
    return aligned_ds_list

def concat_members_aligning_lonlat(ds_list):
    aligned_ds_list = align_lonlat(ds_list) #override same-dimension lon/lat prior to concatenating (ensures lon/lats are not padded)
    return xr.concat(aligned_ds_list, dim='member_id', join='outer', coords='minimal',compat='override') #return xr.concat(ds_pick, dim='member_id')

def merge_variables_aligning_lonlat(ds_list):
    aligned_ds_list = align_lonlat(ds_list) #override same-dimension lon/lat prior to concatenating (ensures lon/lats are not padded)
    return xr.merge(aligned_ds_list, join='outer',compat='override')

In [4]:
'''
my_models = ['BCC-CSM2-MR','CESM2','CESM2-WACCM','CMCC-ESM2','CMCC-CM2-SR5','EC-Earth3',
                'GFDL-CM4','GFDL-ESM4','HadGEM3-GC31-MM','MIROC6','MPI-ESM1-2-HR','MRI-ESM2-0',
                'NorESM2-MM','TaiESM1']
'''
my_models = ['MPI-ESM1-2-HR']

col = google_cmip_col()

cat_data_ssp245 = col.search( #find instances providing all required variables for both historical & ssp245
    source_id=my_models,
    experiment_id=['historical','ssp245'],
    table_id='day',
    variable_id=['sfcWind','psl','pr'],
    require_all_on=['source_id', 'member_id','grid_label']
)


cat_data_ssp585 = col.search( #find instances providing all required variables for both historical & ssp585
    source_id=my_models,
    experiment_id=['historical','ssp585'],
    table_id='day',
    variable_id=['sfcWind','psl','pr'],
    require_all_on=['source_id', 'member_id','grid_label']
)
cat_data = cat_data_ssp585
cat_data.esmcat._df = pd.concat([cat_data_ssp245.df,cat_data_ssp585.df],ignore_index=True).drop_duplicates(ignore_index=True) #all instances we want

cat_data = reduce_cat_to_max_num_realizations(cat_data) #per model, select grid and 'ipf' combination providing most realizations
cat_data = drop_vars_from_cat(cat_data,['pr']) #we query only instances that also provide 'pr', but don't process 'pr' it in this script

In [5]:
cat_data.esmcat.aggregation_control.groupby_attrs = [] #to circumvent aggregate=false bug

kwargs = {
    'zarr_kwargs':{
        'consolidated':True,
        'use_cftime':True
    },
    'aggregate':True #to avoid this issue: https://github.com/intake/intake-esm/issues/496
    #doesn't actually aggregate if we set cmip6_cat.esmcat.aggregation_control.groupby_attrs = []
}

ddict = cat_data.to_dataset_dict(**kwargs) #open datasets into dictionary


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'


  ddict = cat_data.to_dataset_dict(**kwargs) #open datasets into dictionary


In [6]:
ddict = drop_duplicate_timesteps(ddict) #CESM2-WACCM has duplicate timeseries
ddict = preselect_years(ddict,1850,2100) #for now, limit analysis to up to 2100
ddict = drop_coords(ddict,['bnds'])

Member concatenation runs into memory issues for more than 26 members for EC-Earth3.

In [7]:
with dask.config.set(**{'array.slicing.split_large_chunks': True}): #join=outer pads NaNs which result in large chunks for timeseries that differ in length
    ddict_merged = combine_datasets(ddict,merge_variables_aligning_lonlat,match_attrs=['source_id', 'grid_label', 'experiment_id', 'table_id','variant_label'])
    #ddict_concat = combine_datasets(ddict_merged,concat_members_aligning_lonlat,match_attrs=['source_id', 'grid_label', 'experiment_id', 'table_id'])

Sanity-check number of timesteps:

```python
for k,v in ddict_concat.items():
    num_days = (v.time[-1]-v.time[0]).dt.days
    assert (len(v.time) > .9*num_days) & (len(v.time) < 1.1*num_days)
```

Do the regridding & subsetting. First, determine the grid coordinates around each tide gauge:

In [8]:
mlrcoefs = xr.open_dataset('/home/jovyan/CMIP6cf/gssr_coefs_1degRes_forcing.nc') #contains coordinates of and MLR coefficients at TGs

era5_grid = xr.Dataset( #the ERA5 grid used to derive the MLR coefficients
        {   "longitude": (["longitude"], np.arange(-40,30,1)+1/2, {"units": "degrees_east"}),
            "latitude": (["latitude"], np.arange(70,10,-1)-1/2, {"units": "degrees_north"}),})

#get coordinates of n x n degree grids around each tide gauge
num_degr = 2
lat_ranges = np.zeros((len(mlrcoefs.tg),2))
lon_ranges = np.zeros((len(mlrcoefs.tg),2))

for t,tg in enumerate(mlrcoefs.tg.values):
    lat_ranges[t,:] = era5_grid.latitude[((era5_grid.latitude>=(mlrcoefs.sel(tg=tg).lat-num_degr/2)) & (era5_grid.latitude<=(mlrcoefs.sel(tg=tg).lat+num_degr/2)))][0:2]
    lon_ranges[t,:] = era5_grid.longitude[((era5_grid.longitude>=(mlrcoefs.sel(tg=tg).lon-num_degr/2)) & (era5_grid.longitude<=(mlrcoefs.sel(tg=tg).lon+num_degr/2)))][0:2]

#create da's to index the CMIP6 files with:
lons_da = xr.DataArray(lon_ranges,dims=['tg','lon_around_tg'],coords={'tg':mlrcoefs.tg,'lon_around_tg':[0,1]})
lats_da = xr.DataArray(lat_ranges,dims=['tg','lat_around_tg'],coords={'tg':mlrcoefs.tg,'lat_around_tg':[0,1]})

target_grid = era5_grid.sel(latitude=slice(np.max(lats_da)+2,np.min(lats_da)-2),longitude=slice(np.min(lons_da)-2,np.max(lons_da)+2))

Now, loop over each model/grid (should be only 1 grid per model), and apply the model/grid-specific interpolation weights to each dataset belonging to that model/grid

**Note:** It is probably more memory-efficient to loop over the tide gauges and regrid to each tide-gauge specific 2 by 2 degree grid, especially if considering larger domains. For now, we regrid to a subset of the ERA5 grid big enough to contain all tide-gauge specific grids (~20x20 grid points), and index the regridded dataset at the tide-gauge specific grids:

In [9]:
for key,ds in ddict_merged.items():
    lon_coord = list(k for k in ds.dims if 'lon' in k)[0] #find lon/lat coordinate names
    ds.coords[lon_coord] = ((ds.coords[lon_coord] + 180) % 360) - 180 #wrap around 0
    ds = ds.reindex({ lon_coord : np.sort(ds[lon_coord])})
    ddict_merged[key] = ds

ds_dict = {k: v for k, v in ddict_merged.items()}
ddict_subsetted = {k: v for k, v in ddict_merged.items()}

while len(ds_dict) > 0: #<- from xmip's combine_datasets
    k = list(ds_dict.keys())[0]
    ds = ds_dict.pop(k)
    ds.attrs["original_key"] = k
    
    matched_datasets = _match_datasets(ds, ds_dict, ['source_id', 'grid_label'], pop=True) #find datasets belonging to same model/grid
  
    regridder = xe.Regridder(matched_datasets[0],target_grid,'bilinear',ignore_degenerate=True) #interpolation weights for this grid
    
    for matched_ds in matched_datasets:
        first_ds,aligned_ds = xr.align(matched_datasets[0],matched_ds,join='override',exclude=['time','member_id','dcpp_init_year']) #make sure lon/lat coordinates are exactly the same
        
        aligned_ds = aligned_ds.isel(dcpp_init_year=0,drop=True) #get rid of dcpp_init_year dimension
        ddict_subsetted[matched_ds.original_key] = regridder(aligned_ds,keep_attrs=True).sel(latitude=lats_da,longitude=lons_da) #regrid to target grid and add to ddict

Store the datasets (directories structured per model):

In [10]:
for key,ds in tqdm(ddict_subsetted.items()):
    model_path = os.path.join('/home/jovyan/CMIP6cf/output/subsetted_forcing/',ds.source_id)
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    ds.to_netcdf(os.path.join(model_path,key.replace('.','_')+'.nc'),mode='w')

  0%|          | 0/6 [00:00<?, ?it/s]

Next, derive surges from the atmospheric forcing:
- combine the historical and ssp data (**need custom combination if quering multiple SSPs at once?**):

In [None]:
ddict_forcing = combine_datasets(ddict_subsetted,
                        _concat_sorted_time,
                       match_attrs =['source_id', 'grid_label','table_id'] ) #appends SSP to historical

- generate the normalized forcing:

In [None]:
#generate forcing to compute surges with
for k,v in ddict_forcing.items():
    attrs = v.attrs
    
    v['sfcWind_sqd'] = v['sfcWind']**2 #add wind squared
    v['sfcWind_cbd'] = v['sfcWind']**3 #add wind cubed
    
    v = (v-v.mean(dim='time'))/v.std(dim='time',ddof=0) #normalize
    v.attrs = attrs
    
    #concatenate & stack normalized forcing variables to data array with shape (time,(4 variables * num_degr * num_degr))
    v['forcing'] = v[["psl", "sfcWind", "sfcWind_sqd","sfcWind_cbd"]].to_array(dim="forcing_var") 
    v['forcing'] = v['forcing'].transpose("time","forcing_var","lon_around_tg",...).stack(f=['forcing_var','lon_around_tg','lat_around_tg'],create_index=False)
    ddict_forcing[k]=v

- derive the principal components and multiply with regression coefficients derived from ERA5:

In [None]:
ddict_surges = defaultdict(dict)
for k,ds in ddict_forcing.items(): #loop over datasets
     
    ds['surges'] = ( ('member_id','time','tg'), np.nan*np.zeros( (len(ds.member_id),len(ds.time),len(ds.tg)) )) #initialize output
            
    for i_member,member in tqdm(enumerate(ds.member_id)):
        forcing_mem = ds.forcing.sel(member_id=member).load() #load forcing data array into memory (for all tg for current dataset and member)
        
        #^^^ is it necessary to remove forcing_mem from memory at the end of the loop or is this taken care of by overwriting it with the next member?
        
        for i_tg,tg in enumerate(ds.tg):
            #get model forcing at TG
            forcing_tg = forcing_mem.sel(tg=tg) 
            
            #get MLR coefficients at TG
            tg_coefs = mlrcoefs.mlrcoefs.sel(tg=tg)
            num_pcs = int(np.sum(np.isfinite(tg_coefs)))-1 #number of coefs = number of PCs to derive, intercept doesn't count
            
            i_timesteps_w_data = np.argwhere(np.isfinite(forcing_tg.data).all(axis=1)).flatten()
            
            #get principal components (using sklearn to keep deterministic signs consistent)
            pca = PCA(num_pcs)
            pca.fit(forcing_tg.isel(time=i_timesteps_w_data).data) #remove missing values for PCA
            pcs = pca.transform(forcing_tg.isel(time=i_timesteps_w_data).data)
            
            #multiply with ERA5 regression coefficients to compute surges
            ds['surges'][i_member,i_timesteps_w_data,i_tg] = np.sum(tg_coefs[np.isfinite(tg_coefs)].values * np.column_stack((np.ones(pcs.shape[0]),pcs)),axis=1) 
            
    ddict_surges[k] = ds['surges'].assign_coords(lon=('tg', mlrcoefs.lon.data),lat=('tg', mlrcoefs.lat.data)).assign_attrs(ds.attrs)
    #probably want to store the surges and close the dataset here to avoid overflowing the memory?        

In [None]:
ddict_surges

**NB: not sure if it is possible to do this without a loop? The alternative is 'eofs.xarray', but I don't think that can be applied along axes either, and that results in different signs for the principal components because of sklearn.svd_flip()**

**How to make sure this loop doesn't overflow the memory for large queries?**

=> Takes about 6 minutes per model per member for historical+ssp (1850-2100), or 4 minutes if I preselect (1950-2100). I think we will analyse approximately 150 members in total times two scenarios, so that would take 20-30h? Takes longer for higher resolution models with larger file sizes (e.g., EC-Earth3).

Code for listing the instance_id of runs with missing timesteps:
```python
for k,v in ddict.items():
    if (len(v.time) < len(np.arange(1850,2015))*12*30):
    #if len(v.time) < len(np.arange(2015,2101))*12*30:
        print(v.attrs['intake_esm_attrs:zstore'][0:-1].replace('gs://cmip6/','').replace('/','.'))
```


    