In [None]:
import xarray as xr 
import matplotlib.pyplot as plt
from workflow.scripts.utils import global_avg, calculate_pooled_variance, diff_means_greater_than_varability
import pandas as pd
import numpy as np
import matplotlib as mpl
%matplotlib inline

In [None]:

ds_exp = {p.split("_")[-2]: xr.open_dataset(p).isel(time=slice(1,None)) for p in snakemake.input.exp_data}
ds_ctrl = {p.split("_")[-2]: xr.open_dataset(p) for p in snakemake.input.ctrl_data}

In [None]:
def _fill_diag_df(data,df, variables):
    for variable in variables:
        dat = data.get(variable, None)
        if dat is None:
            df[variable] = np.nan
            
        else:
            if variable == 'emidust' or variable == 'concdust':

                df[variable] = dat.values*1e-9
            else:
                df[variable] = dat.values
        
    return df
        


def create_diagnostics_df(ctrl, exp, mod_id,
                        varables=['emidust','lifetime','concdust','concdust_sum','od550aer','abs550aer','radatm', 'radatmcs','ang4487aer', 'od550dust_mass_ext']):
    burd_ctrl = ctrl['concdust'].dropna(dim='time')*ctrl['cell_area']
    burd_exp = exp['concdust'].dropna(dim='time')*exp['cell_area']
    burd_ctrl = burd_ctrl.sum(dim=['lon','lat']).mean(dim='time')
    burd_exp = burd_exp.sum(dim=['lon','lat']).mean(dim='time')
    emis_ctrl = (ctrl['emidust'].dropna(dim='time').mean(dim='time')*ctrl['cell_area']).sum(dim=['lon','lat'])
    emis_exp = (exp['emidust'].dropna(dim='time').mean(dim='time')*exp['cell_area']).sum(dim=['lon','lat'])
    ctrl = ctrl.dropna(dim='time').mean(dim='time')
    exp = exp.dropna(dim='time').mean(dim='time')

    diff = exp - ctrl
    ctrl=global_avg(ctrl)
    diff = global_avg(diff)
    exp = global_avg(exp)
    
    ctrl = ctrl.assign(emidust=emis_ctrl)
    ctrl = ctrl.assign(concdust_sum=burd_ctrl)
    exp = exp.assign(emidust=emis_exp)
    exp = exp.assign(concdust_sum=burd_exp)
    diff['emidust'] = emis_exp-emis_ctrl
    diff['concdust_sum'] = burd_exp-burd_ctrl

    series_exp = pd.Series(index=varables,name=mod_id)
    series_ctrl = pd.Series(index=varables,name=mod_id)
    series_diff = pd.Series(index=varables,name=mod_id)
    reldiff = pd.Series(index=varables,name=mod_id)


    series_exp = _fill_diag_df(exp,series_exp,varables)
    series_ctrl = _fill_diag_df(ctrl,series_ctrl,varables)
    series_diff = _fill_diag_df(diff,series_diff,varables)
    for variable in varables:
        if ctrl.get(variable, None):
            reldiff[variable] = diff[variable].values/ctrl[variable].values*100
        else:
            reldiff[variable] = np.nan
    return series_exp,series_ctrl, reldiff, series_diff

In [None]:
def create_source_region_df(dset_dict, 
                            source_regions, 
                            var_id='emidust'):
    df = pd.DataFrame(index=dset_dict.keys(), columns= source_regions)
    # print(df)
    # return df
    for mod_id in dset_dict:
        for source_region in source_regions:
            df.loc[mod_id,source_region] = dset_dict[mod_id][f'{source_region} {var_id}'].mean().values
    df = df.astype(float).T
    return df


In [None]:
dfs = []
rel_dfs = []
dfs_ctrl =  []
dfs_diff = []
for mod_id in ds_exp:
    exp,ctrl,temp_rel,diff = create_diagnostics_df(ds_ctrl[mod_id], ds_exp[mod_id], mod_id)
    dfs.append(exp)
    rel_dfs.append(temp_rel)
    dfs_ctrl.append(ctrl)
    dfs_diff.append(diff)
dust_source_regions=list(snakemake.config['dust_source_regions'].keys())

df_source = create_source_region_df(ds_exp, dust_source_regions)

