In [None]:
import xarray as xr
from pyclim_noresm.general_util_funcs import global_avg
from workflow.scripts.plotting_tools import global_map, create_facet_plot
from workflow.scripts.utils import regrid_global
import numpy as np
import matplotlib as mpl
xr.set_options(keep_attrs=True)



In [None]:
def reorder_lons(ds):
    x = ds.cf['X'].name
    if ds.lon.max() > 180:
        ds = ds.assign_coords({x:((ds.coords[x] + 180) % 360 - 180)}).sortby(x)
        ds=ds.cf.add_bounds(['lon','lat'])
    return ds


In [None]:
drop_vars = ['drydust','emidust','od440aer','od550dust','od550aer'
            'tas','radatm','radatmcs','ang4487aer','Australia emidust',
            'Western north Africa emidust','Eastern north Africa emidust','North America emidust'
            ,'lifetime','Sahel emidust','Middle east emidust','Central Asia emidust',
            'od550dust_mass_ext','Southern Africa emidust', 'tas','East Asia emidust',
            'South America emidust']

so4conc_exp = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p).isel(time=slice(1,None))) for p in snakemake.input.exp_so4
}
so4conc_ctrl = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p)) for p in snakemake.input.ctrl_so4
}


oaconc_exp = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p).isel(time=slice(1,None))) for p in snakemake.input.exp_oa
}
oaconc_ctrl = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p)) for p in snakemake.input.ctrl_oa
}

soaconc_exp = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p).isel(time=slice(1,None))) for p in snakemake.input.exp_soa
}
soaconc_ctrl = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p)) for p in snakemake.input.ctrl_soa
}

bcconc_exp = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p).isel(time=slice(1,None))) for p in snakemake.input.exp_bc
}
bcconc_ctrl = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p)) for p in snakemake.input.ctrl_bc
}

nh4conc_exp = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p).isel(time=slice(1,None))) for p in snakemake.input.exp_nh4
}
nh4conc_ctrl = {
    p.split('_')[2]: reorder_lons(xr.open_dataset(p)) for p in snakemake.input.ctrl_nh4
}


cloud_diag_ctrl = {p.split("_")[-2]: xr.open_dataset(p).isel(time=slice(1,None)) for p in snakemake.input.ctrl_clddiag}
cloud_diag_exp = {p.split("_")[-2]: xr.open_dataset(p) for p in snakemake.input.exp_clddiag}


dst_diag_exp = {
    p.split("_")[-2]: xr.open_dataset(p).isel(time=slice(1,None)).drop(drop_vars, errors='ignore')
                for p in snakemake.input.exp_dstdiag
               }
dst_diag_ctrl = {
    p.split("_")[-2]: xr.open_dataset(p).drop(drop_vars, errors='ignore') 
                         for p in snakemake.input.ctrl_dstdiag
                }

In [None]:
def fix_dataset(dust_diag, data_concso4, vname='concso4'):
    if np.all(data_concso4.lon.values==dust_diag.lon.values) and np.all(data_concso4.lat.values==dust_diag.lat.values): 
        dust_diag = dust_diag.assign({vname:data_concso4[vname]})
    else:
        data_concso4 = data_concso4.assign_coords(lon=dust_diag.lon)
        data_concso4 = data_concso4.assign_coords(lat=dust_diag.lat)
        dust_diag = dust_diag.assign({vname:data_concso4[vname]})
    if 'bnds' in dust_diag.dims:
        dust_diag = dust_diag.drop('bnds')
    if 'bounds' in dust_diag.dims:
        dust_diag = dust_diag.rename_dims({'bounds':'bnds'})
        dust_diag = dust_diag.rename({'lon_bounds': 'lon_bnds','lat_bounds':'lat_bnds'})
    return dust_diag

def corr_fields(ds, burd1='concso4', burd2='concdust'):
    ds=ds.assign({f'corr{burd1[4:]}{burd2[4:]}': xr.corr(ds[burd1],ds[burd2],dim='time')})
    return ds

