In [1]:
import numpy as np
import xarray as xr
import os
from collections import defaultdict
import fnmatch
from tqdm.autonotebook import tqdm
import dask
import gcsfs
fs = gcsfs.GCSFileSystem() # equivalent to fsspec.fs('gs')

  from tqdm.autonotebook import tqdm


In [2]:
def po_t_of_refyear(da,threshold,refyear,dim):
    return da.where(da>da.sel(window=refyear).quantile(threshold,dim=dim))
            
def rolling_max(da,window_len,dim):
    return da.rolling({dim:window_len},center=True,min_periods=1).max()

def count_num_extremes_pmonth(extremes):
    extremes_ = extremes.copy(deep=True) #must be boolean array (True or False (joint) extreme occurs on that day)!
    if len(extremes.time.shape)>1:
        extremes_['time_in_window_idx'] = extremes_.time.dt.month.isel(window=0).values
    else:
        extremes_['time_in_window_idx'] = extremes_.time.dt.month.values
    num_extremes_pmonth = extremes_.rename({'time_in_window_idx':'month'}).groupby('month').sum()
    return num_extremes_pmonth        

Configure the bivariate sampling:

In [3]:
#configure bivariate sampling settings
max_lag = 0 #maximum time between co-occurring extremes (0 = no lag)
declus_window_len = 3 #rolling window length for declustering (1 = no declustering)
threshold = .98 #quantile above which events are defined extreme

output_yrs = np.arange(2000,2100,20)
window_len = 40 #period length around output_yrs
ref_year = 2000 #historical period to to compute thresholds from

#configure the input directories and grid_type
var1_dir = 'leap-persistent/timh37/CMIP6/timeseries_eu_gesla2_tgs/surge'
var2_dir = 'leap-persistent/timh37/CMIP6/timeseries_eu_gesla2_tgs/pr'

out_dir = '/home/jovyan/CMIP6cex/output/num_extremes/'

overwrite_existing=False

Open datasets into dictionary of datasets:

In [4]:
#make output folder
var1 = var1_dir.split('/')[-1]
var2 = var2_dir.split('/')[-1]

out_path = os.path.join(out_dir,var1_dir.split('/')[-2],var1+'_'+var2,
                        str(window_len)+'yr_'+str(threshold).replace('0.','p')+'_lag'+str(max_lag)+'d_declus'+str(declus_window_len)+'d_ref'+str(ref_year))

#open datasets
var1_models = [k.split('/')[-1] for k in fs.ls(var1_dir) if k.startswith('.')==False]
var2_models = [k.split('/')[-1] for k in fs.ls(var2_dir) if k.startswith('.')==False]
models = [k for k in var1_models if k in var2_models]
ddict = defaultdict(dict)

for s,source_id in tqdm(enumerate(models)):
    for f,file in enumerate(fs.ls(os.path.join(var1_dir,source_id))):
        
        try:
            var1_var2_ds = xr.open_mfdataset(('gs://'+file,'gs://'+file.replace(var1,var2)),engine='zarr',use_cftime=True)
        except:
            continue
        
        k = file.split('/')[-1]
        ddict[k] = var1_var2_ds.sel(time=slice(str(int(output_yrs[0]-window_len/2)), str(int(output_yrs[-1]+window_len/2))))

0it [00:00, ?it/s]

