In [None]:
import xarray as xr 
import matplotlib.pyplot as plt
from workflow.scripts.utils import global_avg, calculate_pooled_variance, diff_means_greater_than_varability
import pandas as pd
import numpy as np
import matplotlib as mpl
%matplotlib inline

In [None]:

ds_exp = {p.split("_")[-2]: xr.open_dataset(p).isel(time=slice(1,None)) for p in snakemake.input.exp_data}
ds_ctrl = {p.split("_")[-2]: xr.open_dataset(p) for p in snakemake.input.ctrl_data}

In [None]:
ds_exp[
    'NorESM2-LM'
]

In [None]:
def _fill_diag_df(data,df, variables):
    for variable in variables:
        dat = data.get(variable, None)
        if dat is None:
            df[variable] = np.nan
            
        else:
            if variable == 'emidust' or variable == 'concdust_sum':

                df[variable] = dat.values*1e-9
            else:
                df[variable] = dat.values
        
    return df
        
    
def calc_ang4487aer(ds):
    if ds.get('od870aer', None) is None:
        return ds

    if ds.get('od440aer', None) is not None:
        od_ds = ds['od440aer']
        od0 = 440
    else:
        od_ds = ds['od550aer']
        od0 = 550
        
    ang4487aer = -np.log(od_ds/ds['od870aer'])/np.log(od0/870)
    ang4487aer.attrs['units'] = '1'
    ang4487aer.attrs['long_name'] = f'Angstrom exponent {od0}-870 nm'
    ds = ds.assign(ang4487aer=ang4487aer)
    return ds
    

def create_diagnostics_df(ctrl, exp, mod_id,
                        variables=['emidust','lifetime','concdust','concdust_sum',
                                  'od550aer','abs550aer','ang4487aer', 
                                  'od550dust_mass_ext','od550dust','od440aer','od870aer','totdust','wetdust']):
    burd_ctrl = ctrl['concdust'].dropna(dim='time')*ctrl['cell_area']
    burd_exp = exp['concdust'].dropna(dim='time')*exp['cell_area']
    burd_ctrl = burd_ctrl.sum(dim=['lon','lat']).mean(dim='time')
    burd_exp = burd_exp.sum(dim=['lon','lat']).mean(dim='time')
    emis_ctrl = (ctrl['emidust'].dropna(dim='time').mean(dim='time')*ctrl['cell_area']).sum(dim=['lon','lat'])
    emis_exp = (exp['emidust'].dropna(dim='time').mean(dim='time')*exp['cell_area']).sum(dim=['lon','lat'])
    ctrl = ctrl.dropna(dim='time').mean(dim='time')
    exp = exp.dropna(dim='time').mean(dim='time')

    diff = exp - ctrl
    if 'ang4487aer' in variables:
        diff = calc_ang4487aer(diff)
    ctrl=global_avg(ctrl)
    diff = global_avg(diff)
    exp = global_avg(exp)
    
    ctrl = ctrl.assign(emidust=emis_ctrl)
    ctrl = ctrl.assign(concdust_sum=burd_ctrl)
    exp = exp.assign(emidust=emis_exp)
    exp = exp.assign(concdust_sum=burd_exp)
    diff['emidust'] = emis_exp-emis_ctrl
    diff['concdust_sum'] = burd_exp-burd_ctrl

    series_exp = pd.Series(index=variables,name=mod_id)
    series_ctrl = pd.Series(index=variables,name=mod_id)
    series_diff = pd.Series(index=variables,name=mod_id)
    reldiff = pd.Series(index=variables,name=mod_id)


    series_exp = _fill_diag_df(exp,series_exp,variables)
    series_ctrl = _fill_diag_df(ctrl,series_ctrl,variables)
    series_diff = _fill_diag_df(diff,series_diff,variables)
    for variable in variables:
        if ctrl.get(variable, None):
            reldiff[variable] = diff[variable].values/ctrl[variable].values*100
        else:
            reldiff[variable] = np.nan
    return series_exp,series_ctrl, reldiff, series_diff

