In [None]:
import xarray as xr
import matplotlib.pyplot as plt
from workflow.scripts.utils import global_avg
import matplotlib as mpl
import pandas as pd
import numpy as np
from workflow.scripts.utils import t_test_diff_sample_means

In [None]:
ds_exp = {p.split("_")[-2]: xr.open_dataset(p).isel(time=slice(1,None)) for p in snakemake.input.exp_data}
ds_ctrl = {p.split("_")[-2]: xr.open_dataset(p) for p in snakemake.input.ctrl_data}

conf_level = snakemake.params.get('CI_alpha', 0.05)

In [None]:
def _fill_diag_df(data,df, variables):
    
    for variable in variables:
        temp_data = data.get(variable)
        if temp_data is None:
            df[variable] = np.nan
        else:
            df[variable] = temp_data.values
    return df

def calc_check_diff(ctrl,exp, ci):
    diff = xr.zeros_like(exp.isel(time=0).mean(dim=['lon','lat']))
    sig_ds = xr.zeros_like(diff)
    for dvar in exp.data_vars:
        t, p, d = t_test_diff_sample_means(da_ctrl =ctrl[dvar], 
                                            da_exp=exp[dvar],global_mean=True)
        
        diff = diff.assign({dvar:d})
        sig_ds = sig_ds.assign({dvar:p < ci})
    return diff, sig_ds

def create_diagnostics_df(ds_ctrl, ds_exp, mod_id,
                        variables=['lwp','pr','cl_low','cl_middle','cl_high','cdncvi','clt','clivi'],
                          ci=0.05
                        ):
    ctrl = ds_ctrl.mean(dim='time')
    exp = ds_exp.mean(dim='time')

    if np.all(ctrl.lat.values == exp.lat.values):
        diff, sig = calc_check_diff(ds_ctrl, ds_exp, ci)
    else:
        ds_exp = ds_exp.assign_coords(lat=ds_ctrl.lat.values)
        diff, sig = calc_check_diff(ds_ctrl, ds_exp,ci)
        
    
    ctrl = global_avg(ctrl)
    exp = global_avg(exp)
    
#     print(variables)

    series_exp = pd.Series(index=variables,name=mod_id)
    series_ctrl = pd.Series(index=variables,name=mod_id)
    series_diff = pd.Series(index=variables,name=mod_id)
    series_sig = pd.Series(index=variables,name=mod_id)
    reldiff = pd.Series(index=variables,name=mod_id)

    series_exp = _fill_diag_df(exp,series_exp,variables)
    series_ctrl = _fill_diag_df(ctrl,series_ctrl,variables)
    series_diff = _fill_diag_df(diff,series_diff,variables)
    series_sig = _fill_diag_df(sig,series_sig,variables)

    for variable in variables:

        if ctrl.get(variable) is not None:
            reldiff[variable] = diff[variable].values/ctrl[variable].values*100

    return series_exp,series_ctrl, reldiff, series_diff, series_sig


In [None]:
dfs = []
rel_dfs = []
dfs_ctrl =  []
dfs_diff = []
dfs_sig = []
for mod_id in ds_exp:
    exp,ctrl,temp_rel,diff, sig = create_diagnostics_df(ds_ctrl[mod_id], ds_exp[mod_id], mod_id,ci=conf_level)
    dfs.append(exp)
    rel_dfs.append(temp_rel)
    dfs_ctrl.append(ctrl)
    dfs_diff.append(diff)
    dfs_sig.append(sig)

In [None]:
df = pd.DataFrame(dfs).sort_index()
df_rel = pd.DataFrame(rel_dfs).sort_index()
df_ctrl = pd.DataFrame(dfs_ctrl).sort_index()
df_diff = pd.DataFrame(dfs_diff).sort_index()
df_sig = pd.DataFrame(dfs_sig).sort_index()

In [None]:
def _get_fmt(data):
    if abs(data) > 100:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.0f}")
    elif abs(data) > 1:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.1f}")
    elif abs(data) < 0.3 and  abs(data) > 0.0007:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.3f}")
    elif abs(data) < 0.0007:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.2e}")
        
    else:
        valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.2f}")
            # print(data[i,j],data[i,j] is np.nan)
    return valfmt_temp(data)
