In [None]:
import intake

from workflow.scripts.utils import regrid_global, calculate_pooled_variance
import time
import xarray as xr
from functools import partial
import xesmf
import pandas as pd
from pyclim_noresm.general_util_funcs import global_avg
from intake_esm.derived import DerivedVariableRegistry
import numpy as np
import scipy.stats as st
import yaml
from scipy.stats._resampling import _bootstrap_resample

In [None]:
def update_datatracker(var_id):
    import yaml
    find_issue = lambda d: [k for k,v in d.items() if var_id in v]
    erfs = snakemake.config.get('variables')
    burdens_dict = snakemake.config.get('burdens_dict')
    burdens_vars = find_issue(burdens_dict)
    erfs_vars = find_issue(erfs)
    issue_vars = burdens_vars+erfs_vars
    if not issue_vars:
        issue_vars = [var_id]
        
    
    with open(snakemake.input.data_tracker, 'r') as f:
        data = yaml.load(f, yaml.SafeLoader)
 
    for v in issue_vars:
        for k,requests in data.items():
            data[k] = list(set(requests)-set(issue_vars)) 
    
    with open(snakemake.input.data_tracker, 'w') as f:
        yaml.dump(data,f, yaml.SafeDumper,default_flow_style=False)
    
    raise ValueError(f"Could not find {var_id} in, removing {','.join(issue_vars)} from data tracker ")

In [None]:
control_exps = {'piClim-2xdust':'piClim-control'} # Should put this in config at some point

conf = snakemake.config

if snakemake.rule == "calc_dust_regional_erf_table":
    exp_id = 'piClim-2xdust'
else:
    exp_id = snakemake.wildcards.experiment
mod_id = snakemake.wildcards.model
ctrl_id = control_exps.get(exp_id, 'piClim-control')
time_slice = conf.get('time_slice', slice(2,None))

nSamples = snakemake.params.get('nSamples', 1000)
confidence_level = snakemake.params.get('conf_level', 0.95)

mask = snakemake.input.get('mask')

if mask:
    MASK = xr.open_dataset(mask)

if snakemake.config['model_specific_variant'].get(exp_id, None):
    memb_id = snakemake.config['model_specific_variant'][exp_id].get(mod_id, snakemake.config['variant_default'])
else:
    memb_id = snakemake.config['variant_default']

In [None]:
dvr = DerivedVariableRegistry()
@dvr.register(variable='rlutaf', 
              query=dict(experiment_id=[exp_id,ctrl_id],
                    source_id=mod_id,
                     variable_id = ['rlutcs', 'rlut','rlutcsaf'],
                     )
             )
def calc_rlutaf(ds):
    ds['rlutaf'] = ds['rlut'] + (ds['rlutcsaf'] - ds['rlutcs'])
    attrs = ds['rlut'].attrs.copy()
    ds = ds.drop(['rlut','rlutcsaf','rlutcs'])
    ds['rlutaf'].attrs = attrs 
    ds['rlutaf'].attrs['long_name'] = 'TOA Outgoing aerosol free Longwave Radiation'
    ds.attrs['variable_id'] = 'rlutaf'
    return ds
    
@dvr.register(variable='rsutaf', 
              query=dict(experiment_id=[exp_id,ctrl_id],
                    source_id=mod_id,
                     variable_id = ['rsutcs', 'rsut','rsutcsaf'],
                     )
             )
def calc_rsutaf(ds):
    ds['rsutaf'] = ds['rsut'] + (ds['rsutcsaf'] - ds['rsutcs'])
    attrs = ds['rsut'].attrs.copy()
    ds = ds.drop(['rsut','rsutcsaf','rsutcs'])
    ds['rsutaf'].attrs = attrs 
    ds['rsutaf'].attrs['long_name'] = 'TOA Outgoing aerosol free Shortwave Radiation'
    ds.attrs['variable_id'] = 'rsutaf'

    return ds