In [5]:
for k,ds in tqdm(ddict.items()):
    if 'member_id' not in ds.dims:
        ds = ds.expand_dims('member_id')
    ds = ds.load()
    ds = ds.transpose('time',...)
   
    source_id = k.split('.')[0]
    output_fn = os.path.join(out_path,source_id,k.replace(var1,'num_extremes'))+'.nc'
    
    if (~overwrite_existing) & (os.path.exists(output_fn)):
        continue
    else:
        try:
            os.makedirs(os.path.join(out_path,source_id))
        except:
            pass
        
    attrs = ds.attrs
    ds = ds.isel(member_id=0) 
    
    #remove leap days
    if len(np.unique(ds.time.resample(time='1Y').count()))>1: #remove leap days so that each computation window has the same length
        with dask.config.set(**{'array.slicing.split_large_chunks': True}):
            ds = ds.sel(time=~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))) #^probably (hopefully) only has a small effect on the results

    days_in_year = int(ds.time.resample(time='1Y').count()[0])
    
    #construct time period indices
    if window_len%2 !=0: #odd
        window_start_idx = days_in_year*(output_yrs-int(output_yrs[0]-window_len/2)-int(np.floor(window_len/2)))
        first_window_idx = np.arange(0*days_in_year,window_len*days_in_year)
    else: #even
        window_start_idx = days_in_year*(output_yrs-int(output_yrs[0]-window_len/2)-int(window_len/2)+1)
        first_window_idx = np.arange(0*days_in_year,window_len*days_in_year)

    if np.max(first_window_idx[:,np.newaxis]+window_start_idx[np.newaxis,:])>=len(ds.time): #if window exceeds simulation length
        if (int(output_yrs[-1]+window_len/2)==2100) & (ddict[k].time.dt.year[-1]==2099): #if end year is 2099 instead of 2100, shift windows to 1 year earlier
            window_start_idx = window_start_idx - days_in_year #shift windows back by 1 year
        else:
            print('Requested time periods exceed simulation length for '+k)
        continue #skip

    window_idx = xr.DataArray( #indices of time periods
        data=first_window_idx[:,np.newaxis]+window_start_idx[np.newaxis,:],
        dims=["time_in_window_idx","window"],
        coords=dict(time_in_window_idx=first_window_idx,window=output_yrs),)
    
    #select data in user-defined time windows
    try:
        ds_wdws = ds.isel(time=window_idx) 
    except:
        print('could not select requested time windows from '+k)
        continue
        
    #find extremes using POT with historical threshold values
    var1_extremes = po_t_of_refyear(ds_wdws[var1],threshold,ref_year,dim='time_in_window_idx')
    var2_extremes = po_t_of_refyear(ds_wdws[var2],threshold,ref_year,dim='time_in_window_idx')

    var1_extremes_declustered = var1_extremes.where(var1_extremes==var1_extremes.rolling({'time_in_window_idx':declus_window_len},center=True,min_periods=1).max(skipna=True))
    var2_extremes_declustered = var2_extremes.where(var2_extremes==var2_extremes.rolling({'time_in_window_idx':declus_window_len},center=True,min_periods=1).max(skipna=True))

    #boolean array of when days experience joint extremes within 'max_lag' lag from eachother
    joint_extremes = np.isfinite((rolling_max(var2_extremes_declustered,max_lag*2+1,dim='time_in_window_idx')*var1_extremes_declustered)) 
    
    #store number of extremes per month in dataset
    num_extremes = count_num_extremes_pmonth(joint_extremes).to_dataset(name='num_joint_extremes') #joint extremes
    num_extremes['num_'+var1+'_extremes'] = count_num_extremes_pmonth(np.isfinite(var1_extremes_declustered)) #univariate extremes var1
    num_extremes['num_'+var2+'_extremes'] = count_num_extremes_pmonth(np.isfinite(var2_extremes_declustered)) #univariate extremes var2

    
    #decompose future changes in number of joint extremes (only if not declustering and 0-day lag)
    
    #initialize output
    num_extremes['num_joint_extremes_'+var1+'_driven']          = np.nan*num_extremes['num_joint_extremes'].copy(deep=True) #dN due to univariate changes var1
    num_extremes['num_joint_extremes_'+var2+'_driven']          = np.nan*num_extremes['num_joint_extremes'].copy(deep=True) #dN due to univariate changes var2
    num_extremes['num_joint_extremes_'+var1+'_'+var2+'_driven'] = np.nan*num_extremes['num_joint_extremes'].copy(deep=True) #dN due to univariate changes var1 and var2
    
    num_extremes['num_'+var1+'_extremes_refWindow_futT']        = np.nan*num_extremes['num_'+var1+'_extremes'].copy(deep=True) #number of peaks in future windows exceeding the reference threshold values
    num_extremes['num_'+var2+'_extremes_refWindow_futT']        = np.nan*num_extremes['num_'+var2+'_extremes'].copy(deep=True) #number of peaks in future windows exceeding the reference threshold values
    
    if (max_lag == 0) & (declus_window_len==1):
    
        #1) sort (in magnitude) the values in the reference period of each variable
        if ('latitude' in ds.dims) & ('longitude' in ds.dims):
            sorted_var1_ref = xr.DataArray(data=np.sort(ds_wdws.sel(window=ref_year)[var1],axis=0),dims=ds_wdws.sel(window=ref_year)[var1].dims,
                                               coords=dict(time_in_window_idx=ds_wdws.time_in_window_idx,latitude=ds_wdws.latitude,longitude=ds_wdws.longitude))
            sorted_var2_ref = xr.DataArray(data=np.sort(ds_wdws.sel(window=ref_year)[var2],axis=0),dims=ds_wdws.sel(window=ref_year)[var2].dims,
                                          coords=dict(time_in_window_idx=ds_wdws.time_in_window_idx,latitude=ds_wdws.latitude,longitude=ds_wdws.longitude))
        elif ('tg' in ds.dims):
            sorted_var1_ref = xr.DataArray(data=np.sort(ds_wdws.sel(window=ref_year)[var1],axis=0),dims=ds_wdws.sel(window=ref_year)[var1].dims,
                                               coords=dict(time_in_window_idx=ds_wdws.time_in_window_idx,tg=ds_wdws.tg))
            sorted_var2_ref = xr.DataArray(data=np.sort(ds_wdws.sel(window=ref_year)[var2],axis=0),dims=ds_wdws.sel(window=ref_year)[var2].dims,
                                          coords=dict(time_in_window_idx=ds_wdws.time_in_window_idx,tg=ds_wdws.tg))
        
        #2) find the number of extremes that exceed the reference period threshold value in other windows    
        for w,win in enumerate(ds_wdws.window): #loop over each window
            num_var1_extremes_in_wdw = np.isfinite(var1_extremes_declustered).sum(dim='time_in_window_idx').sel(window=win)#.load()
            num_var2_extremes_in_wdw = np.isfinite(var2_extremes_declustered).sum(dim='time_in_window_idx').sel(window=win)#.load()
        
            #use that number to define the equivalent threshold in the historical period (var_{U_{var}}^{hist} in the paper)
            var1_eqv_thresholds = sorted_var1_ref.isel(time_in_window_idx=-1*(num_var1_extremes_in_wdw))
            var2_eqv_thresholds = sorted_var2_ref.isel(time_in_window_idx=-1*(num_var2_extremes_in_wdw))

            #3) determine the extremes above those threshold values in the reference window
            var1_extremes_fut_threshold = ds_wdws[var1].sel(window=ref_year).where(ds_wdws[var1].sel(window=ref_year)>=var1_eqv_thresholds)
            var2_extremes_fut_threshold = ds_wdws[var2].sel(window=ref_year).where(ds_wdws[var2].sel(window=ref_year)>=var2_eqv_thresholds)

            #4) determine the joint extremes using these extremes:
            joint_extremes_var1_driven = np.isfinite((rolling_max(var2_extremes_declustered.sel(window=ref_year),max_lag*2+1,dim='time_in_window_idx')*var1_extremes_fut_threshold))
            joint_extremes_var2_driven = np.isfinite((rolling_max(var2_extremes_fut_threshold,max_lag*2+1,dim='time_in_window_idx')*var1_extremes_declustered.sel(window=ref_year)))
            joint_extremes_var1_var2_driven = np.isfinite((rolling_max(var2_extremes_fut_threshold,max_lag*2+1,dim='time_in_window_idx')*var1_extremes_fut_threshold))

            #5) count per month & write to output dataset
            num_extremes['num_joint_extremes_'+var1+'_driven'].loc[dict(window=win)] = count_num_extremes_pmonth(joint_extremes_var1_driven)
            num_extremes['num_joint_extremes_'+var2+'_driven'].loc[dict(window=win)] = count_num_extremes_pmonth(joint_extremes_var2_driven)
            num_extremes['num_joint_extremes_'+var1+'_'+var2+'_driven'].loc[dict(window=win)] = count_num_extremes_pmonth(joint_extremes_var1_var2_driven)
            
            num_extremes['num_'+var1+'_extremes_refWindow_futT'].loc[dict(window=win)] = count_num_extremes_pmonth(np.isfinite(var1_extremes_fut_threshold))
            num_extremes['num_'+var2+'_extremes_refWindow_futT'].loc[dict(window=win)] = count_num_extremes_pmonth(np.isfinite(var2_extremes_fut_threshold))
            
    #store the results
    num_extremes = num_extremes.expand_dims('member_id') #add back member_id as dimension
    num_extremes.attrs = attrs #add original attributes
    
    #add information about the joint extremes analysis
    num_extremes.attrs['window_length'] = str(window_len)
    num_extremes.attrs['declustering_length'] = str(declus_window_len)
    num_extremes.attrs['allowed_lag'] = str(max_lag)
    num_extremes.attrs['ref_window'] = str(ref_year)
    num_extremes.attrs['source_id'] = source_id
    num_extremes.to_netcdf(output_fn,mode='w')
    num_extremes.close()

  0%|          | 0/436 [00:00<?, ?it/s]

