In [43]:
import matplotlib.pyplot as plt
from workflow.scripts.plotting_tools import create_facet_plot
import xarray as xr
from workflow.scripts.plotting_tools import global_map
import matplotlib as mpl
import numpy as np
import xesmf
from pyclim_noresm.general_util_funcs import global_avg

In [48]:


nmodels = len(snakemake.input)
paths = sorted(snakemake.input)

params = snakemake.params
time_slice = params.get('time_slice', slice(5,None))
cmap = params.get('cmap', None)
nlevels = params.get('nlevels', 21)

scaling_factor=params.get('scaling_factor',1000)
label = params.get('label',None)
add_global_avg = params.get('add_global_avg',False)

if cmap:
    cmap = mpl.cm.get_cmap(cmap, nlevels)
else:
    cmap = mpl.cm.get_cmap('viridis_r', nlevels)


In [53]:
fig,ax,cax = create_facet_plot(nmodels)
norm = mpl.colors.Normalize(params.get('vmin',0), params.get('vmax',None))
for path, axi in zip(paths, ax):
    temp_ds = xr.open_dataset(path)
    if 'time' in temp_ds.coords:
        temp_ds = temp_ds.isel(time=time_slice)
        temp_ds = temp_ds.mean(dim='time',keep_attrs=True)
    with xr.set_options(keep_attrs=True):
        temp_ds[temp_ds.variable_id] = temp_ds[temp_ds.variable_id]*scaling_factor 
    
    pcm=temp_ds[temp_ds.variable_id].plot(ax=ax[axi], cmap=cmap, norm=norm,  add_colorbar=False)
   
    global_map(ax[axi])

    ax[axi].set_ylabel('')
    ax[axi].set_xlabel('')
    ax[axi].set_title(temp_ds.source_id)
    if add_global_avg:
        ax[axi].text(0.12, 0.05, 
                     f"Mean: {global_avg(temp_ds)[temp_ds.variable_id]:.3f}",
                     transform=ax[axi].transAxes,va='center', ha='center',
                     bbox=dict(boxstyle="square",ec=(1., 0.5, 0.5),
                   fc=(1., 0.8, 0.8))
                    )
if label == None:
    label= f'{temp_ds.variable_id} {da.attrs["units"]}'
fig.colorbar(pcm, cax=cax, extend=params.get('cb_extend','neither'),
             label=label)
plt.savefig(snakemake.output.outpath, bbox_inches='tight', dpi=144, facecolor='w')