In [None]:
import numpy as np
import xarray as xr
import os
from collections import defaultdict
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from xmip.postprocessing import combine_datasets,_concat_sorted_time
from sklearn.decomposition import PCA

Open the subsetted CMIP6 `psl` & `sfcWind` data (**To-do: test if concatenating results in issues for many realizations (e.g., EC-Earth3)**):

In [None]:
in_dir = '/home/jovyan/CMIP6cf/output/subsetted_forcing/'

ddict = defaultdict(dict)

for source_id in [s for s in os.listdir(in_dir) if s.startswith('.')==False]:
    
    experiment_ids = [s.split('_')[2] for s in os.listdir(os.path.join(in_dir,source_id)) if s.startswith('.')==False]
    for experiment_id in set(experiment_ids): #for each experiment_id, open the datasets, concatenating all realizations:
        
        source_ds = xr.open_mfdataset(os.path.join(in_dir,source_id,'*'+experiment_id+'*.nc'),join='outer',combine='nested',
                                      compat='override',coords='minimal',concat_dim='member_id') #need to test this for large np. of realizations, like EC-Earth3
        ddict[source_ds.original_key.rsplit('.',1)[0]] = source_ds
        

Concatenate matching realizations of historical & SSP run in time:

In [None]:
ssps = set([k.split('.')[2] for k in ddict.keys() if 'ssp' in k])

ddict_concat = defaultdict(dict)

for ssp in ssps:
    ddict_ssp = defaultdict(dict)
    
    for k in ddict.keys():
        if ((ssp in k) or ('historical' in k)):
            ddict_ssp[k] = ddict[k]
            
    #append SSP to historical, only for realizations for which both experiments are provided (join=inner)
    hist_ssp = combine_datasets(ddict_ssp,
                                _concat_sorted_time,
                                match_attrs =['source_id', 'grid_label','table_id'],combine_func_kwargs={'join':'inner'})
    
    for key,ds in hist_ssp.items(): #put back together in dictionary
        ddict_concat[key+'.'+ssp] = ds

Sanity-check timeseries length:

In [None]:
for k,v in ddict_concat.items():
    num_days = (v.time[-1]-v.time[0]).dt.days
    assert (len(v.time) > .9*num_days) & (len(v.time) < 1.1*num_days)

Generate forcing data:

In [None]:
#generate forcing to compute surges with
for k,v in ddict_concat.items():
    attrs = v.attrs
    
    v['sfcWind_sqd'] = v['sfcWind']**2 #add wind squared
    v['sfcWind_cbd'] = v['sfcWind']**3 #add wind cubed
    
    v = (v-v.mean(dim='time'))/v.std(dim='time',ddof=0) #normalize (ignores nan by default?)
    v.attrs = attrs
    
    #concatenate & stack normalized forcing variables to data array with shape (time,(4 variables * num_degr * num_degr))
    v['forcing'] = v[["psl", "sfcWind", "sfcWind_sqd","sfcWind_cbd"]].to_array(dim="forcing_var") 
    v['forcing'] = v['forcing'].transpose("time","forcing_var","lon_around_tg",...).stack(f=['forcing_var','lon_around_tg','lat_around_tg'],create_index=False)
    ddict_concat[k]=v

Derive the principal components and multiply with regression coefficients derived from ERA5 (**to-do**: check lat/lon coords of result, must be of TGs to match pr timeseries):

In [None]:
mlrcoefs = xr.open_dataset('/home/jovyan/CMIP6cf/gssr_coefs_1degRes_forcing.nc') #contains coordinates of and MLR coefficients at TGs

In [None]:
ddict_surges = defaultdict(dict)
for k,ds in ddict_concat.items(): #loop over datasets
     
    ds['surge'] = ( ('member_id','time','tg'), np.nan*np.zeros( (len(ds.member_id),len(ds.time),len(ds.tg)) )) #initialize output
            
    for i_member,member in tqdm(enumerate(ds.member_id)):
        forcing_mem = ds.forcing.sel(member_id=member).load() #load forcing data array into memory (for all tg for current dataset and member)
        
        for i_tg,tg in enumerate(ds.tg):
            #get model forcing at TG
            forcing_tg = forcing_mem.sel(tg=tg) 
            
            #get MLR coefficients at TG
            tg_coefs = mlrcoefs.mlrcoefs.sel(tg=tg)
            num_pcs = int(np.sum(np.isfinite(tg_coefs)))-1 #number of coefs = number of PCs to derive, intercept doesn't count
            
            i_timesteps_w_data = np.argwhere(np.isfinite(forcing_tg.data).all(axis=1)).flatten()
            
            #get principal components (using sklearn to keep deterministic signs consistent)
            pca = PCA(num_pcs)
            pca.fit(forcing_tg.isel(time=i_timesteps_w_data).data) #remove missing values for PCA
            pcs = pca.transform(forcing_tg.isel(time=i_timesteps_w_data).data)
            
            #multiply with ERA5 regression coefficients to compute surges
            ds['surge'][i_member,i_timesteps_w_data,i_tg] = np.sum(tg_coefs[np.isfinite(tg_coefs)].values * np.column_stack((np.ones(pcs.shape[0]),pcs)),axis=1) 
     
    ds['surge'] = ds['surge'].assign_coords(lon=('tg', mlrcoefs.lon.data),lat=('tg', mlrcoefs.lat.data)).assign_attrs(ds.attrs)
    
    #store:
    model_path = os.path.join('/home/jovyan/CMIP6cf/output/surge_timeseries/',ds.source_id)
    
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    
    ds['surge'].to_dataset(name='surge').to_netcdf(os.path.join(model_path,k.replace('.','_')+'.nc'),mode='w')
    ds.close()
        
        #ddict_surges[k] = ds['surges'].assign_coords(lon=('tg', mlrcoefs.lon.data),lat=('tg', mlrcoefs.lat.data)).assign_attrs(ds.attrs)
    
    #save the data instead of storing it in dictionary? probably, don't need it afterwards