In [None]:
def get_variable(var_id, exp_id, cat):
    
    use_derived_vars_dict = conf['use_derived_vars'].get(var_id,False)
    if isinstance(use_derived_vars_dict, dict):
        use_derived_vars = use_derived_vars_dict.get('all', False)
        use_derived_vars = use_derived_vars_dict.get(mod_id,use_derived_vars)
    else:
        use_derived_vars = use_derived_vars_dict
    
    if conf['model_specific_variant'].get(exp_id, None):
        memb_id = conf['model_specific_variant'][exp_id].get(mod_id, conf['variant_default'])
    else:
        memb_id = conf['variant_default']
        
    col = cat.search(experiment_id=exp_id,
                source_id=mod_id,
                 variable_id = var_id,
                 member_id=memb_id,
    )

    if var_id in col.unique()['variable_id'] and use_derived_vars==False:
        col.derivedcat = DerivedVariableRegistry()
        col = col.search(variable_id = var_id)
        if col.nunique()['table_id'] > 1:
            col = col.search(table_id=conf['table_ids'].get(var_id, conf['table_id_default']))
    if col.nunique().version > 1 and not col.unique()['derived_variable_id']:
        latest = max(col.df['version'].unique())
        col = col.search(version=[latest])

    try:
        ds = col.to_dataset_dict(xarray_open_kwargs={'use_cftime':True},progressbar=False)
        ds = ds[list(ds.keys())[0]]
        ds = ds.drop('member_id').squeeze()
    except IndexError:
        update_datatracker(var_id)
    if col.derivedcat.values():
        expected_variables = col.derivedcat[var_id].query['variable_id']
        if set(expected_variables) != set(col.unique()['variable_id']):
            update_datatracker(var_id)
    data = ds.copy()
    data=data[data.variable_id].resample(time='Y').mean()
    data = data.to_dataset(name=ds.variable_id)
    data.attrs = ds.attrs.copy()
    data.attrs['history'] = data.attrs.get('history','') + f', annual average'    
    t=data.time.dt.strftime('%Y')
    data = data.assign_coords({'time':xr.cftime_range(start=str(t[0].values), end= str(t[-1].values), freq='YS')})
    if ds.cf.bounds.get('lon') is None:
        data = data.cf.add_bounds(['lon'])
    else:
        data = data.assign({ds.cf.bounds['lon'][0]:ds[ds.cf.bounds['lon'][0]]})
    if ds.cf.bounds.get('lat') is None:
        data = data.cf.add_bounds(['lat'])
    else:
        data = data.assign({ds.cf.bounds['lat'][0]:ds[ds.cf.bounds['lat'][0]]})
    return data

In [None]:
def regrid(ds, no_mask=False):
    method='conservative'
    ds = regrid_global(ds,MASK, method=method)
    return ds


def calc_erf(variables, 
             exp_id, 
             cat, 
             time_slice=slice(2,None), 
             n_samples=1000,
             mask=None,
            only_regrid=False,
            boot_strap=True):
    dsets = {}
    signs = {'rlut':-1, 'rlutcs' :-1, 'rlutaf': -1, 'rlutcsaf':-1 ,
          'rsdt':1, 
          'rsut':-1, 'rsutcs': -1, 'rsutcsaf': -1, 'rsutaf':-1,
          'rlus':-1,
           'rsus':-1, 'rsuscs': -1, 
            'rsds':1, 'rsdscs': 1,
            'rlds':1}
    
    try:
        for var_id in variables:
            dsets[var_id] = get_variable(var_id, exp_id, cat)
    except:
        raise ValueError(f"Could not find {var_id}")
    data = {}
    for k, ds in dsets.items():
        try:
            data[k] = ds.load()
        except:
            (f"print couldn't not load {k}")
            raise
    if only_regrid:
        data = {k:regrid(ds) for k, ds in data.items()} 
    
    elif mask is not None:
        data = {k:regrid(ds).where(mask,0.0) for k, ds in data.items()} 
    
    data = {k:global_avg(ds) for k, ds in data.items()}
    forcing = 0
    for v, ds in data.items():
        if mod_id in ['NorESM2-LM','NorESM2.0.6dev-LM'] and v in ['rsutaf', 'rsutafcs']:
            forcing += signs[v]*np.abs(ds[v]-data['rsdt']['rsdt'])
        else:
            forcing += signs[v]*ds[v]
    forcing = forcing.isel(time=time_slice)
    if boot_strap:
        samples = _bootstrap_resample(forcing.values,n_samples)
        mean, std = samples.mean(axis=-1).mean(), samples.mean(axis=-1).std()
    else:
        mean = forcing.mean()
        std = forcing.std()
        samples = forcings
    return mean, std, forcing