In [None]:
dfs = []
rel_dfs = []
dfs_ctrl =  []
dfs_diff = []
for mod_id in ds_exp:
    exp,ctrl,temp_rel,diff = create_diagnostics_df(ds_ctrl[mod_id], ds_exp[mod_id], mod_id)
    dfs.append(exp)
    rel_dfs.append(temp_rel)
    dfs_ctrl.append(ctrl)
    dfs_diff.append(diff)


In [None]:
df = pd.DataFrame(dfs).sort_index()
df_rel = pd.DataFrame(rel_dfs).sort_index()
df_ctrl = pd.DataFrame(dfs_ctrl).sort_index()
df_diff = pd.DataFrame(dfs_diff).sort_index()

In [None]:
df_ctrl['wetratio'] = df_ctrl['wetdust']/df_ctrl['totdust']
df_diff['wetratio'] = df_diff['wetdust']/df_diff['totdust']
df['wetratio'] = df['wetdust']/df['totdust']
df_diff['ang4487aer'] = df_diff['ang4487aer'].where(df_diff['ang4487aer'] >0, np.nan)
# df_rel['wetratio'] = 

In [None]:
def _get_fmt(data):
    if abs(data) > 100:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.0f}")
    elif abs(data) > 1:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.1f}")
    elif abs(data) < 1 and abs(data) > 0.008:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.3f}")
    else:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.4f}")
    return valfmt_temp(data)
def annotate_heatmap(im,data, rel_change,valfmt="{x:.2f}", 
                     textcolors=["black", "white"], threshold=3, **textkw):
    """
    A function to annotate a heatmap.
    """
    # Normalize the threshold to the images color range.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",verticalalignment="center")
    kw.update(textkw)
    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = mpl.ticker.StrMethodFormatter(valfmt)
    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    cdata = im.get_array().data

    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(cdata[i, j] < threshold)])
            if np.isnan(data[i,j]):
                texts.append('')
            elif np.isnan(rel_change[i,j]):
                text = im.axes.text(j, i, f"{_get_fmt(data[i, j])}", **kw)
            else:
                text = im.axes.text(j, i, f"{_get_fmt(data[i, j])}\n ({_get_fmt(rel_change[i, j])} %)", **kw)
            texts.append(text)

    return texts

In [None]:
translate_column_names = {
    'emidust' : "$\Delta \mathrm{Emiss}_{DU}$ \n (Tg/yr)" ,
    'wetratio': 'DU$_{Wetdep}$ \n /DU$_{Totdep}$',
    'concdust_sum' : '$\Delta$DU burden \n (Tg)',
    'od550aer' : '$\Delta \mathrm{AOD}_{550mn}$',
    'od550dust_mass_ext' : 'DU MEC \n (m2 g-1)',
    'abs550aer_mass_abs' : 'DU MAC \n (m2 g-1)',
    'lifetime' : 'Lifetime \n (Days)',
    'abs550aer' : '$\Delta \mathrm{AAOD}_{550nm}$',
    'ang4487aer': '$\mathrm{Angström}_{440-870}$'
#     'radatm' : 'ARC \n (w m-2)',
#     'radatmcs' : 'ARC$_{clearsky}$\n (w m-2)'
}

In [None]:
context_dict = {
    'axes.labelsize':7.5,
    'axes.spines.left': False,
    'axes.spines.right': False,
    'axes.spines.top': False,
    'axes.spines.bottom': False,
    'xtick.labelsize': 5.5,
    'ytick.labelsize': 6.5
    
}

In [None]:
lh_cond_water = 2260*(10**3)