In [None]:
dsets_exp = {k: fix_dataset(dst_diag_exp[k], data_concso4=so4conc_exp[k]) for k in so4conc_exp}
dset_ctrl = {k: fix_dataset(dst_diag_ctrl[k], data_concso4=so4conc_ctrl[k]) for k in so4conc_ctrl}

In [None]:
dsets_exp = {k : corr_fields(dsets_exp[k],'concso4', 'concdust') for k in so4conc_exp}

In [None]:
fig,ax, cax = create_facet_plot(len(dsets_exp
                                   ))
for k,axk in zip(so4conc_ctrl,ax):
    
    global_map(ax=ax[axk])
    
    dsets_exp[k]['corrso4dust'].plot(ax=ax[axk], cmap='RdBu_r', vmin=-0.8, vmax=0.8)
    ax[axk].set_title(k)

In [None]:
dsets_exp = {k: fix_dataset(dst_diag_exp[k], data_concso4=so4conc_exp[k]) for k in so4conc_exp}
dset_ctrl = {k: fix_dataset(dst_diag_ctrl[k], data_concso4=so4conc_ctrl[k]) for k in so4conc_ctrl}

diff_dsets_mean =  {k: dsets_exp[k].mean(dim='time') -dset_ctrl[k].mean(dim='time') for k in dset_ctrl}

In [None]:
fig,ax, cax = create_facet_plot(len(diff_dsets_mean
                                   ))

cmap = mpl.cm.PuOr_r.resampled(11)
norm = mpl.colors.Normalize(vmin=-25, vmax=25)
for k,axk in zip(so4conc_ctrl,ax):
    
    global_map(ax=ax[axk])
    change_so4burd = diff_dsets_mean[k]['concso4']*dset_ctrl[k]['cell_area']/1e3
    change_so4burd.attrs['units'] = 'Tonnes'
    burd_ctrl = (dset_ctrl[k]['concso4'].mean(dim='time')*dset_ctrl[k]['cell_area']).sum(dim=['lon','lat'])/1e3
    cm = change_so4burd.plot(ax=ax[axk], cmap=cmap, norm=norm, add_colorbar=False, add_labels=False)
    tot_change = change_so4burd.sum(dim=['lon','lat'])
    rel_diff = tot_change/(burd_ctrl)*100
    ax[axk].text(0.03, 0.08,f'Total change {tot_change:.3f} Tonnes ({rel_diff:.2f}%) ',transform=ax[axk].transAxes, 
                 bbox={'facecolor':'white'})
    ax[axk].set_title(k)
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm,cmap=cmap),cax=cax, extend='max')
cbar.set_label('Tonnes')

In [None]:
dsets_exp = {k: fix_dataset(dst_diag_exp[k], data_concso4=oaconc_exp[k],vname='concoa') for k in oaconc_exp}
dset_ctrl = {k: fix_dataset(dst_diag_ctrl[k], data_concso4=oaconc_ctrl[k],vname='concoa') for k in oaconc_ctrl}

diff_dsets_mean =  {k: dsets_exp[k].mean(dim='time') -dset_ctrl[k].mean(dim='time') for k in dset_ctrl}

In [None]:
fig,ax, cax = create_facet_plot(figsize=(12,7),nplots=len(diff_dsets_mean
                                   ))

cmap = mpl.cm.PuOr_r.resampled(11)
norm = mpl.colors.Normalize(vmin=-60, vmax=60)
for k,axk in zip(oaconc_ctrl,ax):
    
    global_map(ax=ax[axk])
    change_so4burd = diff_dsets_mean[k]['concoa']*dset_ctrl[k]['cell_area']/1e3
    change_so4burd.attrs['units'] = 'Tonnes'
    burd_ctrl = (dset_ctrl[k]['concoa'].mean(dim='time')*dset_ctrl[k]['cell_area']).sum(dim=['lon','lat'])/1e3
    cm = change_so4burd.plot(ax=ax[axk], cmap=cmap, norm=norm, add_colorbar=False, add_labels=False)
    tot_change = change_so4burd.sum(dim=['lon','lat'])
    rel_diff = tot_change/(burd_ctrl)*100
    ax[axk].text(0.03, 0.08,f'Total change {tot_change:.3f} Tonnes ({rel_diff:.2f}%) ',transform=ax[axk].transAxes, 
                 bbox={'facecolor':'white'})
    ax[axk].set_title(k)
