In [125]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.cm as cm
import matplotlib as mpl
import xesmf as xe
from workflow.scripts.utils import regrid_global
import numpy as np
from pyclim_noresm.general_util_funcs import global_avg
import scipy.stats as stats
import pandas as pd

In [3]:
forcing_regions = {
    'East China': {'lon0':80, 'lat0':25, 'lon1':145,'lat1':50},
    'India'     : {'lon0':63, 'lat0':5,'lon1':93,'lat1':30},
    'Europe'    : {'lon0':-5, 'lat0':35, 'lon1':40,'lat1':65},
    'North America': {'lon0':-100, 'lat0': 23, 'lon1':-50, 'lat1':53},
    'Global'       : {'lon0':None, 'lat0': None, 'lon1': None, 'lat1':None}
                  }

In [75]:
def read_data(paths):
    dsets = []
    for p in paths:
        ds = xr.open_dataset(p)
        ds = ds.cf.add_bounds(['lon','lat'])
        grid_params = snakemake.config['regrid_params']
        method=grid_params.get('method','conservative')

        dxdy = grid_params['dxdy']
        ds = regrid_global(ds, lon=dxdy[0], lat=dxdy[1], method=method)
        da = ds[ds.variable_id]
#         return da
        da = da.rename(f'{ds.variable_id}_{ds.source_id}')
        da.attrs['source_id'] = ds.source_id 
        da.attrs['experiment_id'] = ds.experiment_id
        da = da.assign_coords(time=np.arange(0, len(da.time)))
#         da = da.rename(year='time')
        dsets.append(da)
    out_da = xr.merge(dsets)
    return out_da

In [45]:
def cal_toa_imbalance(ds, models):
    dsets = []
    for m in models:
        toa_imbalance = ds[f'rsdt_{m}'] - ds[f'rlut_{m}'] - ds[f'rsut_{m}']
        toa_imbalance = toa_imbalance.rename(f'Imbalance_TOA_{m}')
        dsets.append(toa_imbalance)
    ds = xr.merge(dsets)
     
    return ds

def compute_dust_timeseries(ds, models, vname='loaddust'):
    dsets = []
    for m in models:
        ems = ds[f'{vname}_{m}']
        dsets.append(ems)
    ds = xr.merge(dsets)
    df = global_avg(ds).to_pandas()
    df = df - df.mean()
    return df

def F_test(ts1, ts2):
    var1 = ts1.var()
    var2 = ts2.var()
    df1 = len(var1) -1
    df2 = len(var2) -1
    F = var1/var2
    pval = 1 - stats.f.cdf(F, df1, df2)
    return pval

In [76]:
piClim = read_data(snakemake.input.picli_vars)
piaer = read_data(snakemake.input.piaer_vars)
hist = read_data(snakemake.input.hist)
histSST = read_data(snakemake.input.histSST)
hist_detrended=hist.diff(dim='time')

models = list({piaer[d].attrs['source_id'] for d in piaer.data_vars})

In [77]:
histSST_detrend = histSST.diff(dim='time')

In [78]:
dsets = [piClim, piaer, hist_detrended, histSST_detrend]

In [51]:
piclim_imbalance = cal_toa_imbalance(piClim, models)
piaer_imbalance = cal_toa_imbalance(piaer, models)
hist_imbalance = cal_toa_imbalance(hist_detrended, models)
histSST_imbalance = cal_toa_imbalance(histSST_detrend, models)

toa_imbalance_list = [piclim_imbalance, piaer_imbalance, hist_imbalance, histSST_imbalance]
hist_bool = [False, False, True,True]

In [81]:
piClim

In [126]:
def plot_toa_imbalance_variability(toa_imbalance,dust_dset,forcing_regions, historical):
    fig, ax = plt.subplots(nrows=len(forcing_regions.keys()), figsize=(10,14))
    colors =["#1845fb", "#ff5e02", "#c91f16","#c849a9", "#adad7d", "#86c8dd", "#578dff", "#656364"]
    correlations = {}
    for reg, ax_i in zip(forcing_regions,ax):
        ax_t = ax_i.twinx()
        correlations[reg]={}
        for toa_im, dust_t, hist, c in zip(toa_imbalance, dust_dset, historical, colors):
            temp_d = dust_t.sel(lon=slice(forcing_regions[reg]['lon0'], forcing_regions[reg]['lon1']), 
                                        lat=slice(forcing_regions[reg]['lat0'],forcing_regions[reg]['lat1']))
            if hist:
                pass
            else:
                dust_t = compute_dust_timeseries(temp_d, ['NorESM2-LM']).iloc[:30,:]
                
                toa_im = global_avg(toa_im).to_pandas().iloc[:30,:]
                toa_im = toa_im - toa_im.mean()         
                toa_im.plot(ax=ax_i, color=c, legend=False, marker='o', linestyle='-.')
#                 return dust_t, toa_im
                correlations[reg][temp_d.attrs["experiment_id"]] = toa_im.iloc[:,0].corr(dust_t.iloc[:,0])
                dust_t.plot(ax=ax_t, label = f'loaddust {temp_d.attrs["experiment_id"]}', color=c, marker='s', legend=False)
        ax_i.set_title(reg)
    return pd.DataFrame(correlations)

In [127]:
corrs = plot_toa_imbalance_variability(toa_imbalance_list, dsets, forcing_regions, hist_bool)

In [129]:
corrs