In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib as mpl
import numpy as np
import random
import os
from workflow.scripts.plotting_tools import get_model_colordict
%matplotlib inline

In [None]:
dfs = {p.split('.')[0].split('_')[-1]: pd.read_csv(p,index_col=0) for p in snakemake.input}

os.makedirs(snakemake.output[0], exist_ok=True)

In [None]:
def get_forcing(forcing_var: str,dataframes: dict):
    k = next(iter(dataframes))
    outdf = pd.DataFrame(index=dataframes.keys(), columns=dataframes[k].columns)
    for k,df in dataframes.items():
        try:
            outdf.loc[k,:] = df.loc[forcing_var]
        except KeyError:
            pass
            
    return outdf

In [None]:
snakemake.rule

In [None]:
colors = get_model_colordict()

if snakemake.rule == 'plot_forcing_decomposition':
    jitter_fix = True
else:
    jitter_fix = False

In [None]:
def forcing_plot(df, jitter_fix=True):
    with mpl.rc_context({
        "xtick.major.size" : 6,
        "xtick.minor.size" : 3.8,
        "xtick.major.width" : 1.2,
        "xtick.minor.width" : .8,
        "axes.linewidth" : 1.2,

    }):
        fig, ax = plt.subplots(figsize=(6,1))
        ax.set_xlim(-1.11,1.11)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.set_yticks([])
#         ax.set_xticks([-.6,-0.4,-0.2,0.0,0.2,0.4,.6])
        ax.xaxis.set_major_locator(mtick.MaxNLocator(7))
        ax.xaxis.set_minor_locator(mtick.AutoMinorLocator())
#         ax.axes.xaxis.set_minor_locator(mtick.FixedLocator([-0.3,-0.5,-0.1,0.1,0.3,0.5]))
        ax.set_ylim(0,2)
        plt.axvline(0.0, linestyle=':', linewidth=2, color='darkgrey')
        prev = None
        df = df.sort_values('diff')
        jitter=0
        for model, series in df.iterrows():
    #         ax.errorbar(series['diff'],0.7,
    #                     xerr=np.abs(series[['diff']].values - series[['diff_ci_low','diff_ci_high']]).values.reshape(2,1),
    #                   ,elinewidth=)
            if jitter_fix:    
                if not prev:
                    prev = series['diff']
                    jitter = 0
                else:
                    r_prev = abs(series['diff']-prev)
                    if r_prev < 0.04:
                        if jitter < 0:
                            jitter += 1
                        elif jitter==0.0:
                            jitter=-.5
                        else:
                            jitter -= 1
                    else:
                        jitter = 0
                    prev = series['diff']
    #                 print(jitter, model, prev)

                ax.scatter(series['diff'],1+jitter, c=colors[model], label=model, s = 40, zorder=100)
            else:
                ax.scatter(series['diff'],.4+jitter, c=colors[model], label=model, s = 20, zorder=100)
                jitter += 0.2
        mean = df.mean(axis=0)
        ax.scatter(mean['diff'],1.5, c='#FF005E', label=model,marker='*', s = 80, zorder=101, edgecolor='k',linewidth=.5)
        ax.set_xlabel('W m-2')
   


In [None]:
forcing_plot(get_forcing('ERFt', dfs), jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/ERFt.pdf',bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('LWDirectEff', dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/LWDirectEff.pdf',bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('SWCloudEff',dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/SWCloudEff.pdf',bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('LWCloudEff',dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/LWCloudEff.pdf',bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('CloudEff',dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/CloudEff.pdf',bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('SWDirectEff',dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/SWDirectEff.pdf',bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('DirectEff',dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/DirectEff.pdf',bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('ERFtcsaf',dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/Albedo.pdf', bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('ERFsurf',dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/SurfaceBalance.pdf',bbox_inches='tight')

In [None]:
forcing_plot(get_forcing('atmabs',dfs),jitter_fix=jitter_fix)
plt.savefig(snakemake.output[0]+'/atmospheric_absorption.pdf',bbox_inches='tight')

In [None]:
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

legments = [
    Line2D([0],[0], markerfacecolor=c, marker='o', label= m, color = 'w', markersize=15)
    for m, c in colors.items()
]

legments.append(Line2D([0],[0], markerfacecolor='#FF005E', marker='*', label= 'Model mean', color = 'w',
                       markeredgecolor='k',markersize=20))
fig,ax = plt.subplots()
ax.legend(handles=legments, loc='center', fontsize="large", frameon=False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.set_yticks([])
ax.set_xticks([])
plt.savefig(snakemake.output[0]+'/legend.pdf',bbox_inches='tight')