In [2]:
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import dask
import intake
import fsspec
import os
from collections import defaultdict
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from xmip.utils import google_cmip_col
from xmip.preprocessing import combined_preprocessing,_drop_coords
from xmip.postprocessing import merge_variables, combine_datasets, concat_experiments,_concat_sorted_time
from sklearn.decomposition import PCA

In [3]:
def regrid_to_era5(ds,era5_grid):
    """wrapper around xesmf regridding"""
    regridder = xe.Regridder(ds,era5_grid,'bilinear',ignore_degenerate=True)
    
    return regridder(ds,keep_attrs=True)

def shorten_ssp_runs(ddict_in,end_year):
    ddict_out = ddict_in
    for k, v in ddict_in.items():
        if 'ssp' in k:
            ddict_out[k] = v.sel(time=slice(None, str(end_year)))
        else:
            ddict_out[k] = v
    return ddict_out

def shorten_historical_runs(ddict_in,start_year):
    ddict_out = ddict_in
    for k, v in ddict_in.items():
        if 'historical' in k:
            ddict_out[k] = v.sel(time=slice(str(start_year), None))
        else:
            ddict_out[k] = v
    return ddict_out

def drop_duplicate_timesteps(ddict_in):
    ddict_out = ddict_in
    for k, v in ddict_in.items():
        
        unique_time, idx = np.unique(v.time,return_index=True)
        
        if len(v.time) != len(unique_time):
            ddict_out[k] = v.isel(time=idx)
            print('Dropping duplicate timesteps for:' + k)
            
    return ddict_out

def drop_coords(ddict_in,coords_to_drop):
    
    for k, v in ddict_in.items():
        
        ddict_in[k] = v.drop_dims(coords_to_drop,errors="ignore")
          
    return ddict_in

def concat_realizations_most_common_ipf(ds_list):
    '''custom function that concatenates only the realizations of the most common 'ipf' combination,
    takes the first sorted 'ipf' if multiple 'ipf' are equally common'''
    member_ids = [ds.member_id.data[0] for ds in ds_list]
    
    member_ids.sort() #often i1 is the baseline?

    ipf_ids = [s[s.find('i'):] for s in member_ids] #separate 'ipf' from 'r'
    from collections import Counter

    most_common_ipf = Counter(ipf_ids).most_common()[0][0]

    # find unique members and decide which values of 'ipf' give the most members/variants?
    # pick only the matching datasets from the list
    ds_pick = [ds for ds in ds_list if (most_common_ipf in ds.member_id.data[0])]
    
    #not ideal way to ensure coordinates are the same, otherwise differences in coordinates are padded with nans if using {join='outer'}
    for idx,ds in enumerate(ds_pick):
        if idx==0:
            lat = ds['lat']
            lon = ds['lon']
        if ((ds['lat'].shape==lat.shape) & (ds['lon'].shape==lon.shape)):
            ds['lat'] = lat
            ds['lon'] = lon
        ds_pick[idx] = ds
        
    return xr.concat(ds_pick, dim='member_id', join='outer', coords='minimal',compat='override') #return xr.concat(ds_pick, dim='member_id')

In [4]:
'''
my_models = ['BCC-CSM2-MR','CESM2','CESM2-WACCM','CMCC-ESM2','CMCC-CM2-SR5','EC-Earth3',
                'GFDL-CM4','GFDL-ESM4','HadGEM3-GC31-MM','MIROC6','MPI-ESM1-2-HR','MRI-ESM2-0',
                'NorESM2-MM','TaiESM1']

'''
my_models = ['MPI-ESM1-2-HR']
col = google_cmip_col()
experiment_id='ssp585'
source_id = my_models
kwargs = {
    'zarr_kwargs':{
        'consolidated':True,
        'use_cftime':True
    },
    'aggregate':False
}

cat_data = col.search(
    source_id=source_id,
    experiment_id=['historical','ssp585'],
    table_id='day',
    variable_id=['psl','sfcWind'],
    member_id=['r1i1p1f1','r2i1p1f1'],
    require_all_on=['source_id', 'member_id','grid_label']
)
ddict = cat_data.to_dataset_dict(**kwargs)
#ddict = cat_data.to_dataset_dict(**kwargs,preprocess=combined_preprocessing) # a lot of 'renaming failed' warnings here


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'


  ddict = cat_data.to_dataset_dict(**kwargs)


