In [2]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
from workflow.scripts.utils import calc_error_gridded
from workflow.scripts.plotting_tools import create_facet_plot, global_map
import xesmf
import matplotlib as mpl

In [4]:
paths = sorted(snakemake.input.paths)
params = snakemake.params
nmodels=len(paths)
time_slice = params.get('time_slice', slice(5,None))
height = snakemake.wildcards.plevel
nlevels = params.get('nlevels', 21)
cmap = params.get('cmap', None)
vname = snakemake.wildcards.variable
plevel_data = snakemake.config['cloud_def'][height]
if plevel_data['top']:
    plevel_slice = slice(plevel_data['bottom']*100, plevel_data['top']*100)
else:
    plevel_slice = slice(plevel_data['bottom']*100, None)
    
scaling_dict = snakemake.config['variable_scalings'].get(vname,None)

if cmap:
    cmap = mpl.cm.get_cmap(cmap, nlevels)
else:
    cmap = mpl.cm.get_cmap('bwr', nlevels)

if scaling_dict:
    scaling_factor = scaling_dict['c']
else:
    scaling_factor =1
vcenter = params.get('vcenter', None)
minmax = params.get('abs_minmax',None)
if minmax:
    maxv=minmax[1]
    minv=minmax[0]
if params.get('scaling_factor',None):
    scaling_factor=params.get('scaling_factor')
label = params.get('label','')
    
if not vcenter:
    vcenter=0.
else:
    vcenter = float(vcenter)
    
if scaling_dict:
    units = scaling_dict['units']
#     print(units)

In [5]:
dsets = {}
vminmed = []
vmaxmed = []
for path in paths:

    ds = xr.open_dataset(path)
    source_id = ds.parent_source_id
    dsets[source_id]={}

    ds = ds[vname].isel(time=time_slice)
    if scaling_dict:
        ds = ds.assign_attrs(units=units)
    with xr.set_options(keep_attrs=True):
        ds = ds.sel(plev=plevel_slice)*scaling_factor
    ds = ds.mean(dim='plev',keep_attrs=True)
    label = ds.long_name.split(' ',1)[1] + f' \n {ds.units}'
    #error = calc_error_gridded(ds.copy(),kind='SEM', time_dim='time')
    ds = ds.mean(dim='time',keep_attrs=True).to_dataset(name=f'{source_id}')
    dsets[source_id]['data']=ds
    if np.isnan(np.nanmin(ds[source_id].values)) == False:
        vminmed.append(np.nanmin(ds[source_id].values))
    if np.isnan(np.nanmax(ds[source_id].values)) == False:
        vmaxmed.append(np.nanmax(ds[source_id].values))
vmaxmed = np.array(vmaxmed)
vminmed = np.array(vminmed)


    


In [6]:
minv = np.median(vminmed)
maxv = np.median(vmaxmed)

In [7]:
fig,ax,cax = create_facet_plot(nmodels)
draw_error_mask=False
if abs(minv/maxv) > 1e-2:
    norm = mpl.colors.TwoSlopeNorm(vmin=minv, vcenter=vcenter, vmax=maxv)
else:
    norm = mpl.colors.Normalize(vmin=0, vmax=maxv)
    if params.get('cmap',None):
        cmap = mpl.cm.get_cmap(cmap, nlevels)
    else:
        cmap = mpl.cm.get_cmap('YlOrRd', nlevels)
label = label
for ds, axi in zip(dsets, ax):
    temp_ds = dsets[ds]
    da = temp_ds['data'][ds] 

    pcm=temp_ds['data'][ds].plot(ax=ax[axi], cmap=cmap, norm=norm,  add_colorbar=False)
   
    global_map(ax[axi])
    if draw_error_mask:
        error_mask.plot.contourf(ax=ax[axi],hatches=[None,'...'], alpha=0.,levels=3, add_colorbar=False)
    ax[axi].set_ylabel('')
    ax[axi].set_xlabel('')
    ax[axi].set_title(ds)
fig.colorbar(pcm, cax=cax, extend='both',
             label=label)
plt.savefig(snakemake.output.outpath, bbox_inches='tight', dpi=144, facecolor='w')