In [1]:
import numpy as np
import fnmatch
import xarray as xr
import os
from collections import defaultdict
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from xmip.postprocessing import combine_datasets,_concat_sorted_time
from sklearn.decomposition import PCA
import gcsfs
import dask
fs = gcsfs.GCSFileSystem() # equivalent to fsspec.fs('gs')

  from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!


In [2]:
grid_size = 9 #n by n degree grid around each TG
cmip6_resolution = 1.5

num_grid_cells = int(grid_size/cmip6_resolution)

start_year=1970
end_year=2100

mlr_coefs = xr.open_dataset('/home/jovyan/CMIP6cex/cmip6_processing/gssr_mlr_coefs_1p5_9deg_gesla2.nc') #load MLR coefficients at TGs
era5_pcs = xr.open_dataset('/home/jovyan/CMIP6cex/cmip6_processing/era5_pca_components_1p5_9deg_gesla2.nc')

Loop over timeseries of `psl` & `sfcWind` at common 1.5 by 1.5 degree grid and open them:

In [4]:
var1 = 'psl'
var2 = 'sfcWind'

var1_dir = 'leap-persistent/timh37/CMIP6/timeseries_eu_1p5/'+var1
var2_dir = 'leap-persistent/timh37/CMIP6/timeseries_eu_1p5/'+var2

output_dir = 'leap-persistent/timh37/CMIP6/timeseries_eu_gesla2_tgs/'

var1_models = [k.split('/')[-1] for k in fs.ls(var1_dir) if k.startswith('.')==False]
var2_models = [k.split('/')[-1] for k in fs.ls(var2_dir) if k.startswith('.')==False]

models = [k for k in var1_models if k in var2_models]
ddict = defaultdict(dict)

for s,source_id in tqdm(enumerate(models)):
    for f,file in enumerate(fs.ls(os.path.join(var1_dir,source_id))):
        
        try: #why is this taking almost 2s per instance? chunksizes look good to me ~100mb..
            var1_var2_ds = xr.open_mfdataset(('gs://'+file,'gs://'+file.replace(var1,var2)),engine='zarr',use_cftime=True)
        except:
            continue
        
        k = file.split('/')[-1]
        ddict[k] = var1_var2_ds.sel(time=slice(str(start_year), str(end_year)))

0it [00:00, ?it/s]

In [4]:
#generate lon/lat indices around each TG
ds0 = ddict[list(ddict)[0]]

lat_ranges = np.zeros((len(mlr_coefs.tg),int(grid_size/cmip6_resolution))) #initialize
lon_ranges = np.zeros((len(mlr_coefs.tg),int(grid_size/cmip6_resolution)))

for t,tg in enumerate(mlr_coefs.tg.values): #grids around TGs
    lat_ranges[t,:] = ds0.latitude[((ds0.latitude>=(mlr_coefs.sel(tg=tg).lat-grid_size/2)) & (ds0.latitude<=(mlr_coefs.sel(tg=tg).lat+grid_size/2)))][0:int(grid_size/cmip6_resolution)]
    lon_ranges[t,:] = ds0.longitude[((ds0.longitude>=(mlr_coefs.sel(tg=tg).lon-grid_size/2)) & (ds0.longitude<=(mlr_coefs.sel(tg=tg).lon+grid_size/2)))][0:int(grid_size/cmip6_resolution)]

lons_da = xr.DataArray(lon_ranges,dims=['tg','lon_around_tg'],coords={'tg':mlr_coefs.tg,'lon_around_tg':np.arange(0,int(grid_size/cmip6_resolution))})
lats_da = xr.DataArray(lat_ranges,dims=['tg','lat_around_tg'],coords={'tg':mlr_coefs.tg,'lat_around_tg':np.arange(0,int(grid_size/cmip6_resolution))})

