In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib as mpl
import numpy as np
import random
import sys
import os
from workflow.scripts.plotting_tools import get_model_colordict
%matplotlib inline

In [None]:
dfs = dfs = {p.split('.')[0].split('_')[-1]: pd.read_csv(p,index_col=0) for p in snakemake.input}
colors = get_model_colordict()

In [None]:
def get_forcing(forcing_var: str,dataframes: dict):
    k = next(iter(dataframes))
    outdf = pd.DataFrame(index=dataframes.keys(), columns=dataframes[k].columns)
    for k,df in dataframes.items():
        try:
            outdf.loc[k,:] = df.loc[forcing_var]
        except KeyError:
            pass
            
    return outdf

In [None]:
def plot_forcings_bar(df, nmodels, pos0, axes,dist=.9, spacing_frac=.01):
    dx = dist/nmodels
    df = df.sort_values('diff')
    pos=pos0
    gap =  spacing_frac/dist
    n=0
    for model, series in df.iterrows():
        for axi in axes:
            if series['diff_sigificant'] == True:
                hatch='\\\\'
            else:
                hatch=None
#             print(pos0+(dx-gap)*.5)
            axi.barh(pos,series['diff'],height=dx-gap,zorder=100, facecolor=colors[model], 
                     xerr=series['pooled_std'],capsize=2, hatch=hatch)
#             print(pos0+(dx-gap)*.5, series['diff'])
            axi.plot(series['diff'],pos,  marker=".",  markerfacecolor=colors[model],
                    markeredgecolor= 'k',ms=10, zorder=300)
           
        if series['diff'] is not np.nan:
            n+=1
        pos+=dx
    mean = df.mean(axis=0)
    for axi in axes:
        axi.plot(mean['diff'],pos0+dist/(n/2), mfc='#FF005E',linestyle='', 
                    marker='*', ms = 18, zorder=301, markeredgecolor='k')
        
    
    

In [None]:
def setup_plot():
    with mpl.rc_context({
        "xtick.major.size" : 6,
        "xtick.minor.size" : 3.8,
        "xtick.major.width" : 1.2,
        "xtick.minor.width" : .8,
        "axes.linewidth" : .8,
        "xtick.labelsize" : 10
    }):
        fig ,ax = plt.subplots(figsize=(6,11.8), ncols=3,gridspec_kw={'width_ratios':[0.2,.6,.2], 'wspace':0.001},sharey=True)
        ax[1].set_xlim(-.4,.4)
        ax[-1].set_xlim(.4,1.8)
        ax[0].set_xlim(-1.8,-.4)
        ax[0].set_ylim(-.25,9.1)
        ax[1].xaxis.set_major_locator(mpl.ticker.FixedLocator([-.4,-.2,0,.2,.4]))
        ax[0].xaxis.set_major_locator(mpl.ticker.FixedLocator([-1.5,-1]))
        ax[-1].xaxis.set_major_locator(mpl.ticker.FixedLocator([1,1.5]))
        for axi in ax:
            # axi.spines['top'].set_visible(False)
            axi.spines['right'].set_visible(False)
            axi.spines['left'].set_visible(False)
            axi.set_yticks([])
            axi.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(4))
            axi.axhline(5.85, linestyle='--',color='k')
            axi.axhline(2.64,linestyle='--', color='k')
            axi.tick_params(top=True, which='both',labeltop=True)
        ax[0].spines['left'].set_visible(True)
        ax[-1].spines['right'].set_visible(True)
        ax[0].text(0,.85
                   , 'Direct radiative effects', va='center',ha='right',rotation='vertical', transform=ax[0].transAxes,
                  fontsize='large')
        
        ax[0].text(0,.487
                   , 'Cloud radiative effects', va='center',ha='right',rotation='vertical', transform=ax[0].transAxes,
                  fontsize='large')
        
        ax[0].text(0,.155
                   , 'Effective dust radiative effect', va='center',ha='right',rotation='vertical', transform=ax[0].transAxes,
                  fontsize='large')

        ax[1].axvline(0.0, linestyle=':', linewidth=2, color='darkgrey')
        ax[1].set_xlabel('W m-2')
        
        ax[-1].text(.68,.935,'LW',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.825,'SW',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.71,'Net',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        
        ax[-1].text(.68,.605,'LW',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.495,'SW',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.381,'Net',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        
        ax[-1].text(.68,.22,'Surface',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.072,'TOA',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
    return ax

axes = setup_plot()
dist=1.3
plot_forcings_bar(get_forcing('ERFt', dfs),len(dfs), -0.1,axes, dist=dist)
plot_forcings_bar(get_forcing('ERFsurf', dfs),len(dfs), 1.3,axes,dist=dist)
plot_forcings_bar(get_forcing('CloudEff', dfs),len(dfs), 2.9,axes,dist=dist)
plot_forcings_bar(get_forcing('SWCloudEff', dfs),len(dfs), 4,axes,dist=dist)
plot_forcings_bar(get_forcing('LWCloudEff', dfs),len(dfs), 5,axes,dist=dist)
plot_forcings_bar(get_forcing('DirectEff', dfs),len(dfs), 6,axes,dist=dist)
plot_forcings_bar(get_forcing('SWDirectEff', dfs),len(dfs), 7.2,axes,dist=dist)
plot_forcings_bar(get_forcing('LWDirectEff', dfs),len(dfs), 8.2,axes,dist=dist)


from matplotlib.patches import Patch
from matplotlib.lines import Line2D

legments = [
    Line2D([0],[0], markerfacecolor=c, marker='o', label= m, color = 'w', markersize=10)
    for m, c in colors.items()
]

legments.append(Line2D([0],[0], markerfacecolor='#FF005E', marker='*', label= 'Model mean', color = 'w',
                    markeredgecolor='k',markersize=20))
fig = plt.gcf()
fig.legend(handles=legments,ncol=3, loc='upper center')
plt.savefig(snakemake.output[0],bbox_inches='tight')