var1 = var1_dir.split('/')[-1]
var2 = var2_dir.split('/')[-1]

models_var1 = [k.split('/')[-1] for k in fs.ls(var1_dir)]
models_var2 = [k.split('/')[-1] for k in fs.ls(var2_dir)]
source_ids = sorted(list(set(models_var1) & set(models_var2))) #intersection of models

for source_id in [k for k in source_ids if ~k.startswith('.')]: #loop over models
  
    var1_model_path = os.path.join(var1_dir,source_id)
    var2_model_path = os.path.join(var2_dir,source_id)
    
    #sfcWind_exps = [s.split('_')[-1][0:-3] for s in os.listdir(sfcWind_path) if s.startswith('.')==False]
    #pr_exps = [s.split('_')[-1][0:-3] for s in os.listdir(pr_path) if s.startswith('.')==False]
    
    #get experiment_id's
    var1_exps = [s.split('/')[-1].split('_')[-1][0:-5] for s in fs.ls(var1_model_path) if s.startswith('.')==False] 
    var2_exps = [s.split('/')[-1].split('_')[-1][0:-5] for s in fs.ls(var2_model_path) if s.startswith('.')==False]
    experiment_ids = list(set(var1_exps) & set(var2_exps))

    for experiment_id in experiment_ids: #loop over experiments
        #load data:
        fn = fnmatch.filter(fs.ls(var1_model_path),'*'+experiment_id+'*')[0]
        fn = fn.split('/')[-1]
        print('Processing file: '+fn)
        if input_is_gridded==False:
            var1_var2_data = xr.open_mfdataset((os.path.join('gs://',var1_model_path,fn),os.path.join('gs://',var2_model_path,fn)),engine='zarr',compat='override',chunks={'member_id':1,'time':100000})
        else:
            #sfcWind_pr = xr.open_mfdataset((os.path.join(sfcWind_path,fn),os.path.join(pr_path,fn)),chunks={'member_id':1,'time':100000,'longitude':3})#.sel(longitude=np.arange(-25,11))
            var1_var2_data = xr.open_mfdataset((os.path.join('gs://',var1_model_path,fn),os.path.join('gs://',var2_model_path,fn)),engine='zarr',chunks={'member_id':1,'time':100000,'longitude':5})
 
        #generate output paths
        #model_path = os.path.join('/home/jovyan/CMIP6cf/output/dependence/sfcWind_pr_europe/40yr_p98_lag0d_declus1d_ref2000',sfcWind_pr.source_id)
        output_path = '/home/jovyan/CMIP6cex/output/num_extremes/'+var1+'_g2_'+var2+'_'+var1_dir.split('_')[-1]+'/'+str(window_len)+'yr_'+str(threshold).replace('0.','p')+'_lag'+str(max_lag)+'d_declus'+str(declus_window_len)+'d_ref'+str(ref_year)
        output_model_path = os.path.join(output_path,var1_var2_data.source_id)
        output_fn = os.path.join(output_model_path,fn.replace('.zarr','.nc'))

        #construct time window indices
        if len(np.unique(var1_var2_data.time.resample(time='1Y').count()))>1: #remove leap days so that each computation window has the same length
            with dask.config.set(**{'array.slicing.split_large_chunks': True}):
                var1_var2_data = var1_var2_data.sel(time=~((var1_var2_data.time.dt.month == 2) & (var1_var2_data.time.dt.day == 29))) #^probably (hopefully) only has a small effect on the results
        
        days_in_year = int(var1_var2_data.time.resample(time='1Y').count()[0])
        
        if window_len%2 !=0: #odd
            window_start_idx = days_in_year*(output_yrs-1850-int(np.floor(window_len/2)))
            first_window_idx = np.arange(0*days_in_year,window_len*days_in_year)
        else: #even
            window_start_idx = days_in_year*(output_yrs-1850-int(window_len/2)+1)
            first_window_idx = np.arange(0*days_in_year,window_len*days_in_year)
        
        if np.max(first_window_idx[:,np.newaxis]+window_start_idx[np.newaxis,:])>=len(var1_var2_data.time): #if window exceeds simulation length
            continue #skip
            #raise Exception('Windows exceed simulation length.')
            
        window_idx = xr.DataArray( #indices of windows
            data=first_window_idx[:,np.newaxis]+window_start_idx[np.newaxis,:],
            dims=["time_in_window_idx","window"],
            coords=dict(
                time_in_window_idx=first_window_idx,
                window=output_yrs
            ),
        )
        
        if not os.path.exists(output_path):
            os.mkdir(output_path)
        if not os.path.exists(output_model_path):
            os.mkdir(output_model_path)
            
        for m,member in tqdm(enumerate(var1_var2_data.member_id)): #loop over members of each model to compute the dependence
        
            var1_var2_data_mem = var1_var2_data.sel(member_id=member)
            with dask.config.set(**{'array.slicing.split_large_chunks': False}):
                var1_var2_data_wdws = var1_var2_data_mem.isel(time=window_idx) #select data in user-defined time windows
            
            data_is_complete = np.isfinite(var1_var2_data_wdws[var1]).all(dim='time_in_window_idx') * np.isfinite(var1_var2_data_wdws[var2]).all(dim='time_in_window_idx') #check data-completeness in each window

            #derive peaks
            var1_peaks = po_t_of_refyear(var1_var2_data_wdws[var1],threshold,ref_year,dim='time_in_window_idx')
            var2_peaks = po_t_of_refyear(var1_var2_data_wdws[var2],threshold,ref_year,dim='time_in_window_idx')
            
            var1_peaks_declustered = var1_peaks.where(var1_peaks==var1_peaks.rolling({'time_in_window_idx':declus_window_len},center=True,min_periods=1).max(skipna=True))
            var2_peaks_declustered = var2_peaks.where(var2_peaks==var2_peaks.rolling({'time_in_window_idx':declus_window_len},center=True,min_periods=1).max(skipna=True))
            
            #determine joint extremes within 'max_lag' lag from eachother
            joint_extremes = np.isfinite((rolling_max(var2_peaks_declustered,max_lag*2+1,dim='time_in_window_idx')*var1_peaks_declustered)) #previously: 'co_occurring'
            
            #generate output dataset for current member
            num_extremes_mem = sum_num_extremes_pmonth(joint_extremes).to_dataset(name='num_joint_extremes')
            num_extremes_mem['num_'+var1+'_extremes'] = sum_num_extremes_pmonth(np.isfinite(var1_peaks_declustered))
            num_extremes_mem['num_'+var2+'_extremes'] = sum_num_extremes_pmonth(np.isfinite(var2_peaks_declustered))
            
            ####DECOMPOSITION OF CHANGES (probably not correct if declustering!!):
            #1) sort (in magnitude) values in reference period to determine the equivalent threshold percentiles in other windows
            if input_is_gridded:
                sorted_var1_ref = xr.DataArray(data=np.sort(var1_var2_data_wdws.sel(window=ref_year)[var1],axis=0),dims=['time_in_window_idx','latitude','longitude'],
                                                   coords=dict(time_in_window_idx=var1_var2_data_wdws.time_in_window_idx,latitude=var1_var2_data_wdws.latitude,longitude=var1_var2_data_wdws.longitude)).chunk({'longitude':5})
                sorted_var2_ref = xr.DataArray(data=np.sort(var1_var2_data_wdws.sel(window=ref_year)[var2],axis=0),dims=['time_in_window_idx','latitude','longitude'],
                                              coords=dict(time_in_window_idx=var1_var2_data_wdws.time_in_window_idx,latitude=var1_var2_data_wdws.latitude,longitude=var1_var2_data_wdws.longitude)).chunk({'longitude':5})
            else:
                sorted_var1_ref = xr.DataArray(data=np.sort(var1_var2_data_wdws.sel(window=ref_year)[var1],axis=0),dims=['time_in_window_idx','tg'],
                                                   coords=dict(time_in_window_idx=var1_var2_data_wdws.time_in_window_idx,tg=var1_var2_data_wdws.tg))
                sorted_var2_ref = xr.DataArray(data=np.sort(var1_var2_data_wdws.sel(window=ref_year)[var2],axis=0),dims=['time_in_window_idx','tg'],
                                              coords=dict(time_in_window_idx=var1_var2_data_wdws.time_in_window_idx,tg=var1_var2_data_wdws.tg))
            #initialize output arrays
            num_extremes_mem['num_joint_extremes_'+var1+'_driven'] = num_extremes_mem['num_joint_extremes'].copy(deep=True)
            num_extremes_mem['num_joint_extremes_'+var2+'_driven'] = num_extremes_mem['num_joint_extremes'].copy(deep=True)
            num_extremes_mem['num_joint_extremes_'+var1+'_'+var2+'_driven'] = num_extremes_mem['num_joint_extremes'].copy(deep=True)
            num_extremes_mem['num_'+var1+'_extremes_refWindow_futT'] = num_extremes_mem['num_'+var1+'_extremes'].copy(deep=True)
            num_extremes_mem['num_'+var2+'_extremes_refWindow_futT'] = num_extremes_mem['num_'+var2+'_extremes'].copy(deep=True)
            
            for w,win in enumerate(var1_var2_data_wdws.window): #loop over each window to do the decomposition
                #2) find the threshold values in the reference period corresponding to the percentile of events exceeding the reference threshold values in the future (var_{U_{var}}^{hist} in the paper)
                var1_eqv_thresholds = sorted_var1_ref.isel(time_in_window_idx=-1*(np.isfinite(var1_peaks_declustered).sum(dim='time_in_window_idx').sel(window=win).load()))
                var2_eqv_thresholds = sorted_var2_ref.isel(time_in_window_idx=-1*(np.isfinite(var2_peaks_declustered).sum(dim='time_in_window_idx').sel(window=win).load()))
                
                #3) determine the peaks above those threshold values in the reference window
                var1_peaks_fut_threshold = var1_var2_data_wdws[var1].sel(window=ref_year).where(var1_var2_data_wdws[var1].sel(window=ref_year)>=var1_eqv_thresholds) #determine the peaks in the reference period above those values
                var2_peaks_fut_threshold = var1_var2_data_wdws[var2].sel(window=ref_year).where(var1_var2_data_wdws[var2].sel(window=ref_year)>=var2_eqv_thresholds)

                #4) determine the joint extremes for different components:
                # a) var2 peaks above standard threshold in reference period, var1 above future threshold percentile in reference period
                joint_extremes_var1_driven = np.isfinite((rolling_max(var2_peaks_declustered.sel(window=ref_year),max_lag*2+1,dim='time_in_window_idx')*var1_peaks_fut_threshold))

                # b) var1 peaks above standard threshold in reference period, var2 above future threshold percentile in reference period
                joint_extremes_var2_driven = np.isfinite((rolling_max(var2_peaks_fut_threshold,max_lag*2+1,dim='time_in_window_idx')*var1_peaks_declustered.sel(window=ref_year)))

                # c) var1 and var 2 above future threshold percentile in reference period
                joint_extremes_var1_var2_driven = np.isfinite((rolling_max(var2_peaks_fut_threshold,max_lag*2+1,dim='time_in_window_idx')*var1_peaks_fut_threshold))
           
                #count per month
                num_extremes_mem['num_joint_extremes_'+var1+'_driven'].loc[dict(window=win)] = sum_num_extremes_pmonth(joint_extremes_var1_driven)
                num_extremes_mem['num_joint_extremes_'+var2+'_driven'].loc[dict(window=win)] = sum_num_extremes_pmonth(joint_extremes_var2_driven)
                num_extremes_mem['num_joint_extremes_'+var1+'_'+var2+'_driven'].loc[dict(window=win)] = sum_num_extremes_pmonth(joint_extremes_var1_var2_driven)
                num_extremes_mem['num_'+var1+'_extremes_refWindow_futT'].loc[dict(window=win)] = sum_num_extremes_pmonth(np.isfinite(var1_peaks_fut_threshold))
                num_extremes_mem['num_'+var2+'_extremes_refWindow_futT'].loc[dict(window=win)] = sum_num_extremes_pmonth(np.isfinite(var2_peaks_fut_threshold))
            
            #store metadata
            num_extremes_mem['complete_window'] = data_is_complete #store where windows miss data
                        
            num_extremes_mem = num_extremes_mem.expand_dims(dim={"member_id": 1}) #add coordinates & dimensions

            num_extremes_mem.attrs = var1_var2_data.attrs #keep original attributes and add information on the extremes analysis
            num_extremes_mem.attrs['window_length'] = str(window_len)
            num_extremes_mem.attrs['declustering'] = 'Rolling window of '+str(declus_window_len)+' days'
            num_extremes_mem.attrs['allowed_lag'] = str(max_lag)
            num_extremes_mem.attrs['ref_window'] = str(ref_year)
            
            num_extremes_mem.to_netcdf(output_fn.replace('.nc','_'+num_extremes_mem.member_id.values[0]+'.nc'),mode='w')
            num_extremes_mem.close()
    