In [87]:
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl
import xesmf as xe
from workflow.scripts.utils import regrid_global
import pandas as pd
import numpy as np
from pyclim_noresm.general_util_funcs import global_avg
import statsmodels.api as sm
from matplotlib.ticker import LinearLocator, AutoMinorLocator, FixedLocator

In [3]:
def read_data(paths, tag='source_id', time_slice=None):
    dsets = []
    for p in paths:
        ds = xr.open_dataset(p, chunks={'time':40})
        ds = ds.cf.add_bounds(['lon','lat'])
        if 'year' in ds.dims:
            ds = ds.rename({'year':'time'})

        grid_params = snakemake.config['regrid_params']
        method=grid_params.get('method','conservative')

        dxdy = grid_params['dxdy']
        ds = regrid_global(ds, lon=dxdy[0], lat=dxdy[1], method=method)
        da = ds[ds.variable_id]
#         return da

        da = da.rename(f'{ds.variable_id}_{ds.attrs[tag]}')
        da.attrs['source_id'] = ds.source_id 
        da.attrs['experiment_id'] = ds.experiment_id
        da = da.assign_coords(time=np.arange(0, len(da.time)))
#         da = da.rename(year='time')
        dsets.append(da)
    out_da = xr.merge(dsets)
    if time_slice:
        out_da = out_da.isel(time=time_slice)
    
    
    return out_da
def calc_load_diff(dsh,dspiaer, models):
    dsets = []
    for m in models:
        dshmmr = dsh[f'mmrdust_{m}'].dropna(dim='lev').dropna(dim='time')
        dspimmr = dspiaer[f'mmrdust_{m}'].dropna(dim='lev').dropna(dim='time')
        dshair = dsh[f'airmass_{m}'].dropna(dim='lev').dropna(dim='time')
        dspiair = dspiaer[f'airmass_{m}'].dropna(dim='lev').dropna(dim='time')
        loaddusth = dshmmr*dshair
        loaddustpi = dspimmr*dspiair
        diffload = loaddusth.sum(dim='lev')-loaddustpi.sum(dim='lev')
        diffload = diffload.rename(f'delta_dustload_{m}')
        dsets.append(diffload)
    ds = xr.merge(dsets)
    return ds

In [6]:
load_histSST_piaer = read_data(snakemake.input.histSST_piaer_load, time_slice=slice(None,-5))

In [4]:
load_histSST = read_data(snakemake.input.histSST_load,time_slice=slice(None,-5))

In [7]:
models = list({load_histSST_piaer[d].source_id for d in load_histSST_piaer.data_vars})

In [8]:
dloaddust=calc_load_diff(load_histSST, load_histSST_piaer, models)

In [9]:
dloaddust = dloaddust.compute()

In [11]:
fit_params = pd.read_csv(snakemake.input.toa_dustload_txt, index_col=0)

In [12]:
def calc_dust_erf(dloaddust, fit_params, models):
    dloaddustavg = global_avg(dloaddust)

    dfloaddf = dloaddustavg.to_dataframe()
    for m in models:
        dfloaddf[f'ERFcs_dust_{m}'] = dfloaddf[f'delta_dustload_{m}']*fit_params.loc['x',m] + fit_params.loc['c',m]
    return dfloaddf

In [42]:
dloaddust['delta_dustload_MPI-ESM-1-2-HAM'][-1]

In [18]:
dfdusterf = calc_dust_erf(dloaddust, fit_params, models)
erfs = read_data(snakemake.input.histSST_ERFcs, time_slice=slice(None, -5)).compute()
erfs = global_avg(erfs).to_dataframe()

In [79]:
def make_plot(dfdusterf, erfs):
    fig,ax = plt.subplots(nrows=3, figsize=(8,6), sharex=True, sharey=True)
    ax[0].set_ylim(-1.2,0.2)
    for i,m in enumerate(models):
        du_removed = erfs[f'ERFtcs_{m}']- dfdusterf[f'ERFcs_dust_{m}']
        ax[i].plot(dfdusterf.index, erfs[f'ERFtcs_{m}']- dfdusterf[f'ERFcs_dust_{m}'], label=m, marker='s')
        ax[i].plot(erfs.index, erfs[f'ERFtcs_{m}'], label=m, marker='o')
        ax[i].set_title(m)
        ax[i].text(0.08, 0.15, 
                   f'$\sigma_{{no dust}}$ = {du_removed.std():.2f} \n $\sigma_{{all}}$ = {erfs[f"ERFtcs_{m}"].std():.2f}',
                  transform=ax[i].transAxes)
        ax[i].set_ylabel('W m-2')
#         ax

In [88]:
snakemake.output.outpath

In [89]:
make_plot(dfdusterf.iloc[-30:,:], erfs.iloc[-30:,:])
ax = plt.gca()
ax.xaxis.set_major_locator(FixedLocator(np.arange(130,163,3)))
ax.set_xticklabels(np.arange(1980,2013, 3))
ax.xaxis.set_minor_locator(AutoMinorLocator(3))
h,l = ax.get_legend_handles_labels()
fig = plt.gcf()
fig.legend(h, ['Dust variability removed', 'All variability'], loc='lower center', ncol=2)
plt.savefig(snakemake.output.outpath, bbox_inches='tight', dpi=144)

In [86]:
make_plot(dfdusterf.iloc[-60:-29,:], erfs.iloc[-60:-29,:])
ax = plt.gca()
ax.xaxis.set_major_locator(FixedLocator(np.arange(100,133,3)))
ax.set_xticklabels(np.arange(1950,1983, 3))
ax.xaxis.set_minor_locator(AutoMinorLocator(3))
h,l = ax.get_legend_handles_labels()
fig = plt.gcf()
fig.legend(h, ['Dust variability removed', 'All variability'], loc='lower center', ncol=2)
plt.savefig(snakemake.output.outpath1950, bbox_inches='tight', dpi=144)