In [2]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
from workflow.scripts.utils import calc_error_gridded
from workflow.scripts.plotting_tools import create_facet_plot, global_map
import xesmf
import matplotlib as mpl

In [21]:
path = snakemake.input.path
params = snakemake.params
nmodels=len(path)
time_slice = params.get('time_slice', slice(5,None))
height = snakemake.wildcards.plevel
nlevels = params.get('nlevels', 21)
cmap = params.get('cmap', None)
vname = snakemake.wildcards.variable
plevel_data = snakemake.config['cloud_def'][height]
if plevel_data['top']:
    plevel_slice = slice(plevel_data['bottom']*100, plevel_data['top']*100)
else:
    plevel_slice = slice(plevel_data['bottom']*100, None)
    
scaling_dict = snakemake.config['variable_scalings'].get(vname,None)

if cmap:
    cmap = mpl.cm.get_cmap(cmap, nlevels)
else:
    cmap = mpl.cm.get_cmap('RdYlBu_r', nlevels)

if scaling_dict:
    scaling_factor = scaling_dict['c']
else:
    scaling_factor =1
vcenter = params.get('vcenter', None)
minmax = params.get('abs_minmax',None)
if minmax:
    maxv=minmax[1]
    minv=minmax[0]
if params.get('scaling_factor',None):
    scaling_factor=params.get('scaling_factor')
label = params.get('label','')
    
if not vcenter:
    vcenter=0.
else:
    vcenter = float(vcenter)
    
if scaling_dict:
    units = scaling_dict['units']
#     print(units)

In [22]:


ds = xr.open_dataset(path)
source_id = ds.parent_source_id

ds = ds[vname].isel(time=time_slice)
if scaling_dict:
    ds = ds.assign_attrs(units=units)
with xr.set_options(keep_attrs=True):
    ds = ds.sel(plev=plevel_slice)*scaling_factor
ds = ds.mean(dim='plev',keep_attrs=True)
label = ds.long_name.split(' ',1)[1] + f' \n {ds.units}'
#error = calc_error_gridded(ds.copy(),kind='SEM', time_dim='time')
ds = ds.mean(dim='time',keep_attrs=True).to_dataset(name=vname)


    


In [27]:
fig,ax = plt.subplots(subplot_kw={'projection':ccrs.PlateCarree()}, figsize=(14,6))
ds[vname].plot(ax=ax,norm = mpl.colors.TwoSlopeNorm(vcenter=0),cmap=cmap)
global_map(ax)
ax.set_title(source_id)
plt.savefig(snakemake.output.outpath, bbox_inches='tight', dpi=144, facecolor='w')