def annotate_heatmap(im,data, rel_change=None,sig_df=None,valfmt="{x:.2f}", 
                     textcolors=["black", "white"], threshold=3, **textkw):
    """
    A function to annotate a heatmap.
    """
    # Normalize the threshold to the images color range.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",verticalalignment="center")
    kw.update(textkw)
    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = mpl.ticker.StrMethodFormatter(valfmt)
    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    cdata = im.get_array().data



    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            if sig_df is not None:
                if sig_df[i,j] == True:
                    kw.update(weight='bold')
                else:
                    kw.update(weight='light')
            kw.update(color=textcolors[int(abs(cdata[i, j]) > threshold)], )
            if data[i,j] > 100:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.0f}")
            elif data[i,j] > 1:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.1f}")
            elif data[i,j] < 0.3:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.3f}")
            else:
                valfmt_temp = mpl.ticker.StrMethodFormatter("{x:.2f}")
            # print(data[i,j],data[i,j] is np.nan)
            if np.isnan(data[i,j]):
                texts.append('')
            else:
                
                if rel_change is not None:
                    text = im.axes.text(j, i, f"{_get_fmt(data[i, j])}\n ({_get_fmt(rel_change[i, j])} %)", **kw)
                else:
                    text = im.axes.text(j, i, f"{_get_fmt(data[i, j])}", **kw)
            texts.append(text)

    return texts

In [None]:
translate_column_names = {
    'lwp': {'name':'LWP \n (g m$^{-2}$)',
            'scale':1e3},
    'pr': {'name':'Precip \n (mm year$^{-1}$)',
            'scale': 1},
    'cl_low': {'name':'$\mathrm{CldFrac}_{low}$ \n [%]',
            'scale': 1},
    'cl_middle': {'name':'$\mathrm{CldFrac}_{mid}$ \n (%)',
            'scale': 1},
    'cl_high': {'name':'$\mathrm{CldFrac}_{high}$ \n (%)', 
            'scale': 1},
    'cdncvi': {'name':'$\mathrm{N_d}/1000$ \n (cm$^{-2}$)',
            'scale': 0.00000001},
    'clt': {'name':'CldFrac \n (-)',
            'scale': 1},
    'clivi': {'name':'IWP \n (g m$^{-2}$)',
            'scale': 1e3},

}

In [None]:
vis_df = df_diff.copy()
vis_rel = df_rel.copy()
# vis_df.loc['UKESM1-0-LL',['cl_low','cl_middle','cl_high']] = np.nan
for var in translate_column_names:
    vis_df[var] = vis_df[var]*translate_column_names[var]['scale']
    vis_df = vis_df.rename(columns={var:translate_column_names[var]['name']})
    vis_rel = vis_rel.rename(columns={var:translate_column_names[var]['name']})


# vis_df.loc['GISS-E2-1-G','$\mathrm{N_d}/1000$ \n (cm$^{-2}$)']=np.nan
# vis_rel.loc['GISS-E2-1-G','$\mathrm{N_d}/1000$ \n (cm$^{-2}$)']=np.nan

In [None]:
fig ,ax = plt.subplots(figsize=(8.3, 4.2))

ax.grid(color='w', linestyle='-', linewidth=3, which='minor')
ax.set_xticks(np.arange(vis_df.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(vis_df.shape[0]+1)-.5, minor=True)

ax.spines[:].set_visible(False)

cmap = mpl.colormaps.get_cmap('PiYG').resampled(9)
cmap.set_bad("#E6E6E6")
# cmap.set_gamma(0.9)
im=ax.imshow(vis_rel, cmap=cmap, aspect='auto', vmin=-2.2, vmax=2.2)
cbar = ax.figure.colorbar(im, ax=ax, location='right', pad=0.06, shrink=0.8, extend='both')
# cbar.ax.invert_yaxis()
cbar.ax.set_yticks([-2,-1.5,-1,-0.5,0,0.5,1,1.5,2])
# cbar.ax.set_yticklabels(['1','2','3','4','5','6','7','8','9'])
cbar.ax.set_ylabel('Relative change [%]')
# plt.savefig(snakemake.output.outpath, bbox_inches='tight')
ax.set_xticks(np.arange(vis_df.shape[1]), labels=vis_df.columns, fontsize=8)
ax.xaxis.tick_top()
ax.set_yticks(np.arange(vis_df.shape[0]), labels=vis_df.index, fontsize=10)
ax.tick_params(which="minor", bottom=False, left=False, top=False)

texts = annotate_heatmap(im, data=vis_df.values, sig_df=df_sig.values,threshold=1.8, fontsize=9)

plt.savefig(snakemake.output.outpath, bbox_inches='tight')
