In [2]:
import matplotlib.pyplot as plt
from workflow.scripts.plotting_tools import create_facet_plot
import xarray as xr
from workflow.scripts.utils import calc_relative_change, calc_abs_change, calc_error_gridded
from workflow.scripts.plotting_tools import global_map
import matplotlib as mpl
import numpy as np

In [3]:
nmodels = len(snakemake.input.path_ctrl)
ctrl_path = snakemake.input.path_ctrl
ctrl_path=sorted(ctrl_path)
exp_path = snakemake.input.path_exp
exp_path=sorted(exp_path)

params = snakemake.params
time_slice = params.get('time_slice', slice(5,None))
cmap = params.get('cmap', None)
nlevels = params.get('nlevels', 21)
draw_error_mask = params.get('draw_error_mask',True)
if cmap:
    cmap = mpl.cm.get_cmap(cmap, nlevels)
else:
    cmap = mpl.cm.get_cmap('bwr', nlevels)
vcenter = params.get('vcenter', None)
if snakemake.wildcards.kind=='abs':
    minmax = params.get('abs_minmax',None)
    if minmax:
        maxv=minmax[1]
        minv=minmax[0]
    scaling_factor=params.get('scaling_factor',1)
    label = params.label
else:
    minmax = params.get('rel_minmax', None)
    if minmax:
        maxv = minmax[1]
        minv= minmax[0]
    label=params.label

    
if not vcenter:
    vcenter=0.
else:
    vcenter = float(vcenter)

diffs = {}
errors = {}
mins=[]
maxs=[]
for ctrl, exp in zip(ctrl_path, exp_path):
    ds_exp = xr.open_dataset(exp).load()
    ds_ctrl = xr.open_dataset(ctrl).load()
    if 'year' in ds_exp.dims:
        t_dim='year'
    else:
        t_dim='time'
    with xr.set_options(keep_attrs=True):
        if snakemake.wildcards.kind == 'rel':
            diff = calc_relative_change(ds_ctrl, ds_exp, time_slice=time_slice, time_average=False)
            error = calc_error_gridded(diff[diff.variable_id], time_dim=t_dim)
            diff = calc_relative_change(ds_ctrl, ds_exp, time_slice=time_slice, time_average=True)
        elif snakemake.wildcards.kind == 'abs':
            diff = calc_abs_change(ds_ctrl, ds_exp, time_slice=time_slice, time_average=False)
            diff[diff.variable_id] = diff[diff.variable_id]*scaling_factor
            error = calc_error_gridded(diff[diff.variable_id],time_dim=t_dim)
            diff = calc_abs_change(ds_ctrl, ds_exp, time_slice=time_slice, time_average=True)
            diff[diff.variable_id] = diff[diff.variable_id]*scaling_factor
        diffs[ds_exp.source_id] = diff
        errors[ds_exp.source_id] = error
        mins.append(diff[diff.variable_id].min().values)
        maxs.append(diff[diff.variable_id].max().values) 
unit = params.get('units', diff[diff.variable_id].units)

if not minmax:
    minv = np.floor(np.array(mins).mean())
    maxv = np.ceil(np.array(maxs).mean())

if snakemake.wildcards.kind =='rel':
    label = diff[diff.variable_id].long_name
else:
    label = params.get('label', diff[diff.variable_id].long_name)
    units = params.get('units', diff[diff.variable_id].units)
    label = '{} {}'.format(label,units)


In [10]:
fig,ax,cax = create_facet_plot(nmodels)
if snakemake.wildcards.kind == 'rel':
    norm = mpl.colors.TwoSlopeNorm(vmin=minv, vcenter=vcenter, vmax=maxv)
    label = label
else:
    norm = mpl.colors.TwoSlopeNorm(vmin=minv, vcenter=vcenter, vmax=maxv)
    label = label
for ds, axi in zip(diffs, ax):
    temp_ds = diffs[ds]
    da = temp_ds[temp_ds.variable_id] 
    error_mask = errors[ds] < da

    pcm=temp_ds[temp_ds.variable_id].plot(ax=ax[axi], cmap=cmap, norm=norm,  add_colorbar=False)
   
    global_map(ax[axi])
    if draw_error_mask:
        error_mask.plot.contourf(ax=ax[axi],hatches=[None,'...'], alpha=0.,levels=3, add_colorbar=False)
    ax[axi].set_ylabel('')
    ax[axi].set_xlabel('')
    ax[axi].set_title(ds)
fig.colorbar(pcm, cax=cax, extend='both',
             label=label)
plt.savefig(snakemake.output.outpath, bbox_inches='tight', dpi=144, facecolor='w')