In [None]:
df = pd.DataFrame(dfs).sort_index()
df_rel = pd.DataFrame(rel_dfs).sort_index()
df_ctrl = pd.DataFrame(dfs_ctrl).sort_index()
df_diff = pd.DataFrame(dfs_diff).sort_index()

In [None]:
def _get_fmt(data):
    if abs(data) > 100:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.0f}")
    elif abs(data) > 1:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.1f}")
    elif abs(data) < 1 and abs(data) > 0.008:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.3f}")
    else:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.2E}")
            # print(data[i,j],data[i,j] is np.nan)
    return valfmt_temp(data)
def annotate_heatmap(im,data, rel_change,valfmt="{x:.2f}", 
                     textcolors=["black", "white"], threshold=3, **textkw):
    """
    A function to annotate a heatmap.
    """
    # Normalize the threshold to the images color range.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",verticalalignment="center")
    kw.update(textkw)
    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = mpl.ticker.StrMethodFormatter(valfmt)
    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    cdata = im.get_array().data



    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(cdata[i, j] < threshold)])
            if data[i,j] > 100:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.0f}")
            elif data[i,j] > 1:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.1f}")
            elif data[i,j] < 0.3 and data[i,j] > 0.008:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.3f}")
            else:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.2E}")
            # print(data[i,j],data[i,j] is np.nan)
            if np.isnan(data[i,j]):
                texts.append('')
            elif np.isnan(rel_change[i,j]):
                text = im.axes.text(j, i, f"{_get_fmt(data[i, j])}", **kw)
            else:
                text = im.axes.text(j, i, f"{_get_fmt(data[i, j])}\n ({_get_fmt(rel_change[i, j])} %)", **kw)
            texts.append(text)

    return texts

In [None]:
translate_column_names = {
    'emidust' : "$\mathrm{Emiss}_{DU}$ \n (Tg/yr)" ,
    'concdust_sum' : 'DU burden \n (Tg)',
    'od550aer' : '$\mathrm{AOD}_{550mn}$',
    'od550dust_mass_ext' : 'DU MEC \n (m2 g-1)',
    'abs550aer_mass_abs' : 'DU MAC \n (m2 g-1)',
    'ang4487aer' : 'AE',
    'lifetime' : 'Lifetime \n (Days)',
    'abs550aer' : '$\mathrm{AAOD}_{550nm}$',
    'radatm' : 'ARC \n (w m-2)',
    'radatmcs' : 'ARC$_{clearsky}$\n (w m-2)'
    
     

}

In [None]:
context_dict = {
    'axes.labelsize':7.5,
    'axes.spines.left': False,
    'axes.spines.right': False,
    'axes.spines.top': False,
    'axes.spines.bottom': False,
    'xtick.labelsize': 7,
    'ytick.labelsize': 7
    
}

In [None]:

vis_df = df[['emidust','concdust_sum', 'od550aer','abs550aer','ang4487aer','lifetime','radatm','radatmcs','od550dust_mass_ext']]
df_rel = df_rel[['emidust','concdust_sum', 'od550aer','abs550aer','ang4487aer','lifetime','radatm','radatmcs']]
vis_df['radatm'] = df_diff['radatm']
vis_df['radatmcs'] = df_diff['radatmcs']
vis_df['abs550aer_mass_abs'] = df_diff['abs550aer']/(df_diff['concdust']*1e3)
vis_df['od550aer_mass_ext'] = df_diff['od550aer']/(df_diff['concdust']*1e3)
# vis_df['od550aer_mass_ext'] = df_diff['od550aer_mass_ext']
df_rel['radatm'] = np.nan
df_rel['abs550aer_mass_abs'] = np.nan
df_rel['od550dust_mass_ext'] = np.nan
df_rel['radatmcs'] = np.nan
df_rel['lifetime'] = np.nan
rank_df = abs(df_rel[['emidust','concdust_sum', 'od550aer','abs550aer','ang4487aer']]).rank(ascending=False)
rank_df = rank_df.join(abs(vis_df[['abs550aer_mass_abs','od550dust_mass_ext','lifetime','radatm','radatmcs']]).rank(ascending=False))
vis_df = vis_df[['emidust','concdust_sum', 'od550aer','abs550aer','ang4487aer','abs550aer_mass_abs','od550dust_mass_ext','lifetime','radatm','radatmcs']]
vis_df = vis_df.rename(columns=translate_column_names)
rank_df = rank_df.rename(columns=translate_column_names)