for k,ds in tqdm(ddict.items()): #for each run 
    attrs = ds.attrs
    ds['sfcWind_sqd'] = ds['sfcWind']**2 #add wind squared
    ds['sfcWind_cbd'] = ds['sfcWind']**3 #add wind cubed

    ds = (ds-ds.mean(dim='time'))/ds.std(dim='time',ddof=0) #normalize predictor variables
   
    ds = ds.isel(member_id=0).load() #load into memory
    ds = ds.sel(latitude=lats_da,longitude=lons_da)
    
    #concatenate & stack normalized forcing variables to data array with shape (time,(4 variables * grid_size * grid_size))
    ds['predictors'] = ds[["psl", "sfcWind", "sfcWind_sqd","sfcWind_cbd"]].to_array(dim="predictor_var") 
    ds['predictors'] = ds['predictors'].transpose("time","predictor_var","lon_around_tg",...).stack(f=['predictor_var','lon_around_tg','lat_around_tg'],create_index=False)

    #compute surges from predictors
    surge_ds = xr.Dataset(data_vars=dict(surge=(['time','tg'], np.nan*np.zeros( (len(ds.time),len(mlr_coefs.tg))) )),
                            coords=dict(time=ds.time,tg=mlr_coefs.tg)) #initialize output dataset per model
    
    for t,tg in enumerate(mlr_coefs.tg): #loop over tide gauges (necessary because of EOF analysis?)
        predictors_at_tg = ds.sel(tg=tg)
        mlr_coefs_at_tg = mlr_coefs.mlr_coefs.sel(tg=tg)

        num_pcs = int(np.sum(np.isfinite(mlr_coefs_at_tg)))-1 #number of mlr coefs = number of PCs to derive, intercept doesn't count
        idx_timesteps_w_data = np.argwhere((np.isfinite(predictors_at_tg.predictors).all(axis=1)).values).flatten() #omit timesteps with NaN if any

        #get principal components (using sklearn to keep deterministic signs consistent)
        pca = PCA(num_pcs)
        pca.fit(predictors_at_tg.predictors.isel(time=idx_timesteps_w_data)) #remove missing values for PCA
        pcs = pca.transform(predictors_at_tg.predictors.isel(time=idx_timesteps_w_data))

        components = xr.DataArray(data=pca.components_,dims=['pc','f'],coords=dict(pc=np.arange(num_pcs),f=predictors_at_tg.f))

        #compute RMSEs with ERA5 principal components, only considering the pressure part of the forcing (first num_grid_cells**2)
        rmses = np.sqrt(((components.isel(f=np.arange(num_grid_cells**2))-era5_pcs.sel(tg=tg).isel(f=np.arange(num_grid_cells**2)).isel(pc=np.arange(num_pcs)).component)**2).mean(dim='f')) #original sign
        rmses_flipped = np.sqrt(((components.isel(f=np.arange(num_grid_cells**2))--era5_pcs.isel(f=np.arange(num_grid_cells**2)).sel(tg=tg).isel(pc=np.arange(num_pcs)).component)**2).mean(dim='f')) #opposite sign

        s = (rmses<rmses_flipped).astype('int') #flip pcs if rmse of flipped pc is lower
        s[s==0]=-1
        pcs = pcs * s.values

        #multiply with ERA5 regression coefficients to compute surges
        surge_ds['surge'][idx_timesteps_w_data,t] = np.sum(mlr_coefs_at_tg[np.isfinite(mlr_coefs_at_tg)].values * np.column_stack((np.ones(pcs.shape[0]),pcs)),axis=1) 
    
    surge_ds.attrs        = attrs
    surge_ds              = surge_ds.expand_dims('member_id')
    surge_ds['tg']        = surge_ds.tg.astype('str') #something wrong with encoding object types in zarr, this is the work-around
    surge_ds['member_id'] = surge_ds.member_id.astype('str')
    
    #storage
    output_fn = os.path.join(output_dir,'surge',surge_ds.source_id,k.replace(var1,'surge'))
    
    surge_ds.chunk({'time':len(surge_ds.time)}).to_zarr('gs://'+output_fn,mode='w')
    surge_ds = {}


  0%|          | 0/28 [00:00<?, ?it/s]

KeyboardInterrupt: 