In [2]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.cm as cm
import matplotlib as mpl
import xesmf as xe
from workflow.scripts.utils import regrid_global
import numpy as np
from pyclim_noresm.general_util_funcs import global_avg
import scipy.stats as stats
import pandas as pd
from workflow.scripts.plotting_tools import create_facet_plot

In [3]:
forcing_regions = {
    'East China': {'lon0':80, 'lat0':25, 'lon1':145,'lat1':50},
    'India'     : {'lon0':63, 'lat0':5,'lon1':93,'lat1':30},
    'Europe'    : {'lon0':-5, 'lat0':35, 'lon1':40,'lat1':65},
    'North America': {'lon0':-100, 'lat0': 23, 'lon1':-50, 'lat1':53},
    'Global'       : {'lon0':None, 'lat0': None, 'lon1': None, 'lat1':None}
                  }

In [4]:
def read_data(paths):
    dsets = []
    for p in paths:
        ds = xr.open_dataset(p)
        ds = ds.cf.add_bounds(['lon','lat'])
        grid_params = snakemake.config['regrid_params']
        method=grid_params.get('method','conservative')

        dxdy = grid_params['dxdy']
        ds = regrid_global(ds, lon=dxdy[0], lat=dxdy[1], method=method)
        da = ds[ds.variable_id]
#         return da
        da = da.rename(f'{ds.variable_id}_{ds.source_id}')
        da.attrs['source_id'] = ds.source_id 
        da.attrs['experiment_id'] = ds.experiment_id
        da = da.assign_coords(time=np.arange(0, len(da.time)))
#         da = da.rename(year='time')
        dsets.append(da)
    out_da = xr.merge(dsets)
    return out_da

def cal_toa_imbalance(ds, models, remove_mean=False, t=''):
    dsets = []
    for m in models:
        toa_imbalance = np.abs(ds[f'rsdt_{m}']) - np.abs(ds[f'rlut{t}_{m}']) - np.abs(ds[f'rsut{t}_{m}'])
        toa_imbalance = toa_imbalance.rename(f'{m}')
        if remove_mean:
            toa_imbalance = toa_imbalance - toa_imbalance.mean(dim='time')
        dsets.append(toa_imbalance)
    ds = xr.merge(dsets)
     
    return ds

In [5]:
mmr_mass_clim = read_data(snakemake.input.picli_load)
mmr_mass_aer = read_data(snakemake.input.piaer_load)

piClim = read_data(snakemake.input.picli_rad)
piaer = read_data(snakemake.input.piaer_rad)
# nor_esm_piload = read_data([snakemake.input.piaer_load_nor])
# nor_esm_aerload = read_data([snakemake.input.piclim_load])
models = list({piaer[d].attrs['source_id'] for d in piaer.data_vars})
# load_models = list({mmr_mass_aer[d].attrs['source_id'] for d in mmr_mass_aer.data_vars})

In [6]:
def calc_load(mmr_air, models):
    dsets = []
    for model in models:
        ds_load = mmr_air[f'mmrdust_{model}']*mmr_air[f'airmass_{model}']
        ds_load = ds_load.sum(dim='lev')
        ds_load.attrs['units'] = 'kg m-2'
        ds_load.attrs['long_name'] = 'Load of Dust'
        ds_load = ds_load.to_dataset(name=f'loaddust_{model}')
        dsets.append(ds_load)
    ds = xr.merge(dsets)
    return ds

In [7]:
load_picli = calc_load(mmr_mass_clim, models)
# load_picli['loaddust_NorESM2-LM']=nor_esm_piload['loaddust_NorESM2-LM']
load_piaer = calc_load(mmr_mass_aer, models)
# load_piaer['loaddust_NorESM2-LM'] = nor_esm_aerload['loaddust_NorESM2-LM'] 