**NB: I don't seem to need any preprocessing. If I turn it on I get a lot of renaming failed warnings.**

In [11]:
ddict = shorten_ssp_runs(ddict,2100)
ddict = shorten_historical_runs(ddict,1950)#may also want to shorten historical, may not need 1850-1950?
ddict = drop_duplicate_timesteps(ddict) #CESM2-WACCM has duplicate timeseries\
ddict = drop_coords(ddict,['bnds'])

Code for listing the instance_id of runs with missing timesteps:
```python
for k,v in ddict.items():
    if (len(v.time) < len(np.arange(1850,2015))*12*30):
    #if len(v.time) < len(np.arange(2015,2101))*12*30:
        print(v.attrs['intake_esm_attrs:zstore'][0:-1].replace('gs://cmip6/','').replace('/','.'))
```


    

In [12]:
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ddict_merged = merge_variables(ddict,merge_kwargs={'join':'outer'}) #produces large chunks

In [13]:
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ddict_concat = combine_datasets(
        ddict_merged,
        concat_realizations_most_common_ipf,
        match_attrs=['source_id', 'grid_label', 'experiment_id', 'table_id']
    )
#NB: This leaves multiple datasets for the same model with different grid labels. Probably need function to keep grid label with most variants?
# Since this occurs only for a few models it is probably OK to just save these and remove them later?

**NB: Padding with NaNs above creates large time chunks for runs with missing timesteps. Adding the missing timesteps to Google Cloud would solve this?**

**NB: `concat_realizations_most_common_ipf()` keeps multiple datasets for the same model with different grid labels. May be useful to filter these out too later, but it only occurs for a few models**

Do the regridding & subsetting:

In [14]:
mlrcoefs = xr.open_dataset('/home/jovyan/CMIP6cf/gssr_coefs_1degRes_forcing.nc') #contains coordinates of and MLR coefficients at TGs

era5_grid = xr.Dataset( #the ERA5 grid used to derive the MLR coefficients
        {
            "longitude": (["longitude"], np.arange(-40,30,1)+1/2, {"units": "degrees_east"}),
            "latitude": (["latitude"], np.arange(70,10,-1)-1/2, {"units": "degrees_north"}),
        }
    )

#get coordinates of n x n degree grids around each tide gauge
num_degr = 2
lat_ranges = np.zeros((len(mlrcoefs.tg),2))
lon_ranges = np.zeros((len(mlrcoefs.tg),2))

for t,tg in enumerate(mlrcoefs.tg.values):
    lat_ranges[t,:] = era5_grid.latitude[((era5_grid.latitude>=(mlrcoefs.sel(tg=tg).lat-num_degr/2)) & (era5_grid.latitude<=(mlrcoefs.sel(tg=tg).lat+num_degr/2)))][0:2]
    lon_ranges[t,:] = era5_grid.longitude[((era5_grid.longitude>=(mlrcoefs.sel(tg=tg).lon-num_degr/2)) & (era5_grid.longitude<=(mlrcoefs.sel(tg=tg).lon+num_degr/2)))][0:2]

#create da's to index the CMIP6 files with:
lons_da = xr.DataArray(lon_ranges,dims=['tg','lon_around_tg'],coords={'tg':mlrcoefs.tg,'lon_around_tg':[0,1]})
lats_da = xr.DataArray(lat_ranges,dims=['tg','lat_around_tg'],coords={'tg':mlrcoefs.tg,'lat_around_tg':[0,1]})

