In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib as mpl
import numpy as np
from workflow.scripts.plotting_tools import get_model_colordict
%matplotlib inline

In [None]:
order = [   
            
            'GISS-E2-1-G',
            'MIROC6',
            'GFDL-ESM4',
            'CNRM-ESM2-1',
            'UKESM1-0-LL',
            'IPSL-CM6A-LR-INCA',
            'NorESM2-LM',
            'MPI-ESM-1-2-HAM',
            'EC-Earth3-AerChem'
            
        ]

In [None]:
dfs = {p.split('.')[0].split('_')[-1]: pd.read_csv(p,index_col=0) for p in snakemake.input.forcing_tables}
colors = get_model_colordict()
diag_tab = pd.read_csv(snakemake.input.diag_tables[0], index_col=0)
model_order = snakemake.params.get('model_order', order)

In [None]:
def get_forcing(forcing_var: str,dataframes: dict):
    k = next(iter(dataframes))
    outdf = pd.DataFrame(index=dataframes.keys(), columns=dataframes[k].columns)
    for k,df in dataframes.items():
        try:
            outdf.loc[k,:] = df.loc[forcing_var]
        except KeyError:
            pass
            
    return outdf

In [None]:
from math import e


def plot_forcings_bar(df, nmodels, pos0, ax,dist=.9, spacing_frac=.01, scaling: pd.Series=None
                      ,model_order: list = None):
    dx = dist/nmodels
    
    if scaling is not None:
        sign_diff = df['diff_sigificant'].copy()
        df = df.divide(scaling, axis=0)
        df['diff_sigificant'] = sign_diff
    if model_order:
        df = df.loc[order]
    else:
        df = df.sort_values('diff')
    
    pos=pos0
    gap =  spacing_frac/dist
    n=0
    for model, series in df.iterrows():
        if series.isnull()['diff']:
            continue
        else:
            if series['diff_sigificant'] == True:
                hatch='\\\\'
            else:
                hatch=None
            ax.barh(pos,series['diff'],height=dx-gap,zorder=100, facecolor=colors[model], 
                        xerr=series['pooled_std'],capsize=2, hatch=hatch)
            ax.plot(series['diff'],pos,  marker=".",  markerfacecolor=colors[model],
                    markeredgecolor= 'k',ms=10, zorder=300)
            
            if series['diff'] is not np.nan:
                n+=1
            pos+=dx
    mean = df.mean(axis=0)
    
    ax.plot(mean['diff'],pos0+dist/(n/2), mfc='#FF005E',linestyle='', 
                    marker='*', ms = 18, zorder=301, markeredgecolor='k')
        
    
    

In [None]:
def setup_plot():
    with mpl.rc_context({
        "xtick.major.size" : 6,
        "xtick.minor.size" : 3.8,
        "xtick.major.width" : 1.2,
        "xtick.minor.width" : .8,
        "axes.linewidth" : .8,
        "xtick.labelsize" : 10
    }):
        fig ,ax = plt.subplots(figsize=(8.3,8),sharey=True)
        ax.set_xlim(-.25,0.075)
        

        ax.set_ylim(-.25,9.1)
        ax.xaxis.set_major_locator(mpl.ticker.FixedLocator([-.2,-.15,-.1,-.05,0,.05,]))
        ax.set_yticks([])
        ax.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(4))
        ax.axhline(3.9, linestyle='--',color='k')
        # ax.axhline(2.64,linestyle='--', color='k')
        ax.tick_params(top=True, which='both',labeltop=True)


        ax.spines['left'].set_visible(True)
        ax.spines['right'].set_visible(True)
        ax.text(0,.725
                   , 'Direct radiative forcing efficiency', va='center',ha='right',rotation='vertical', transform=ax.transAxes,
                  fontsize='large')

        ax.text(0,.22
                   , 'Surface forcing efficiency', va='center',ha='right',rotation='vertical', 
                transform=ax.transAxes,
                  fontsize='large')

        ax.axvline(0.0, linestyle=':', linewidth=2, color='darkgrey')
        ax.set_xlabel('W m-2 Tg-1')
        
        ax.text(.98,.95,'LW',va='center',ha='right', transform=ax.transAxes,fontsize='large')
        ax.text(.98,.75,'SW',va='center',ha='right', transform=ax.transAxes,fontsize='large')
        ax.text(.98,.55,'Net',va='center',ha='right', transform=ax.transAxes,fontsize='large')
#         ax.text(.98,.45,'LW',va='center',ha='right', transform=ax.transAxes,fontsize='large')
        ax.text(.98,.32,'SW Clearsky',va='center',ha='right', transform=ax.transAxes,fontsize='large')
        ax.text(.98,.05,'Clearsky',va='center',ha='right', transform=ax.transAxes,fontsize='large')
        

    return ax

axes = setup_plot()
dist=1.8
plot_forcings_bar(get_forcing('ERFsurfcs', dfs),len(dfs), .02,axes, dist=dist, scaling=diag_tab['$\Delta$DU burden \n (Tg)'], model_order=model_order)
plot_forcings_bar(get_forcing('ERFsurfswcs', dfs),len(dfs), 2,axes, dist=dist, scaling=diag_tab['$\Delta$DU burden \n (Tg)'],model_order=model_order)
plot_forcings_bar(get_forcing('DirectEff', dfs),len(dfs), 4.1,axes,dist=dist, scaling=diag_tab['$\Delta$DU burden \n (Tg)'], model_order=model_order)
plot_forcings_bar(get_forcing('SWDirectEff', dfs),len(dfs), 5.8,axes,dist=dist,scaling=diag_tab['$\Delta$DU burden \n (Tg)'], model_order=model_order)
plot_forcings_bar(get_forcing('LWDirectEff', dfs),len(dfs), 7.6,axes,dist=dist, scaling=diag_tab['$\Delta$DU burden \n (Tg)'], model_order=model_order)


from matplotlib.patches import Patch
from matplotlib.lines import Line2D

if model_order:
    colors = {model: colors[model] for model in model_order[::-1]}

legments = [
    Line2D([0],[0], markerfacecolor=c, marker='o', label= m, color = 'w', markersize=10)
    for m, c in colors.items()
]

legments.append(Line2D([0],[0], markerfacecolor='#FF005E', marker='*', label= 'Model mean', color = 'w',
                    markeredgecolor='k',markersize=20))
fig = plt.gcf()
fig.legend(handles=legments,ncol=1, bbox_to_anchor=[0.13, 0.61, 0.5, 0.5], loc='lower left',fontsize='small')
plt.savefig(snakemake.output[0],bbox_inches='tight')