In [2]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.cm as cm
import matplotlib as mpl
import xesmf as xe
from workflow.scripts.utils import regrid_global
import numpy as np
from pyclim_noresm.general_util_funcs import global_avg
import scipy.stats as stats
import pandas as pd
import statsmodels.api as sm

In [3]:
def read_data(paths, tag='experiment_id'):
    dsets = []
    for p in paths:
        ds = xr.open_dataset(p)
        ds = ds.cf.add_bounds(['lon','lat'])
        grid_params = snakemake.config['regrid_params']
        method=grid_params.get('method','conservative')

        dxdy = grid_params['dxdy']
        ds = regrid_global(ds, lon=dxdy[0], lat=dxdy[1], method=method)
        da = ds[ds.variable_id]
#         return da
        da = da.rename(f'{ds.variable_id}_{ds.attrs[tag]}')
        da.attrs['source_id'] = ds.source_id 
        da.attrs['experiment_id'] = ds.experiment_id
        da = da.assign_coords(time=np.arange(0, len(da.time)))
#         da = da.rename(year='time')
        dsets.append(da)
    out_da = xr.merge(dsets)
    return out_da


In [4]:
rad_noresm = read_data(snakemake.input.noresm_rad)
rad_mpi = read_data(snakemake.input.mpi_rad)
rad_ec = read_data(snakemake.input.ec_rad)

In [5]:
load_nor = read_data(snakemake.input.noresm_load)
load_mpi = read_data(snakemake.input.mpi_load)
load_ec = read_data(snakemake.input.ec_load)

In [6]:
exp_noresm = list({rad_noresm[d].experiment_id for d in rad_noresm.data_vars})
exp_mpi = list({rad_mpi[d].experiment_id for d in rad_mpi.data_vars})
exp_ec = list({rad_ec[d].experiment_id for d in rad_ec.data_vars})

In [53]:
def calc_toa_imbalance(ds_rad, experiments,ds_load=None,tag='cs', subtract_mean_dTOA=True):
    toa_imbalance = []
    loads = []
    for exp in experiments:
        
        temp_da = (np.abs(ds_rad[f'rsdt_{exp}'].dropna(dim='time')) 
                   - np.abs(ds_rad[f'rlut{tag}_{exp}'].dropna(dim='time')) 
                   - np.abs(ds_rad[f'rsut{tag}_{exp}'].dropna(dim='time')))
        
        if ds_load:
            load_da = ds_load[f'mmrdust_{exp}'].dropna(dim='time')*ds_load[f'airmass_{exp}'].dropna(dim='time')
            load_da = load_da.sum(dim='lev')
    #         return load_da
            load_da = global_avg(load_da)
            load_da = load_da -load_da.mean()
            load = load_da.to_dataset(name=f'loaddust_{exp}')
            loads.append(load)
        temp_da = global_avg(temp_da)
        if subtract_mean_dTOA:
            temp_da = temp_da - temp_da.mean()
        temp_rad = temp_da.to_dataset(name=f'dTOA_{exp}')
        
        toa_imbalance.append(temp_rad)
    load_im=xr.merge(toa_imbalance+loads)
#     if subtract_mean == True:
#         load_im=load_im-load_im.mean()
    
    return load_im

In [8]:
ec_load_im = calc_toa_imbalance(rad_ec,exp_ec,ds_load=load_ec)
nor_load_im = calc_toa_imbalance(rad_noresm, exp_noresm, ds_load=load_nor)
mpi_load_im = calc_toa_imbalance(rad_mpi, exp_mpi,ds_load=load_mpi)

In [9]:
def stack_data(ds, experiments):
    df = ds.to_dataframe()
    dfs = []
    for exp in experiments:
        tdf = df[[f'loaddust_{exp}',f'dTOA_{exp}']]
        

        tdf = tdf.dropna()
#         tdf=tdf[(np.abs(stats.zscore(tdf)) < 3).all(axis=1)]
        tdf = tdf.rename(columns={f'loaddust_{exp}':'loaddust',f'dTOA_{exp}':'dTOA'})
        dfs.append(tdf)
#     ts_load_ano = [df[[f'loaddust_{exp}']] for exp in experiments]
#     ts_dTOA_ano = [df[f'dTOA_{exp}'] for exp in experiments]
    df = pd.concat(dfs).reset_index()
#     ts_dTOA = pd.concat(ts_dTOA_ano).reset_index()
    df = df.drop(columns='time')
    df=df[(np.abs(stats.zscore(df['dTOA'])) < 3)]
#      = ts_load.drop(columns='time')
    return df
        

In [10]:
mpi_load_dTOA= stack_data(mpi_load_im, exp_mpi)
nor_load_dTOA = stack_data(nor_load_im, exp_noresm)
ec_load_dTOA = stack_data(ec_load_im, exp_ec)

In [16]:
def make_fit(df,regularized = True):
# dTOA = ec_dTOA.dropna()
# load = ec_load.dropna()
#     load=load.dropna()
#     dTOA = dTOA.dropna()
    X = sm.add_constant(df['loaddust'])
    model = sm.OLS(df['dTOA'],X, missing='drop')
    if regularized:
        result = model.fit_regularized(L1_wt=1)
    else:
        result = model.fit()
    return result, result.params.values[0], result.params.values[1]