In [8]:
piClim_imbalance = cal_toa_imbalance(piClim, models)
piaer_imbalance = cal_toa_imbalance(piaer, models)
piClim_imbalance_cs = cal_toa_imbalance(piClim, models,t='cs')
piaer_imbalance_cs = cal_toa_imbalance(piaer, models,t='cs')

In [9]:
def plot_correlation_map(ds_load,ds_imbalance,models):
    fig,ax,cax = create_facet_plot(len(models),subplot_kw={'projection':ccrs.Robinson()},figsize=(14,7))
    norm = mpl.colors.Normalize(vmin=-0.7, vmax=0.7)
    cmap = cm.get_cmap('bwr', 15)
    ds_load = ds_load.cf.add_bounds(['lon','lat'])
    ds_imbalance = ds_imbalance.cf.add_bounds(['lon','lat'])
    ds_out = xe.util.grid_global(5,5, cf=True)
    ds_load = regrid_global(ds_load,ds_out)
    ds_imbalance = regrid_global(ds_imbalance,ds_out)
    for a, m in zip(ax, models):
        corr = xr.corr(ds_imbalance[m], ds_load[f'loaddust_{m}'], dim='time')
        corr.plot(ax=ax[a], norm=norm, cmap=cmap, add_colorbar=False, transform=ccrs.PlateCarree())
        mask = corr.copy()
        dof = len(ds_load.time)-2
        tv = (corr*np.sqrt(dof))/np.sqrt(1-corr**2)
        p = 2*(1-stats.t.cdf(np.abs(tv), dof))
        mask.data = xr.where(p < 0.1,1,np.nan)
        mask.plot.contourf(ax=ax[a], hatches=['////'], colors='none',levels=2, transform=ccrs.PlateCarree(), add_colorbar=False)
        ax[a].set_title(m)
        ax[a].coastlines()
        ax[a].gridlines()
        
    fig.colorbar(cm.ScalarMappable(norm, cmap=cmap),cax=cax)
    # return p, mask

In [14]:
plot_correlation_map(ds_imbalance=piaer_imbalance, ds_load=load_piaer, models = models)
plt.savefig(snakemake.output.piaer_allsky, bbox_inches='tight', dpi=144)

In [15]:
plot_correlation_map(ds_imbalance=piClim_imbalance, ds_load=load_picli, models = models)
plt.savefig(snakemake.output.picli_allsky, bbox_inches='tight', dpi=144)

In [16]:
plot_correlation_map(ds_imbalance=piClim_imbalance_cs, ds_load=load_picli, models = models)
plt.savefig(snakemake.output.picli_clearsky, bbox_inches='tight', dpi=144)

In [17]:
plot_correlation_map(ds_imbalance=piaer_imbalance_cs, ds_load=load_piaer, models = models)
plt.savefig(snakemake.output.piaer_clearsky, bbox_inches='tight', dpi=144)

In [11]:
def plot_std_map(ds_imbalance,models):
    fig,ax,cax = create_facet_plot(len(models),subplot_kw={'projection':ccrs.Robinson()},figsize=(14,7))
    norm = mpl.colors.Normalize(vmin=0, vmax=5)
    cmap = cm.get_cmap('YlOrBr', 15)
    ds_imbalance = ds_imbalance.cf.add_bounds(['lon','lat'])
    ds_out = xe.util.grid_global(5,5, cf=True)
    ds_imbalance = regrid_global(ds_imbalance,ds_out)
    for a, m in zip(ax, models):
        std = ds_imbalance[m].std(dim='time')
        std.plot(ax=ax[a], norm=norm, cmap=cmap, add_colorbar=False, transform=ccrs.PlateCarree())

        ax[a].set_title(m)
        ax[a].coastlines()
        ax[a].gridlines()
        
    fig.colorbar(cm.ScalarMappable(norm, cmap=cmap),cax=cax, extend='max')
    # return p, mask

In [12]:
plot_std_map(piaer_imbalance_cs,models)

In [13]:
plot_std_map(piClim_imbalance_cs,models)