cbar=fig.colorbar(mpl.cm.ScalarMappable(norm=norm,cmap=cmap),cax=cax, extend='both')
cbar.set_label('Tonnes')

In [None]:
dsets_exp = {k: fix_dataset(dst_diag_exp[k], data_concso4=soaconc_exp[k],vname='concsoa') for k in soaconc_exp}
dset_ctrl = {k: fix_dataset(dst_diag_ctrl[k], data_concso4=soaconc_ctrl[k],vname='concsoa') for k in soaconc_ctrl}

diff_dsets_mean =  {k: dsets_exp[k].mean(dim='time') -dset_ctrl[k].mean(dim='time') for k in dset_ctrl}


In [None]:
fig,ax, cax = create_facet_plot(figsize=(10,4.5),nplots=len(diff_dsets_mean
                                   ))

cmap = mpl.cm.PuOr_r.resampled(11)
norm = mpl.colors.Normalize(vmin=-60, vmax=60)
for k,axk in zip(soaconc_ctrl,ax):
    
    global_map(ax=ax[axk])
    change_so4burd = diff_dsets_mean[k]['concsoa']*dset_ctrl[k]['cell_area']/1e3
    change_so4burd.attrs['units'] = 'Tonnes'
    burd_ctrl = (dset_ctrl[k]['concsoa'].mean(dim='time')*dset_ctrl[k]['cell_area']).sum(dim=['lon','lat'])/1e3
    cm = change_so4burd.plot(ax=ax[axk], cmap=cmap, norm=norm, add_colorbar=False, add_labels=False)
    tot_change = change_so4burd.sum(dim=['lon','lat'])
    rel_diff = tot_change/(burd_ctrl)*100
    ax[axk].text(0.03, 0.08,f'Total change {tot_change:.3f} Tonnes ({rel_diff:.2f}%) ',transform=ax[axk].transAxes, 
                 bbox={'facecolor':'white'})
    ax[axk].set_title(k)
fig.colorbar(mpl.cm.ScalarMappable(norm=norm,cmap=cmap),cax=cax, extend='both')

In [None]:
dsets_exp = {k: fix_dataset(dst_diag_exp[k], data_concso4=bcconc_exp[k],vname='concbc') for k in bcconc_exp}
dset_ctrl = {k: fix_dataset(dst_diag_ctrl[k], data_concso4=bcconc_ctrl[k],vname='concbc') for k in bcconc_ctrl}

diff_dsets_mean =  {k: dsets_exp[k].mean(dim='time') -dset_ctrl[k].mean(dim='time') for k in dset_ctrl}


In [None]:
fig,ax, cax = create_facet_plot(nplots=len(diff_dsets_mean))

cmap = mpl.cm.PuOr_r.resampled(11)
norm = mpl.colors.Normalize(vmin=-8, vmax=8)
for k,axk in zip(dsets_exp,ax):
    
    global_map(ax=ax[axk])
    change_so4burd = diff_dsets_mean[k]['concbc']*dset_ctrl[k]['cell_area']/1e3
    change_so4burd.attrs['units'] = 'Tonnes'
    burd_ctrl = (dset_ctrl[k]['concbc'].mean(dim='time')*dset_ctrl[k]['cell_area']).sum(dim=['lon','lat'])/1e3
    cm = change_so4burd.plot(ax=ax[axk], cmap=cmap, norm=norm, add_colorbar=False, add_labels=False)
    tot_change = change_so4burd.sum(dim=['lon','lat'])
    rel_diff = tot_change/(burd_ctrl)*100
    ax[axk].text(0.03, 0.08,f'Total change {tot_change:.3f} Tonnes ({rel_diff:.2f}%) ',transform=ax[axk].transAxes, 
                 bbox={'facecolor':'white'})
    ax[axk].set_title(k)