In [12]:
def plot_fit():
    mpi_load_dTOA= stack_data(mpi_load_im, exp_mpi)
    nor_load_dTOA = stack_data(nor_load_im, exp_noresm)
    ec_load_dTOA = stack_data(ec_load_im, exp_ec)
    fig,ax = plt.subplots(figsize=(14,4), ncols=3)
    res_nor, c_nor,x_nor = make_fit(nor_load_dTOA, False)
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax[0].scatter(nor_load_dTOA['loaddust'],nor_load_dTOA['dTOA'], s=5)
    ax[0].plot([-10e-4,10e-4],[-10e-4*x_nor+c_nor,10e-4*x_nor+c_nor], 
        linewidth=3, color='grey')
    ax[0].set_ylim(-0.5,0.5)
    ax[0].set_xlim(-6e-6, 6e-6)
    ax[0].set_title('NorESM2-LM')
    ax[0].text(0.17, 0.08, f'$R^2 = $ {res_nor.rsquared:.2f}', va='center', ha='center', transform=ax[0].transAxes)
    res_ec, c_ec,x_ec = make_fit(ec_load_dTOA, False)
    ax[0].text(0.5,0.92, f'dTOA={res_nor.params["loaddust"]:2.2E} $* \;C_{{dust}}$, p = {res_nor.pvalues["loaddust"]:.2E}',
              va='center', ha='center', transform=ax[0].transAxes, bbox=props)
    ax[1].scatter(ec_load_dTOA['loaddust'], ec_load_dTOA['dTOA'],s=5)
    ax[1].plot([-10e-4,10e-4],[-10e-4*x_ec+c_ec,10e-4*x_ec+c_ec], 
        linewidth=3, color='grey')
    ax[1].set_title('EC-Earth3-AerChem')
    ax[1].set_ylim(-0.5,0.5)
    ax[1].set_xlim(-2e-5, 2e-5)
    ax[1].text(0.17, 0.08, f'$R^2 = $ {res_ec.rsquared:.2f}', va='center', ha='center', transform=ax[1].transAxes)
    ax[1].text(0.5,0.92, f'dTOA={res_ec.params["loaddust"]:2.2E} $* \;C_{{dust}}$, p = {res_ec.pvalues["loaddust"]:.2E}',
              va='center', ha='center', transform=ax[1].transAxes, bbox=props)
    
    
    res_mpi, c_mpi, x_mpi = make_fit(mpi_load_dTOA, False)
    
    ax[2].set_title('MPI-ESM-1-2-HAM')
    ax[2].scatter(mpi_load_dTOA['loaddust'], mpi_load_dTOA['dTOA'],s=5)
    ax[2].plot([-10e-4,10e-4],[-10e-4*x_mpi+c_mpi,10e-4*x_mpi+c_mpi], linewidth=3, color='grey')
    ax[2].set_ylim(-0.5,0.5)
    ax[2].set_xlim(-4e-5, 4e-5)
    ax[2].text(0.17, 0.08, f'$R^2 = $ {res_mpi.rsquared:.2f}', va='center', ha='center', transform=ax[2].transAxes)
    ax[2].text(0.5,0.92, f'dTOA={res_mpi.params["loaddust"]:2.2E} $* \;C_{{dust}}$, p = {res_mpi.pvalues["loaddust"]:.2E}',
              va='center', ha='center', transform=ax[2].transAxes, bbox=props)
    
    
    ax[0].set_ylabel('TOA imbalance [W m-2]')
    ax[1].set_xlabel('Dust load anomaly \n [kg m-2]')
    
# ax.set_as

In [15]:
ec_load_dTOA

In [14]:
nor_load_dTOA

In [13]:
plot_fit()
plt.savefig(snakemake.output.toa_dustload_png)

In [109]:
res_nor, c_nor,x_nor = make_fit(nor_load_dTOA, False)
res_ec, c_ec,x_ec = make_fit(ec_load_dTOA, False)
res_mpi, c_mpi, x_mpi = make_fit(mpi_load_dTOA, False)

In [114]:
c_ec

In [117]:
import yaml
out_dict = {}
out_dict['NorESM2-LM'] = {}
out_dict['NorESM2-LM']['c'] = c_nor
out_dict['NorESM2-LM']['x'] = x_nor

out_dict['EC-Earth3-AerChem'] = {}
out_dict['EC-Earth3-AerChem']['c'] = c_ec
out_dict['EC-Earth3-AerChem']['x'] = x_ec

out_dict['MPI-ESM-1-2-HAM'] = {}
out_dict['MPI-ESM-1-2-HAM']['c'] = c_mpi
out_dict['MPI-ESM-1-2-HAM']['x'] = x_mpi

outdf = pd.DataFrame(out_dict)
outdf.to_csv(snakemake.output.toa_dustload_txt)

## Validation

