In [2]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import matplotlib.cm as cm
import matplotlib as mpl
from workflow.scripts.plotting_tools import create_facet_plot
from workflow.scripts.utils import regrid_global

In [28]:
ec_forcing = xr.open_dataset(snakemake.input.forcing_ec)
mpi_forcing = xr.open_dataset(snakemake.input.forcing_mpi)
nor_forcing = xr.open_dataset(snakemake.input.forcing_noresm)
vname = snakemake.wildcards.vName
mask_treshold = snakemake.params.get('mask_treshold', -2)
time_slice = snakemake.params.get('time_slice', None)

In [32]:
if snakemake.wildcards.experiment.split('-')[0]=='piClim':
    if time_slice:
        t_slice = time_slice
    else:
        t_slice = slice(5, None)
    ec_forcing_m = ec_forcing.isel(year=t_slice).mean(dim='year')
    nor_forcing_m = nor_forcing.isel(year=t_slice).mean(dim='year')
    mpi_forcing_m = mpi_forcing.isel(year=t_slice).mean(dim='year')
else:
    if time_slice:
        t_slice = time_slice
    else:
        t_slice = slice(-30,None)
    ec_forcing_m = ec_forcing.isel(year=t_slice).mean(dim='year')
    nor_forcing_m = nor_forcing.isel(year=t_slice).mean(dim='year')
    mpi_forcing_m = mpi_forcing.isel(year=t_slice).mean(dim='year')

In [33]:
def normalize_min_max(da):
    return (da - da.max())/(da.min()-da.max())
def normalize_z(da):
    return (da-da.mean())/da.std()

norm_mpi = normalize_z(mpi_forcing_m[vname]).to_dataset()
norm_nor = normalize_z(nor_forcing_m[vname]).to_dataset()
norm_ec = normalize_z(ec_forcing_m[vname]).to_dataset()


In [34]:

reg_grid_params = snakemake.config['regrid_params']

In [35]:
nor_regrid = regrid_global(norm_nor.cf.add_bounds(['lon','lat']),lon=reg_grid_params['dxdy'][0],
                           lat=reg_grid_params['dxdy'][1])
mpi_regrid = regrid_global(norm_mpi.cf.add_bounds(['lon','lat']),lon=reg_grid_params['dxdy'][0],
                           lat=reg_grid_params['dxdy'][1])
ec_regrid = regrid_global(norm_ec.cf.add_bounds(['lon','lat']),lon=reg_grid_params['dxdy'][0],
                           lat=reg_grid_params['dxdy'][1])

In [17]:
merge_forcing = (nor_regrid[vname]+mpi_regrid[vname]+ec_regrid[vname])/3

In [22]:
merged_normed = merge_forcing.to_dataset(name='Merged_normed_ERF')
merged_normed = regrid_global(merged_normed.cf.add_bounds(['lon','lat']),lat=5, lon=5)

mask = xr.where(merged_normed['Merged_normed_ERF'] < mask_treshold,1,0)


In [23]:
fig,ax ,cax = create_facet_plot(1,figsize=(10,5), subplot_kw={'projection':ccrs.Robinson()})
cmap = cm.get_cmap('PRGn', 21).copy()
merge_forcing.plot(ax=ax['A'], add_colorbar=False, transform=ccrs.PlateCarree(), cmap=cmap, vmax=10, vmin=-10)
ax['A'].coastlines()
fig.colorbar(cm.ScalarMappable(mpl.colors.Normalize(vmin=-10,vmax=10), cmap=cmap),cax=cax)

merged_normed = merge_forcing.to_dataset(name='Merged_normed_ERF')
merged_normed = regrid_global(merged_normed.cf.add_bounds(['lon','lat']),lat=5, lon=5)


mask.plot.contour(ax=ax['A'], transform=ccrs.PlateCarree(),colors='orange',levels=1
    )

In [24]:
fig,ax, cax = create_facet_plot(3,subplot_kw={'projection':ccrs.Robinson()}, figsize=(12,6))

cmap = cm.get_cmap('PRGn', 21).copy()
nor_forcing_m[vname].plot(ax=ax['A'], transform=ccrs.PlateCarree(),add_colorbar=False, cmap=cmap, vmax=15, vmin=-15)
ax['A'].coastlines()
ec_forcing_m[vname].plot(ax=ax['B'], transform=ccrs.PlateCarree(),add_colorbar=False, cmap=cmap, vmax=15, vmin=-15)
ax['B'].coastlines()
mpi_forcing_m[vname].plot(ax=ax['C'], transform=ccrs.PlateCarree(),add_colorbar=False, cmap=cmap, vmax=15, vmin=-15)
ax['C'].coastlines()
cb =fig.colorbar(cm.ScalarMappable(mpl.colors.Normalize(-15,15),cmap=cmap), cax=cax)

for a in ax:
    mask.plot.contour(ax=ax[a], transform=ccrs.PlateCarree(),colors='orange',levels=1
    )
ax['A'].set_title('NorESM2-LM')
ax['B'].set_title('MPI-ESM-1-2-HAM')
ax['C'].set_title('EC-Earth3-AerChem')
cb.set_label('$W/m^2$')
plt.savefig(snakemake.output.plot, dpi=144, bbox_inches='tight')

In [25]:
m = mask.to_dataset(name='ERF_region_mask')
m.attrs['treshold'] = mask_treshold
m.to_netcdf(snakemake.output.mask)

m['Asia'] = m['ERF_region_mask'].where((m.lon > 70) & (m.lat > 20),0.0)
m['North_America'] =m['ERF_region_mask'].where(((m.lon > -100)&(m.lon < -30)) & (m.lat > 20),0.0) 
m['Europe'] = m['ERF_region_mask'].where(((m.lon > 20)&(m.lon < 70)) & (m.lat > 20),0.0)
m.to_netcdf(snakemake.output.mask)