In [None]:
from workflow.scripts.utils import read_list_input_paths
from pyclim_noresm.general_util_funcs import global_avg
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
# import seaborn

In [None]:
paths_atm = snakemake.input.atmabs
paths_ctrl = snakemake.input.abs550_ctrl
paths_exp = snakemake.input.abs550_exp
time_slice = slice(3,None)

with open('workflow/input_data/refractive_indicies_550nm.yaml') as f:
    
    dust_optics = yaml.safe_load(f)

def get_forcing(forcing_var: str,dataframes: dict):
    k = next(iter(dataframes))
    outdf = pd.DataFrame(index=dataframes.keys(), columns=dataframes[k].columns)
    for k,df in dataframes.items():
        try:
            outdf.loc[k,:] = df.loc[forcing_var]
        except KeyError:
            pass
            
    return outdf

In [None]:
snakemake.input.keys()

In [None]:
snakemake.input.diag_table

In [None]:
diag_table = {p.split('.')[0].split('_')[-1]:pd.read_csv(p,index_col=0) for p in snakemake.input.diag_table}

In [None]:
absdict, vname = read_list_input_paths(paths_atm)
abs550dict, vname_ctrl = read_list_input_paths(paths_ctrl)
abs550exp, vname_exp = read_list_input_paths(paths_exp)
dod550ctrl, _ = read_list_input_paths(snakemake.input.oddust550_ctrl)
dod550exp, _ = read_list_input_paths(snakemake.input.oddust550_exp)

In [None]:
abs_change = {}
dod_change = {}
gabs = {}

for model, ds in dod550ctrl.items():
    exp = dod550exp[model].isel(time=time_slice).mean(dim='time')
    ctrl = dod550ctrl[model].isel(time=time_slice).mean(dim='time')
    diff = exp-ctrl
    diff = global_avg(diff)
    dod_change[model] = diff['od550dust'].values

for model, ds in abs550dict.items():
    exp = abs550exp[model].isel(time=time_slice).mean(dim='time')
    ctrl = abs550dict[model].isel(time=time_slice).mean(dim='time')
    diff = exp-ctrl
    diff = global_avg(diff)
    abs_change[model] = diff[vname_ctrl].values

for model, ds in absdict.items():
#     print(model)
    if 'year' in ds.dims:
        ds=ds.rename_dims({'year':'time'})
    tempabs = ds.isel(time=time_slice).mean(dim='time')
    gabs[model] = global_avg(tempabs)[ds.variable_id]
    
keys = set(abs_change.keys()).intersection(set(dust_optics.keys()))

gabs = {k:gabs[k] for k in keys}

abs_change = {k:abs_change[k] for k in keys}
ni_optics = {k:dust_optics[k]['complex'] for k in keys}

df = pd.DataFrame([abs_change,ni_optics,gabs,dod_change], index=['abs550aer', 'ni', 'atmabs', 'dod550nm'])
df = df.astype(float).T

In [None]:

df['SWDirectEff'] = get_forcing('SWDirectEff',diag_table)['diff']
df['LWDirectEff'] = get_forcing('LWDirectEff',diag_table)['diff']
df['ERFt'] = get_forcing('ERFt', diag_table)['diff']

In [None]:
fig, ax = plt.subplots(figsize=(8,6)) 
df.plot.scatter(y='abs550aer', x='ni', ax=ax,  s=50)
for k, v in df.iterrows():
    xy = (v['ni'],v['abs550aer'])
    ax.annotate(k, xy,
               xytext=(10,-5), textcoords='offset points', fontsize=12)
ax.set_ylim(0.0001,0.1)
ax.set_xlim(0,0.007)
ax.semilogy()
ax.grid(linestyle='--')

In [None]:
fig, ax = plt.subplots(figsize=(8,6)) 
df.plot.scatter(y='atmabs', x='ni', ax=ax,  s=50)
for k, v in df.iterrows():
    xy = (v['ni'],v['atmabs'])
    ax.annotate(k, xy,
               xytext=(10,-5), textcoords='offset points', fontsize=12)
ax.set_ylim(0.1,2)
ax.set_xlim(0,0.007)
# ax.semilogy()
ax.grid(linestyle='--')

In [None]:
def plot_fig(df):
    df = df.drop(['MIROC6','GISS-E2-1-G'], axis=0)
    y = 'ERFt'
    x = 'dod550nm'
    fig, ax = plt.subplots(figsize=(4,3.6)) 
    cmap = mpl.cm.get_cmap('Blues', 13)
    norm = mpl.colors.Normalize(vmin=0.0001, vmax=0.0012)
    msg=df.plot.scatter(y=y, x=x, ax=ax,  s=50, c='abs550aer',colorbar=False, norm=norm,
                        colormap='Blues')

    for k, v in df.iterrows():
        xy = (v[x],v[y])
        ax.annotate(f'{k}', xy,
                   xytext=(10,-5), textcoords='offset points', fontsize=8)
        
#   
# tabx = fig.add_axes([0.05,0.01,0.8,0.2])
#     fig.subplots_adjust(left=0.2, bottom=0.2)
#     with pd.option_context('display.float_format', '{:0.2f}'.format):
    dftab = df[['ni']].round(decimals=4)
    
    dftab = dftab.rename(columns={'concdust':'$\Delta$ Dust loading (kg/m-2)',
                                 'ni':'$n_i$'})
    dftab = dftab.sort_values('$n_i$')
#     pd.plotting.table(ax=ax, data=dftab[['$n_i$']], loc=3, bbox = [0.35,0.58,0.12,0.4])
    
    cax = fig.add_axes([0.94,0.2,0.02,0.62])
    fig.colorbar(mpl.cm.ScalarMappable(norm, cmap=cmap), cax=cax, extend='max', label='$\Delta$ AAOD 550m')
#     ax.grid(linestyle='--')
    
#     ax.set_ylim(-0.8, 0.1)
    ax.set_xlim(0, 0.04)
    ax.set_xlabel('$\Delta$ AOD 550nm')
    ax.axes.invert_xaxis()
    ax.set_ylabel("Total DRE")

plot_fig(df)

plt.savefig(snakemake.output.absortion_plot, bbox_inches='tight', dpi=300)
plt.savefig('results/figs/AerChemMIP/SWDirectEff_AAOD_refractive_index.pdf',bbox_inches='tight')