In [None]:
import xarray as xr 
import matplotlib.pyplot as plt
from workflow.scripts.utils import global_avg, calculate_pooled_variance, diff_means_greater_than_varability
import pandas as pd
import numpy as np
import matplotlib as mpl

In [None]:

ds_exp = {p.split("_")[-2]: xr.open_dataset(p) for p in snakemake.input.exp_data}
ds_ctrl = {p.split("_")[-2]: xr.open_dataset(p) for p in snakemake.input.ctrl_data}

In [None]:
def _fill_diag_df(data,df, variables):
    for variable in variables:
        df[variable] = data[variable].values
        if variable == 'emidust' or variable == 'concdust':

            df[variable] = data[variable].values*1e-9
        else:
            df[variable] = data[variable].values
    return df
        


def create_diagnostics_df(ctrl, exp, mod_id,
                        varables=['emidust','lifetime','concdust','od550aer','abs550aer']):
    burd_ctrl = ctrl['concdust']*ctrl['cell_area']
    burd_exp = exp['concdust']*exp['cell_area']
    burd_ctrl = burd_ctrl.sum(dim=['lon','lat']).mean(dim='time')
    burd_exp = burd_exp.sum(dim=['lon','lat']).mean(dim='time')
    emis_ctrl = (ctrl['emidust'].mean(dim='time')*ctrl['cell_area']).sum(dim=['lon','lat'])
    emis_exp = (exp['emidust'].mean(dim='time')*exp['cell_area']).sum(dim=['lon','lat'])
    ctrl = ctrl.mean(dim='time')
    exp = exp.mean(dim='time')

    diff = exp - ctrl
    ctrl=global_avg(ctrl)
    diff = global_avg(diff)
    exp = global_avg(exp)
    
    ctrl = ctrl.assign(emidust=emis_ctrl)
    ctrl = ctrl.assign(concdust=burd_ctrl)
    exp = exp.assign(emidust=emis_exp)
    exp = exp.assign(concdust=burd_exp)
    diff['emidust'] = emis_exp-emis_ctrl
    diff['concdust'] = burd_exp-burd_ctrl

    series_exp = pd.Series(index=varables,name=mod_id)
    series_ctrl = pd.Series(index=varables,name=mod_id)
    reldiff = pd.Series(index=varables,name=mod_id)


    series_exp = _fill_diag_df(exp,series_exp,varables)
    series_ctrl = _fill_diag_df(ctrl,series_ctrl,varables)

    for variable in varables:
        reldiff[variable] = diff[variable].values/ctrl[variable].values*100
    
    return series_exp,series_ctrl, reldiff

In [None]:
def create_source_region_df(dset_dict, 
                            source_regions, 
                            var_id='emidust'):
    df = pd.DataFrame(index=dset_dict.keys(), columns= source_regions)
    # print(df)
    # return df
    for mod_id in dset_dict:
        for source_region in source_regions:
            df.loc[mod_id,source_region] = dset_dict[mod_id][f'{source_region} {var_id}'].mean().values
    df = df.astype(float).T
    return df


In [None]:
dfs = []
rel_dfs = []
dfs_ctrl =  []
for mod_id in ds_exp:
    exp,ctrl,temp_rel = create_diagnostics_df(ds_ctrl[mod_id], ds_exp[mod_id], mod_id)
    dfs.append(exp)
    rel_dfs.append(temp_rel)
    dfs_ctrl.append(ctrl)

dust_source_regions=list(snakemake.config['dust_source_regions'].keys())

df_source = create_source_region_df(ds_exp, dust_source_regions)

In [None]:
df = pd.DataFrame(dfs).sort_index()
df_rel = pd.DataFrame(rel_dfs).sort_index()
df_ctrl = pd.DataFrame(dfs_ctrl).sort_index()

In [None]:
def _get_fmt(data):
    if abs(data) > 100:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.0f}")
    elif abs(data) > 1:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.1f}")
    elif abs(data) < 0.3:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.3f}")
    else:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.2f}")
            # print(data[i,j],data[i,j] is np.nan)
    return valfmt_temp(data)