In [None]:
def bootstrap_diff(exp_mean, exp_std, ctrl_mean, ctrl_std, n_samples=1000, confidence=0.95):
    dst_ctrl = lambda :np.random.normal(ctrl_mean,ctrl_std)
    dst_exp = lambda :np.random.normal(exp_mean,exp_std)
    
    diffs = [dst_exp()-dst_ctrl() for i in range(n_samples)]
    diffs = np.array(diffs)
    res = st.bootstrap((diffs,),np.mean,confidence_level=confidence)
    
    return np.mean(diffs), np.std(diffs),res.confidence_interval.low, res.confidence_interval.high

In [None]:
def t_test_diff(ts_exp, ts_crtl, confidence=0.95):
    pooled_var = calculate_pooled_variance(ts_crtl, ts_exp)
    std_error =  np.sqrt((pooled_var)/len(ts_exp)+(pooled_var)/len(ts_crtl))
    diff = ts_exp.mean()-ts_crtl.mean()
    t_val = abs(diff/std_error)
    t_crit = st.t.ppf(q=1-confidence/2, df=len(ts_exp)+ len(ts_crtl)-2)
    
    return t_val.values, std_error,np.sqrt(pooled_var),t_crit

In [None]:
with open(snakemake.input.data_tracker, 'r') as f:
    data_vars = yaml.load(f, yaml.SafeLoader)
    
erfs_vars = data_vars['ERFs']
cat = intake.open_esm_datastore(snakemake.input.catalog, registry=dvr)



