In [None]:
from workflow.scripts.utils import read_list_input_paths
from pyclim_noresm.general_util_funcs import global_avg
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import xarray as xr
from scipy.stats import linregress
# import seaborn

In [None]:
def get_forcing_value(dfs, variable):
    return {model: df.get('diff').get(variable, np.nan) for model, df in dfs.items()}


with open('workflow/input_data/refractive_indicies_550nm.yaml') as f:
    
    dust_optics = yaml.safe_load(f)


In [None]:
dfs = {p.split('_')[-1].split('.')[0]: pd.read_csv(p,index_col=0) for p in snakemake.input.erfs}

In [None]:
atmabs = get_forcing_value(dfs, 'atmabs')
atmabs_sw = get_forcing_value(dfs, 'atmabsSW')
SWDirectEff = get_forcing_value(dfs, 'SWDirectEff')
ERFt = get_forcing_value(dfs, 'ERFt')

In [None]:
def get_global_value(dsets,variable):
    return {
            model: global_avg(dset.isel(time=slice(1,None)).mean(dim='time').get(variable,np.nan)).values
              for model, dset in dsets.items()}

In [None]:
dsets_exp = {p.split('/')[-1].split('_')[-2] : xr.open_dataset(p) for p in snakemake.input.exp_data}
dsets_ctrl = {p.split('/')[-1].split('_')[-2] : xr.open_dataset(p) for p in snakemake.input.ctrl_data}


In [None]:
aaod_exp = get_global_value(dsets_exp,'abs550aer')
aaod_ctrl = get_global_value(dsets_ctrl,'abs550aer')
df_aaod_exp = pd.DataFrame.from_dict(aaod_exp,orient='index',columns=['abs550aer'])
df_aaod_ctrl = pd.DataFrame.from_dict(aaod_ctrl,orient='index',columns=['abs550aer'])

aod_exp = pd.DataFrame.from_dict(get_global_value(dsets_exp,'od550aer'),orient='index',columns=['od550aer'])
aod_ctrl = pd.DataFrame.from_dict(get_global_value(dsets_ctrl,'od550aer'),orient='index',columns=['od550aer'])

df_diff_aaod = df_aaod_exp-df_aaod_ctrl
df_diff_aod = aod_exp-aod_ctrl

In [None]:
atmabs = pd.DataFrame.from_dict(atmabs,orient='index',columns=['atmabs'])
atmabs_sw = pd.DataFrame.from_dict(atmabs_sw,orient='index',columns=['atmabs_sw'])
SWDirectEff = pd.DataFrame.from_dict(SWDirectEff,orient='index',columns=['SWDirectEff'])
ERFt = pd.DataFrame.from_dict(ERFt,orient='index',columns=['ERFt'])


In [None]:
dust_optics = pd.DataFrame.from_dict(dust_optics,orient='index')

In [None]:
df = pd.concat([df_diff_aod,df_diff_aaod,atmabs,atmabs_sw,SWDirectEff,ERFt, dust_optics],axis=1)

In [None]:
df

In [None]:
def plot_fig(df):
    df = df.drop(['MIROC6','GISS-E2-1-G'], axis=0)
    y = 'ERFt'
    x = 'od550aer'
    fig, ax = plt.subplots(figsize=(4*1.5,3.6*1.51)) 
    cmap = mpl.cm.get_cmap('Blues', 13)
    norm = mpl.colors.Normalize(vmin=0.0001, vmax=0.0012)
    msg=df.plot.scatter(y=y, x=x, ax=ax,  s=50, c='abs550aer',colorbar=False, norm=norm,
                        colormap='Blues', zorder=100)

    for k, v in df.iterrows():
        xy = (v[x],v[y])
        ax.annotate(f'{k}', xy,
                   xytext=(5,-5), textcoords='offset points', fontsize=8, zorder=200)
        
    slope, intercept, r_value, p_value, std_err = linregress(df[x],df[y])
    xl = [0.04, 0.00]
    yl = [slope*0.04 + intercept , intercept]
    ax.plot(xl, yl, '--', color='red', linewidth=3)
    dftab = df[['complex']].round(decimals=4)
    
    dftab = dftab.rename(columns={'complex':'$n_i$'})
    dftab = dftab.sort_values('$n_i$')
    pd.plotting.table(ax=ax, data=dftab[['$n_i$']], loc=3, bbox = [0.35,0.58,0.12,0.4])
    
    cax = fig.add_axes([0.94,0.2,0.02,0.62])
    fig.colorbar(mpl.cm.ScalarMappable(norm, cmap=cmap), cax=cax, extend='max', label='$\Delta$ AAOD 550m')
#     ax.grid(linestyle='--')
    
    ax.set_ylim(-0.55, 0.25)
    ax.set_xlim(0, 0.04)
    ax.set_xlabel('$\Delta$ AOD 550nm')
    ax.axes.invert_xaxis()
    ax.set_ylabel("Total DRE")

plot_fig(df)

plt.savefig(snakemake.output.absortion_plot, bbox_inches='tight', dpi=300)
# plt.savefig('results/figs/AerChemMIP/SWDirectEff_AAOD_refractive_index.pdf',bbox_inches='tight')