def annotate_heatmap(im,data, rel_change,valfmt="{x:.2f}", 
                     textcolors=["black", "white"], threshold=3, **textkw):
    """
    A function to annotate a heatmap.
    """
    # Normalize the threshold to the images color range.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",verticalalignment="center")
    kw.update(textkw)
    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = mpl.ticker.StrMethodFormatter(valfmt)
    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    cdata = im.get_array().data



    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(cdata[i, j] < threshold)])
            if data[i,j] > 100:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.0f}")
            elif data[i,j] > 1:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.1f}")
            elif data[i,j] < 0.3:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.3f}")
            else:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.2f}")
            # print(data[i,j],data[i,j] is np.nan)
            if np.isnan(data[i,j]):
                texts.append('')
            else:
                text = im.axes.text(j, i, f"{_get_fmt(data[i, j])}\n ({_get_fmt(rel_change[i, j])} %)", **kw)
            texts.append(text)

    return texts

In [None]:
translate_column_names = {
    'emidust' : "$\mathrm{Emiss}_{DU}$ \n (Tg/yr)" ,
    'lifetime' : 'Lifetime \n (Days)',
    'concdust' : 'DU burden \n (Tg)',
    'od550aer' : '$\mathrm{AOD}_{550mn}$',
    'abs550aer' : '$\mathrm{AAOD}_{550nm}$'

}

In [None]:

fig = plt.figure(figsize=(8.3,8.8))
# fig,ax = plt.subplots(figsize=(8.3,7.4))
gs = fig.add_gridspec(8,5,height_ratios=[1,1,1,1,1,1,0.1,1])
ax = fig.add_subplot(gs[:4, 1:])

ax.grid(color='w', linestyle='-', linewidth=3, which='minor')
df = df.rename(columns=translate_column_names)
ax.set_xticks(np.arange(df.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(df.shape[0]+1)-.5, minor=True)

ax.spines[:].set_visible(False)

cmap = mpl.colormaps.get_cmap('YlGn_r').resampled(9)
cmap.set_bad("#E6E6E6")
im=ax.imshow(df.rank(ascending=False), cmap=cmap, vmin=1, vmax=10, aspect='auto')
cbar = ax.figure.colorbar(im, ax=ax, location='right', pad=0.03, shrink=0.8)
cbar.ax.invert_yaxis()
cbar.ax.set_yticks([2,3,4,5,6,7,8,9,10])
cbar.ax.set_yticklabels(['1','2','3','4','5','6','7','8','9'])
cbar.ax.set_ylabel('Rank', rotation=0, va='center')
# plt.savefig(snakemake.output.outpath, bbox_inches='tight')
ax.set_xticks(np.arange(df.shape[1]), labels=df.columns, fontsize=9)
ax.xaxis.tick_top()
ax.set_yticks(np.arange(df.shape[0]), labels=df.index, fontsize=10)
ax.tick_params(which="minor", bottom=False, left=False, top=False)

texts = annotate_heatmap(im, data=df.values, rel_change=df_rel.values, threshold=4)

i = 0
t_i = 0

color_dict = {'Western north Africa': '#1f77b4',
                 'Eastern north Africa': '#ff7f0e',
                 'Sahel': '#2ca02c',
                 'Middle east': '#d62728',
                 'Central Asia': '#9467bd',
                 'East Asia': '#8c564b',
                 'Southern Africa': '#e377c2',
                 'North America': '#7f7f7f',
                 'Australia': '#bcbd22',
                 'South America': '#17becf'}

df_color = df_source.index.map(color_dict)
for name, data in df_source.iteritems():
    # print(i, data[1])
    
    if i > 4:
            
        tax = fig.add_subplot(gs[6:, t_i])
        t_i = t_i + 1
    else:
        tax = fig.add_subplot(gs[4:6, i])
    data.plot.pie(ax=tax,colors=df_color, labels=None, autopct='%1.0f%%', textprops={'fontsize': 7}, pctdistance=1.3,)
    tax.set_title(name, fontsize=10)
    tax.set_ylabel('')
    i = i + 1
from matplotlib.patches import Patch
legelem = [Patch(facecolor=color_dict[source], label=source) for source in color_dict]
fig.legend(handles=legelem, loc='lower center', ncol=5, fontsize=8, bbox_to_anchor=(0.5, 0.03), frameon=False)

ax.text(-0.15, 1.1, 'a)', transform=ax.transAxes, fontsize=12, fontweight='bold', va='top', ha='left')
fig.text(0.1,0.45, 'b)', fontsize=12, fontweight='bold', va='top', ha='left')
plt.savefig(snakemake.output.outpath, bbox_inches='tight')