In [None]:
def setup_forcing_table(erfs_list, cat,nSamples=1000,confidence_level=0.95, mask=None, only_regrid=False):
    erfs = {k : v for k,v in conf['variables'].items() if k in erfs_list}
    
    forcings = {k : v for k, v in erfs.items() if k.startswith('ERF')}

    derived_forcings = {k : v for k, v in erfs.items() if k.startswith('ERF') == False}
    df = pd.DataFrame(index=erfs_list, 
                      columns=['diff', 'diff_std', 'diff_ci_low', 'diff_ci_high','mean_exp',
                               'mean_ctrl', 'std_exp', 'std_ctrl', 't_val', 'diff_sigificant', 'std_error'])
    
    req_derived_ERFs = []
    for k, reqERF in derived_forcings.items():
        for e in reqERF:
            if e not in req_derived_ERFs:
                req_derived_ERFs.append(e)
    additional_ERFs = {k : conf['variables'][k] for k in req_derived_ERFs if k not in forcings.keys()}
    forcings_ts_exp = {}
    forcings_ts_ctrl = {}
    forcings = {**forcings, **additional_ERFs}
    for k,listvar in forcings.items():   
        exp_imbalance = calc_erf(listvar,exp_id,cat,time_slice=time_slice, n_samples=nSamples, mask=mask,
                                only_regrid=only_regrid)
        forcings_ts_exp[k] = exp_imbalance[-1] 
        df.loc[k, 'mean_exp'] = exp_imbalance[0]; df.loc[k, 'std_exp'] = exp_imbalance[1]
        ctrl_imbalance = calc_erf(listvar,ctrl_id,cat,time_slice=time_slice, n_samples=nSamples,mask=mask,
                                 only_regrid=only_regrid)
        forcings_ts_ctrl[k] = ctrl_imbalance[-1] 
        t_val, st_error,pooled_std,t_crit = t_test_diff(exp_imbalance[-1] ,ctrl_imbalance[-1],confidence_level)
        df.loc[k, 'mean_ctrl'] = ctrl_imbalance[0]; df.loc[k, 'std_ctrl'] = ctrl_imbalance[1]
        diff_m,diff_std, ciL, ciH = bootstrap_diff(exp_imbalance[0],exp_imbalance[1],ctrl_imbalance[0], 
                                          ctrl_imbalance[1], nSamples,confidence_level)
        df.loc[k,'diff'] = diff_m; df.loc[k,'diff_ci_low'] = ciL; df.loc[k,'diff_ci_high'] = ciH
        df.loc[k, 'diff_std'] = diff_std
        df.loc[k,'t_val'] = t_val
        df.loc[k, 'diff_sigificant'] = t_val > t_crit
        df.loc[k,'st_error'] = st_error
        df.loc[k,'pooled_std'] = pooled_std
    for k, listvar in derived_forcings.items():
        erf_tot = df.loc[listvar[0],'diff']
        erf_tot_std = df.loc[listvar[0],'diff_std']
        erf_effect = df.loc[listvar[1],'diff']
        erf_effect_std = df.loc[listvar[1],'diff_std']
        diff_m,diff_std, ciL, ciH = bootstrap_diff(erf_tot, erf_tot_std, erf_effect, 
                                                    erf_effect_std, nSamples, confidence_level)
        df.loc[k,'diff'] = diff_m; df.loc[k,'diff_ci_low'] = ciL; df.loc[k,'diff_ci_high'] = ciH
        df.loc[k, 'diff_std'] = diff_std
        df.loc[k, 'mean_exp'] = df.loc[listvar[0],'mean_exp'] - df.loc[listvar[1],'mean_exp']
        df.loc[k, 'mean_ctrl'] = df.loc[listvar[0],'mean_ctrl'] - df.loc[listvar[1],'mean_ctrl']
        df.loc[k, 'std_ctrl'] = (forcings_ts_ctrl[listvar[0]] - forcings_ts_ctrl[listvar[1]]).std().values
        df.loc[k, 'std_exp'] = (forcings_ts_exp[listvar[0]] - forcings_ts_exp[listvar[1]]).std().values
        t_val,st_error,pooled_std, t_crit = t_test_diff(forcings_ts_exp[listvar[0]] - forcings_ts_exp[listvar[1]], 
                                   forcings_ts_ctrl[listvar[0]] - forcings_ts_ctrl[listvar[1]], confidence_level)
        df.loc[k,'t_val'] = t_val
        df.loc[k,'st_error'] = st_error
        df.loc[k, 'diff_sigificant'] = t_val > t_crit
        df.loc[k,'pooled_std'] = pooled_std
        
    return df

In [None]:
if snakemake.rule == "calc_dust_regional_erf_table":
    import os
    if os.path.exists(snakemake.output.outpath_all) == False:
        df_all = setup_forcing_table(erfs_vars, cat, nSamples, confidence_level, only_regrid=True)
        df_all.to_csv(snakemake.output.outpath_all)
    
    if os.path.exists(snakemake.output.outpath_unmasked) == False:
        df_inversemask = setup_forcing_table(erfs_vars, cat, nSamples, confidence_level, mask=MASK[mod_id]<1) 
        df_inversemask.to_csv(snakemake.output.outpath_unmasked)
    if os.path.exists(snakemake.output.outpath_masked) == False:
        df_masked = setup_forcing_table(erfs_vars, cat, nSamples, confidence_level, mask=MASK[mod_id]>0)
        df_masked.to_csv(snakemake.output.outpath_masked)
    

else:
    df = setup_forcing_table(erfs_vars, cat, nSamples, confidence_level)
    df.to_csv(snakemake.output.outpath)