In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib as mpl
import numpy as np
import random
import sys
import os
from workflow.scripts.plotting_tools import get_model_colordict
%matplotlib inline

In [None]:
order = [   
            
            'GISS-E2-1-G',
            'MIROC6',
            'GFDL-ESM4',
            'CNRM-ESM2-1',
            'UKESM1-0-LL',
            'IPSL-CM6A-LR-INCA',
            'NorESM2-LM',
            'MPI-ESM-1-2-HAM',
            'EC-Earth3-AerChem'
            
        ]

In [None]:
dfs = {p.split('.')[0].split('_')[-1]: pd.read_csv(p,index_col=0) for p in snakemake.input.forcing_tables}
colors = get_model_colordict()
diag_tab = pd.read_csv(snakemake.input.diag_tables[0], index_col=0)
model_order = snakemake.params.get('model_order', order)

In [None]:
def get_forcing(forcing_var: str,dataframes: dict):
    k = next(iter(dataframes))
    outdf = pd.DataFrame(index=dataframes.keys(), columns=dataframes[k].columns)
    for k,df in dataframes.items():
        try:
            outdf.loc[k,:] = df.loc[forcing_var]
        except KeyError:
            pass
            
    return outdf

In [None]:
def plot_forcings_bar(df, nmodels, pos0, axes,dist=.9, spacing_frac=.01, scaling: pd.Series=None,
                     model_order: list=None):
    dx = dist/nmodels
    
    if scaling is not None:
        df = df.divide(scaling, axis=0)
    if model_order:
        df = df.loc[order]
    else:
        df = df.sort_values('diff')
    
    pos=pos0
    gap =  spacing_frac/dist
    n=0
    for model, series in df.iterrows():
        if series.isnull()['diff']:
            continue
        else:
            for axi in axes:
                
                if series['diff_sigificant'] == True:
                    hatch='\\\\'
                else:
                    hatch=None
    #             print(pos0+(dx-gap)*.5)
                axi.barh(pos,series['diff'],height=dx-gap,zorder=100, facecolor=colors[model], 
                        xerr=series['pooled_std'],capsize=2, hatch=hatch)
    #             print(pos0+(dx-gap)*.5, series['diff'])
                axi.plot(series['diff'],pos,  marker=".",  markerfacecolor=colors[model],
                        markeredgecolor= 'k',ms=10, zorder=300)
            
            if series['diff'] is not np.nan:
                n+=1
            pos+=dx
    mean = df.mean(axis=0)
    diff_mean = mean['diff']
    if diff_mean < -0.4:
        axes[0].plot(mean['diff'],pos0+(pos-pos0)/2, mfc='#FF005E',linestyle='', 
                    marker='*', ms = 18, zorder=301, markeredgecolor='k')
    elif diff_mean <= 0.4 and diff_mean >= -0.4:
        axes[1].plot(mean['diff'],pos0+(pos-pos0)/2, mfc='#FF005E',linestyle='',
                    marker='*', ms = 18, zorder=301, markeredgecolor='k')
    else:
        axes[2].plot(mean['diff'],pos0+(pos-pos0)/2, mfc='#FF005E',linestyle='', 
                    marker='*', ms = 18, zorder=301, markeredgecolor='k')
    # print(len(axes))
    # for i in enumerate(axes):
    #     if 
    #     axi.plot(mean['diff'],pos0+(pos-pos0)/2, mfc='#FF005E',linestyle='', 
    #                 marker='*', ms = 18, zorder=301, markeredgecolor='k')
        
    
    