fig.colorbar(mpl.cm.ScalarMappable(norm=norm,cmap=cmap),cax=cax, extend='both')

In [None]:
models_cdnc = ['NorESM2-LM', 'MPI-ESM-1-2-HAM', 'EC-Earth3-AerChem']

In [None]:
def diff_nd(nd_exp, nd_ctrl):
    nd_exp = nd_exp.assign(lon=nd_ctrl.lon)
    nd_exp = nd_exp.assign(lat=nd_ctrl.lat)
    nd_diff = nd_exp-nd_ctrl
    return nd_diff
    

In [None]:

diff_dsets_nd =  {k: diff_nd(cloud_diag_exp[k].mean(dim='time'),cloud_diag_ctrl[k].mean(dim='time')) for k in models_cdnc}

In [None]:
fig,ax, cax = create_facet_plot(nplots=len(diff_dsets_nd
                                   ))

cmap = mpl.cm.PuOr_r.resampled(11)
norm = mpl.colors.Normalize(vmin=-80, vmax=80)
for k,axk in zip(diff_dsets_nd,ax):
    
    global_map(ax=ax[axk])
    change_nd= diff_dsets_nd[k]['cdncvi']*(1e-9)
    cm = change_nd.plot(ax=ax[axk], cmap=cmap,norm=norm,  add_colorbar=False, add_labels=False)
    tot_change = global_avg(change_nd)
    ax[axk].text(0.03, 0.08,f'Mean change {tot_change:.3f} #1000*cm-3 ',transform=ax[axk].transAxes, 
                 bbox={'facecolor':'white'})
    ax[axk].set_title(k)
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm,cmap=cmap),cax=cax, extend='both')
cbar.set_label('CDNCVI [#/1000cm-3]')

In [None]:
dsets_exp = {k: fix_dataset(dst_diag_exp[k], data_concso4=nh4conc_exp[k],vname='concnh4') for k in nh4conc_exp}
dset_ctrl = {k: fix_dataset(dst_diag_ctrl[k], data_concso4=nh4conc_ctrl[k],vname='concnh4') for k in nh4conc_ctrl}

diff_dsets_mean =  {k: dsets_exp[k].mean(dim='time') -dset_ctrl[k].mean(dim='time') for k in dset_ctrl}


In [None]:
fig,ax, cax = create_facet_plot(nplots=len(diff_dsets_mean), figsize=(12,5.5))

cmap = mpl.cm.PuOr_r.resampled(11)
norm = mpl.colors.Normalize(vmin=-8, vmax=8)
for k,axk in zip(dsets_exp,ax):
    
    global_map(ax=ax[axk])
    change_so4burd = diff_dsets_mean[k]['concnh4']*dset_ctrl[k]['cell_area']/1e3
    change_so4burd.attrs['units'] = 'Tonnes'
    burd_ctrl = (dset_ctrl[k]['concnh4'].mean(dim='time')*dset_ctrl[k]['cell_area']).sum(dim=['lon','lat'])/1e3
    cm = change_so4burd.plot(ax=ax[axk], cmap=cmap, norm=norm, add_colorbar=False, add_labels=False)
    tot_change = change_so4burd.sum(dim=['lon','lat'])
    rel_diff = tot_change/(burd_ctrl)*100
    ax[axk].text(0.03, 0.08,f'Total change {tot_change:.3f} Tonnes ({rel_diff:.2f}%) ',transform=ax[axk].transAxes, 
                 bbox={'facecolor':'white'})
    ax[axk].set_title(k)
fig.colorbar(mpl.cm.ScalarMappable(norm=norm,cmap=cmap),cax=cax, extend='both')