vis_df = df[['lifetime','od550dust_mass_ext']]
vis_df = vis_df.join(df_diff[['emidust','wetratio','concdust_sum', 'od550aer','abs550aer', 'ang4487aer']])
df_rel = df_rel[['emidust','concdust_sum', 'od550aer','abs550aer','lifetime']]
vis_df['abs550aer_mass_abs'] = df_diff['abs550aer']/(df_diff['concdust']*1e3)
vis_df['od550aer_mass_ext'] = df_diff['od550aer']/(df_diff['concdust']*1e3)
df_rel['abs550aer_mass_abs'] = np.nan
df_rel['od550dust_mass_ext'] = np.nan
df_rel['wetratio'] = np.nan
df_rel['ang4487aer'] = np.nan
df_rel['lifetime'] = np.nan
rank_df = abs(vis_df[['emidust','wetratio','concdust_sum', 'od550aer','abs550aer','ang4487aer']]).rank(ascending=False)
rank_df = rank_df.join(abs(vis_df[['abs550aer_mass_abs','od550dust_mass_ext','lifetime']]).rank(ascending=False))

variable_order = ['emidust','wetratio','concdust_sum', 'od550aer','abs550aer','ang4487aer','abs550aer_mass_abs','od550dust_mass_ext','lifetime']

vis_df = vis_df[variable_order]
df_rel = df_rel[variable_order]
vis_df = vis_df.rename(columns=translate_column_names)
rank_df = rank_df.rename(columns=translate_column_names)

In [None]:
if snakemake.output.outpath.endswith('.csv'):
    vis_df.to_csv(snakemake.output.outpath)
    df_rel.to_csv(snakemake.output.relpath)
else:
    with mpl.rc_context(context_dict):
        fig = plt.figure(figsize=(8.3,4.2))
        # fig,ax = plt.subplots(figsize=(8.3,7.4))
        ax = fig.add_subplot(111)

        ax.grid(color='w', linestyle='-', linewidth=3, which='minor')
        ax.set_xticks(np.arange(vis_df.shape[1]+1)-.5, minor=True)
        ax.set_yticks(np.arange(vis_df.shape[0]+1)-.5, minor=True)


        cmap = mpl.colormaps.get_cmap('YlGn_r').resampled(9)
        cmap.set_bad("#E6E6E6")
        im=ax.imshow(rank_df, cmap=cmap, vmin=1, vmax=10, aspect='auto')
        cbar = ax.figure.colorbar(im, ax=ax, location='right', pad=0.06, shrink=0.8)
        cbar.ax.invert_yaxis()
        cbar.ax.set_yticks([2,3,4,5,6,7,8,9,10])
        cbar.ax.set_yticklabels(['1','2','3','4','5','6','7','8','9'])
        cbar.ax.set_title('Rank', fontsize=7)
        # plt.savefig(snakemake.output.outpath, bbox_inches='tight')
        ax.set_xticks(np.arange(vis_df.shape[1]), labels=vis_df.columns)
        ax.xaxis.tick_top()
        ax.set_yticks(np.arange(df.shape[0]), labels=vis_df.index)
        ax.tick_params(which="minor", bottom=False, left=False, top=False)

    #     ax.annotate('Ranked by relative change', xy=(0.295, 0.452), xycoords='figure fraction',
    #              xytext=(0, -16), textcoords='offset points',
    #              ha="center", va="bottom",
    #                arrowprops=
    #                 dict(arrowstyle="-[",
    #                      mutation_scale=91,
    #                      mutation_aspect=.35
    #                             ))
        
    #     ax.annotate('Ranked by absolute value', xy=(0.6, 0.452), xycoords='figure fraction',
    #              xytext=(0, -16), textcoords='offset points',
    #              ha="center", va="bottom",
    #                arrowprops=
    #                 dict(arrowstyle="-[",
    #                      mutation_scale=91,
    #                      mutation_aspect=.35
    #                             ))

        texts = annotate_heatmap(im, data=vis_df.values, rel_change=df_rel.values, threshold=4, fontsize=7)


        plt.savefig(snakemake.output.outpath, bbox_inches='tight')