In [118]:
val=False
if val:
    histSST = read_data(snakemake.input.histSST_rad, tag='source_id')

    mmr_mass_hist = read_data(snakemake.input.histSST_load, tag='source_id')

    models = list({histSST[d].attrs['source_id'] for d in histSST.data_vars})

    def calc_load(mmr_air, models):
        dsets = []
        for model in models:
            ds_load = mmr_air[f'mmrdust_{model}']*mmr_air[f'airmass_{model}']
            ds_load = ds_load.sum(dim='lev')
            ds_load = global_avg(ds_load)
            ds_load.attrs['units'] = 'kg m-2'
            ds_load.attrs['long_name'] = 'Load of Dust'
            ds_load = ds_load.to_dataset(name=f'loaddust_{model}')
            dsets.append(ds_load)
        ds = xr.merge(dsets)
        return ds



    histSST_imbalance_cs = calc_toa_imbalance(histSST,models,tag='cs', subtract_mean_dTOA=False)

    histSST_imbalance_cs

    res_nor, c_nor,x_nor = make_fit(nor_load_dTOA, False)
    res_ec, c_ec,x_ec = make_fit(ec_load_dTOA, False)
    res_mpi, c_mpi, x_mpi = make_fit(mpi_load_dTOA, False)

    loads_histSST = calc_load(mmr_mass_hist, models)

    def create_validation_data(loading, dTOA, t_slice=None):
        if t_slice:
            loading = loading.isel(time=t_slice)
            dTOA = dTOA.isel(time=t_slice)
        else:
            loading = loading.isel(time=slice(0,30))
            dTOA = dTOA.sel(time=slice(0,30))

        loading = loading - loading.mean()
        loading = loading.rename('loaddust')
    #     dTOA = dTOA - dTOA.mean()
        dTOA = dTOA.to_dataset(name='dTOA')
        dfTOA = dTOA.to_pandas()
        dfTOA['loaddust'] = loading.to_pandas()
    #     dfloading = 
        return dfTOA

    nor_val = create_validation_data(loads_histSST['loaddust_NorESM2-LM'], histSST_imbalance_cs['dTOA_NorESM2-LM'])
    nor_val['predicted'] = res_nor.predict(sm.add_constant(nor_val['loaddust']))
    mpi_val = create_validation_data(loads_histSST['loaddust_MPI-ESM-1-2-HAM'], 
                                     histSST_imbalance_cs['dTOA_MPI-ESM-1-2-HAM'], slice(-40,-10))
    mpi_val['predicted'] = res_mpi.predict(sm.add_constant(mpi_val['loaddust']))

    ec_val = create_validation_data(loads_histSST['loaddust_EC-Earth3-AerChem'], 
                                     histSST_imbalance_cs['dTOA_EC-Earth3-AerChem'])
    ec_val['predicted'] = res_mpi.predict(sm.add_constant(ec_val['loaddust']))

    dhTOA_mpi = histSST_imbalance_cs['dTOA_MPI-ESM-1-2-HAM'].isel(time=slice(0,30)) - histSST_imbalance_cs['dTOA_MPI-ESM-1-2-HAM'].isel(time=slice(0,30)).mean()

In [68]:
if val:
    ax = plt.gca()
    mpi_load_im['loaddust_piClim-aer'].plot(ax=ax)
    mpi_load_im['loaddust_piClim-2xdust'].plot(ax=ax)
    mpi_load_im['loaddust_piClim-control'].plot(ax=ax)
    mpi_val['loaddust'].plot(ax=ax)

In [100]:
if val:
    ax = plt.gca()

    ax.plot(res_mpi.predict(sm.add_constant(mpi_load_im['loaddust_piClim-aer'])))
    ax.plot(res_mpi.predict(sm.add_constant(mpi_load_im['loaddust_piClim-2xdust'])))
    ax.plot(res_mpi.predict(sm.add_constant(mpi_load_im['loaddust_piClim-control'])))
    (mpi_val['predicted']).plot(ax=ax)

In [108]:
if val:
    ax=plt.gca()
    # mpi_no_dust.plot(ax=ax,marker='o', label='no dust')
    ((mpi_val['dTOA']-mpi_val['dTOA'].mean())-mpi_val['predicted']).plot(ax=ax, marker='s', label='dTOA histSST')
    ((mpi_val['dTOA']-mpi_val['dTOA'].mean())).plot(ax=ax, marker='s', label='dTOA histSST')
    # ax.set_ylim(-0.6, 0.6)
    ax.legend()

In [120]:
if val:
    mpi_no_dust_hist = (dhTOA_mpi-
                        res_mpi.predict(sm.add_constant(mpi_val['loaddust'])))

    ax = plt.gca()
    dhTOA_mpi.plot(ax=ax)
    mpi_no_dust_hist.plot(ax=ax)

In [119]:
if val:
    ax = plt.gca()
    # mpi_val['predicted'].plot(ax=ax)
    (mpi_val['dTOA']-mpi_val['dTOA'].mean()).plot(ax=ax)
    (mpi_val['dTOA']-(mpi_val['predicted']+mpi_val['dTOA'].mean())).plot(ax=ax)