In [2]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from workflow.scripts.plotting_tools import global_map
import numpy as np
import matplotlib as mpl
import xesmf as xe
from workflow.scripts.utils import calc_error_gridded, regrid_global

In [3]:
paths = sorted(snakemake.input.paths)
params = snakemake.params
time_slice = params.get('time_slice', slice(3,-1))
nlevels = params.get('nlevels', 11)
minmax = params.get('minmax',[-6,6])
draw_error_mask = params.get('draw_error_mask', True)
vmin = minmax[0]
vmax = minmax[1]

In [5]:
dsets = {}
for path in paths:

    ds = xr.open_dataset(path)
    with xr.set_options(keep_attrs=True):
        
        if snakemake.config.get('regrid_params', None) and snakemake.params.get('regrid', True):
            if not ds.cf.bounds.get('lon',None):
                ds = ds.cf.add_bounds(['lon','lat'])
        
            grid_params = snakemake.config['regrid_params']
            grid_path = grid_params.get('grid_path',None)
            method=grid_params.get('method','conservative')
            if grid_path:
                out_grid = xr.open_dataset(grid_path)
                ds = regrid_global(ds, out_grid, method=method)
            elif grid_params.get('dxdy',None):
                dxdy = grid_params['dxdy']
                ds = regrid_global(ds, lon=dxdy[0], lat=dxdy[1], method=method)
            else:
                print('No outgrid provided!')

    source_id = ds.source_id
    dsets[source_id]={}
    ds = ds[snakemake.wildcards.vName].isel(year=time_slice)
    error = calc_error_gridded(ds,kind='SEM' )
    ds = ds.mean(dim='year',keep_attrs=True).to_dataset(name=f'{source_id}')
    dsets[source_id]['data']=ds
    dsets[source_id]['error']=error
    

In [6]:
fig,ax = plt.subplot_mosaic([['a)','b)','c)'],['d)','e)','f)'],['g)','h)','i)'],['j)','','']], 
                            subplot_kw={'projection':ccrs.PlateCarree()}, figsize=(18,10), sharex=True, sharey=True)
cmap = mpl.cm.get_cmap('RdYlBu_r', nlevels)
for k, i in zip(ax.keys(), dsets.keys()):
    pmesh= dsets[i]['data'][i].plot(ax=ax[k],cmap=cmap,vmin=vmin, vmax=vmax, add_colorbar=False)
    error_mask = dsets[i]['data'][i] < dsets[i]['error']
    if draw_error_mask:
        error_mask.plot.contourf(ax=ax[k],hatches=[None,'...'],alpha=0., levels=3, add_colorbar=False)
    ax[k].coastlines()
    ax[k].set_title(i)
    global_map(ax[k])
cax = fig.add_axes([0.94,0.2,0.02,0.62])
cmap = mpl.cm.RdYlBu_r
fig.colorbar(pmesh, cax=cax, extend='both', 
             label='$W/m^2$')
plt.savefig(snakemake.output.outpath, bbox_inches='tight', facecolor='white', dpi=144)