In [None]:
with mpl.rc_context(context_dict):
    fig = plt.figure(figsize=(8.3,8.8))
    # fig,ax = plt.subplots(figsize=(8.3,7.4))
    gs = fig.add_gridspec(9,5,height_ratios=[1,1,1,1,1,0.01,1.2,0.1,1.2])
    ax = fig.add_subplot(gs[:4, :])

    ax.grid(color='w', linestyle='-', linewidth=3, which='minor')
    ax.set_xticks(np.arange(vis_df.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(vis_df.shape[0]+1)-.5, minor=True)


    cmap = mpl.colormaps.get_cmap('YlGn_r').resampled(9)
    cmap.set_bad("#E6E6E6")
    im=ax.imshow(rank_df, cmap=cmap, vmin=1, vmax=10, aspect='auto')
    cbar = ax.figure.colorbar(im, ax=ax, location='right', pad=0.06, shrink=0.8)
    cbar.ax.invert_yaxis()
    cbar.ax.set_yticks([2,3,4,5,6,7,8,9,10])
    cbar.ax.set_yticklabels(['1','2','3','4','5','6','7','8','9'])
    cbar.ax.set_title('Rank', fontsize=7)
    # plt.savefig(snakemake.output.outpath, bbox_inches='tight')
    ax.set_xticks(np.arange(vis_df.shape[1]), labels=vis_df.columns)
    ax.xaxis.tick_top()
    ax.set_yticks(np.arange(df.shape[0]), labels=vis_df.index)
    ax.tick_params(which="minor", bottom=False, left=False, top=False)

    ax.annotate('Ranked by relative change', xy=(0.295, 0.452), xycoords='figure fraction',
             xytext=(0, -16), textcoords='offset points',
             ha="center", va="bottom",
               arrowprops=
                dict(arrowstyle="-[",
                     mutation_scale=91,
                     mutation_aspect=.35
                            ))
    
    ax.annotate('Ranked by absolute value', xy=(0.6, 0.452), xycoords='figure fraction',
             xytext=(0, -16), textcoords='offset points',
             ha="center", va="bottom",
               arrowprops=
                dict(arrowstyle="-[",
                     mutation_scale=91,
                     mutation_aspect=.35
                            ))

    texts = annotate_heatmap(im, data=vis_df.values, rel_change=df_rel.values, threshold=4, fontsize=7)

    i = 0
    t_i = 0

    color_dict = {'Western north Africa': '#1f77b4',
                     'Eastern north Africa': '#ff7f0e',
                     'Sahel': '#2ca02c',
                     'Middle east': '#d62728',
                     'Central Asia': '#9467bd',
                     'East Asia': '#8c564b',
                     'Southern Africa': '#e377c2',
                     'North America': '#7f7f7f',
                     'Australia': '#bcbd22',
                     'South America': '#17becf'}

    df_color = df_source.index.map(color_dict)
    def my_autopct(pct):
        return '{:1.0f}%'.format(pct) if pct > 3 else ''
    
    for name, data in df_source.iteritems():
        # print(i, data[1])

        if i > 4:

            tax = fig.add_subplot(gs[8, t_i])
            t_i = t_i + 1
        else:
            tax = fig.add_subplot(gs[6, i])
        data.plot.pie(ax=tax,colors=df_color, labels=None, autopct=my_autopct, textprops={'fontsize': 6.5}, pctdistance=1.2,)
        tax.set_title(name, fontsize=8)
        tax.set_ylabel('')
        i = i + 1
    from matplotlib.patches import Patch
    legelem = [Patch(facecolor=color_dict[source], label=source) for source in color_dict]
    fig.legend(handles=legelem, loc='lower center', ncol=5, fontsize=7, bbox_to_anchor=(0.5, 0.03), frameon=False)

    fig.text(0.1,0.9, 'a)',  fontsize=12, fontweight='bold', va='top', ha='left')
    fig.text(0.1,0.42, 'b)', fontsize=12, fontweight='bold', va='top', ha='left')
    fig.text(0.33,0.42, 'Relative distribution of dust emissions', fontsize=12, va='top', ha='left')
    plt.savefig(snakemake.output.outpath, bbox_inches='tight')