In [15]:
ddict_subsetted = {k: v for k, v in ddict_concat.items()} #copy dictionary with concatenated realizations
for key,ds in tqdm(ddict_subsetted.items()):
    
    ds = ds.isel(dcpp_init_year=0,drop=True) #get rid of dcpp_init_year dimension
    
    #change longitude coordinates to -180 -> 180 (avoids getting NaNs at the 0-meridian)
    lon_coord = list(k for k in ds.dims if 'lon' in k)[0] #find lon/lat coordinate names
    ds.coords[lon_coord] = ((ds.coords[lon_coord] + 180) % 360) - 180 #wrap around 0
    ds = ds.reindex({ lon_coord : np.sort(ds[lon_coord])})
    
    ds = regrid_to_era5(ds,era5_grid).sel(latitude=lats_da,longitude=lons_da) #regrid to ERA5 grid and index at n x n degree grids around each TG
    ddict_subsetted[key] = ds

  0%|          | 0/2 [00:00<?, ?it/s]

**NB: If desired, we could store the subsetted data at this point?**

Next, derive surges from the atmospheric forcing:
- combine the historical and ssp data (**need custom combination if quering multiple SSPs at once?**):

In [16]:
ddict_forcing = combine_datasets(ddict_subsetted,
                        _concat_sorted_time,
                       match_attrs =['source_id', 'grid_label','table_id'] ) #appends SSP to historical

- generate the normalized forcing:

In [17]:
#generate forcing to compute surges with
for k,v in ddict_forcing.items():
    attrs = v.attrs
    
    v['sfcWind_sqd'] = v['sfcWind']**2 #add wind squared
    v['sfcWind_cbd'] = v['sfcWind']**3 #add wind cubed
    
    v = (v-v.mean(dim='time'))/v.std(dim='time',ddof=0) #normalize
    v.attrs = attrs
    
    #concatenate & stack normalized forcing variables to data array with shape (time,(4 variables * num_degr * num_degr))
    v['forcing'] = v[["psl", "sfcWind", "sfcWind_sqd","sfcWind_cbd"]].to_array(dim="forcing_var") 
    v['forcing'] = v['forcing'].transpose("time","forcing_var","lon_around_tg",...).stack(f=['forcing_var','lon_around_tg','lat_around_tg'],create_index=False)
    ddict_forcing[k]=v

- derive the principal components and multiply with regression coefficients derived from ERA5 (**need to work on a way to handle NaNs in timeseries**):

In [18]:
ddict_surges = defaultdict(dict)
for k,ds in tqdm(ddict_forcing.items()): #loop over datasets
    forcing = ds.forcing.load() #load forcing data array into memory (for all member_id & tg for current dataset)
    
    ds['surges'] = ( ('member_id','time','tg'), np.nan*np.zeros( (len(ds.member_id),len(ds.time),len(ds.tg)) )) #initialize output
            
    for imember,member in enumerate(ds.member_id):
        forcing_mem = forcing.sel(member_id=member)

        for itg,tg in enumerate(ds.tg):
            #get model forcing at TG
            forcing_tg = forcing_mem.sel(tg=tg) 
            
            #get MLR coefficients at TG
            tg_coefs = mlrcoefs.mlrcoefs.sel(tg=tg)
            num_pcs = int(np.sum(np.isfinite(tg_coefs)))-1 #number of coefs = number of PCs to derive, intercept doesn't count

            #get principal components (using sklearn to keep deterministic signs consistent)
            pca = PCA(num_pcs)
            pca.fit(forcing_tg.data)
            pcs = pca.transform(forcing_tg.data)
            
            #multiply with ERA5 regression coefficients to compute surges
            ds['surges'][imember,:,itg] = np.sum(tg_coefs[np.isfinite(tg_coefs)].values * np.column_stack((np.ones(pcs.shape[0]),pcs)),axis=1) 
            
    ddict_surges[k] = ds['surges'].assign_coords(lon=('tg', mlrcoefs.lon.data),lat=('tg', mlrcoefs.lat.data)).assign_attrs(ds.attrs)
    #+store?        

  0%|          | 0/1 [00:00<?, ?it/s]

**NB: not sure if it is possible to do this without a loop? The alternative is 'eofs.xarray', but I don't think that can be applied along axes either, and that results in different signs for the principal components because of sklearn.svd_flip()**

=> Takes about 12 minutes per model per member for historical+ssp (1850-2100), or 8 minutes if I preselect (1950-2100). I think we will analyse approximately 150 members in total times two scenarios, so that would take 40-60h?