In [None]:
def setup_plot():
    with mpl.rc_context({
        "xtick.major.size" : 6,
        "xtick.minor.size" : 3.8,
        "xtick.major.width" : 1.2,
        "xtick.minor.width" : .8,
        "axes.linewidth" : .8,
        "xtick.labelsize" : 10
    }):
        fig ,ax = plt.subplots(figsize=(7.7,9), ncols=4,gridspec_kw={'width_ratios':[0.35,0.2,.6,.2], 'wspace':0.001},sharey=True)
        ax[2].set_xlim(-.4,.4)
        ax[-1].set_xlim(.4,1.8)
        ax[1].set_xlim(-1.8,-.4)
        ax[1].set_ylim(-.25,9.1)
        ax[2].xaxis.set_major_locator(mpl.ticker.FixedLocator([-.4,-.2,0,.2,.4]))
        ax[1].xaxis.set_major_locator(mpl.ticker.FixedLocator([-1.5,-1]))
        ax[-1].xaxis.set_major_locator(mpl.ticker.FixedLocator([1,1.5]))
        for axi in ax[1:]:
            # axi.spines['top'].set_visible(False)
            axi.spines['right'].set_visible(False)
            axi.spines['left'].set_visible(False)
            axi.set_yticks([])
            axi.xaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(4))
            axi.axhline(5.85, linestyle='--',color='k')
            axi.axhline(2.64,linestyle='--', color='k')
            axi.tick_params(top=True, which='both',labeltop=True)
        ax[0].spines['left'].set_visible(False)
        ax[0].spines['right'].set_visible(False)
        ax[0].spines['top'].set_visible(False)
        ax[0].spines['bottom'].set_visible(False)
        ax[0].set_yticks([])
        ax[0].set_xticks([])
        ax[0].tick_params(top=False, which='both',labeltop=False)

        ax[1].spines['left'].set_visible(True)
        ax[-1].spines['right'].set_visible(True)
        ax[1].text(0,.85
                   , 'Direct radiative effects', va='center',ha='right',rotation='vertical', transform=ax[1].transAxes,
                  fontsize='large')
        
        ax[1].text(0,.487
                   , 'Cloud radiative effects', va='center',ha='right',rotation='vertical', transform=ax[1].transAxes,
                  fontsize='large')
        
        ax[1].text(0,.155
                   , 'Effective dust radiative effect', va='center',ha='right',rotation='vertical', transform=ax[1].transAxes,
                  fontsize='large')

        ax[2].axvline(0.0, linestyle=':', linewidth=2, color='darkgrey')
        ax[2].set_xlabel('W m-2')
        
        ax[-1].text(.68,.935,'LW',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.825,'SW',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.71,'Net',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        
        ax[-1].text(.68,.605,'LW',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.495,'SW',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.381,'Net',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        
        ax[-1].text(.68,.22,'Surface',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
        ax[-1].text(.68,.072,'TOA',va='center',ha='right', transform=ax[-1].transAxes,fontsize='large')
    return ax

axes = setup_plot()
dist=1.15
plot_forcings_bar(get_forcing('ERFt', dfs),len(dfs), 0.,axes[1:], dist=dist,model_order=model_order)
plot_forcings_bar(get_forcing('ERFsurf', dfs),len(dfs), 1.35,axes[1:],dist=dist, model_order=model_order)
plot_forcings_bar(get_forcing('CloudEff', dfs),len(dfs), 2.8,axes[1:],dist=dist, model_order=model_order)
plot_forcings_bar(get_forcing('SWCloudEff', dfs),len(dfs), 3.85,axes[1:],dist=dist, model_order=model_order)
plot_forcings_bar(get_forcing('LWCloudEff', dfs),len(dfs), 4.85,axes[1:],dist=dist, model_order=model_order)
plot_forcings_bar(get_forcing('DirectEff', dfs),len(dfs), 6,axes[1:],dist=dist, model_order=model_order)
plot_forcings_bar(get_forcing('SWDirectEff', dfs),len(dfs), 7.05,axes[1:],dist=dist, model_order=model_order)
plot_forcings_bar(get_forcing('LWDirectEff', dfs),len(dfs), 8.15,axes[1:],dist=dist, model_order=model_order)


from pyexpat import model
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

colors = {model: colors[model] for model in model_order[::-1]}

legments = [
    Line2D([0],[0], markerfacecolor=c, marker='o', label= m, color = 'w', markersize=10)
    for m, c in colors.items()
]


legments.append(Line2D([0],[0], markerfacecolor='#FF005E', marker='*', label= 'Model mean', color = 'w',
                    markeredgecolor='k',markersize=20))
fig = plt.gcf()
fig.legend(handles=legments,ncol=1, bbox_to_anchor=[0.075, 0.125, 0.5, 0.5], loc='lower left',fontsize='small')
plt.savefig(snakemake.output[0],bbox_inches='tight')