In [2]:
import matplotlib.pyplot as plt
from workflow.scripts.plotting_tools import create_facet_plot, global_map
import xarray as xr
from workflow.scripts.utils import model_levels_to_pressure_levels
import matplotlib as mpl
import numpy as np
import xesmf
from pyclim_noresm.general_util_funcs import global_avg

In [3]:


nmodels = len(snakemake.input)
paths = sorted(snakemake.input)

params = snakemake.params
time_slice = params.get('time_slice', slice(5,None))
cmap = params.get('cmap', None)
nlevels = params.get('nlevels', 21)
method = params.get('method','mean')
height = snakemake.wildcards.plevel
scaling_factor=params.get('scaling_factor',1000)
label = params.get('label',None)
add_global_avg = params.get('add_global_avg',False)

if cmap:
    cmap = mpl.cm.get_cmap(cmap, nlevels)
else:
    cmap = mpl.cm.get_cmap('viridis_r', nlevels)


In [31]:
fig,ax,cax = create_facet_plot(nmodels)
norm = mpl.colors.Normalize(params.get('vmin',0),params.get('vmax',None))
for path, axi in zip(paths, ax):
    temp_ds = xr.open_dataset(path)
    if temp_ds.cf['Z'].formula[0] == 'p':
        temp_ds = model_levels_to_pressure_levels(temp_ds)
    
    var_id = temp_ds.variable_id
    if 'time' in temp_ds.coords:
        temp_ds = temp_ds.isel(time=time_slice)
        temp_ds = temp_ds.mean(dim='time',keep_attrs=True)
        source_id = temp_ds.parent_source_id
    if 'Z' in temp_ds.cf.get_valid_keys():
        Z = temp_ds.cf['Z'].name
    else:
        Z = 'plev'

    if Z == 'lev':
        plevel_data = snakemake.config['cloud_def_height'][height]
        if plevel_data['top']:
            plevel_slice = slice(plevel_data['bottom'], plevel_data['top'])
        else:
            plevel_slice = slice(plevel_data['bottom'], None)

    else:
        plevel_data = snakemake.config['cloud_def'][height]
        if plevel_data['top']:
            plevel_slice = slice(plevel_data['bottom']*100, plevel_data['top']*100)
        else:
            plevel_slice = slice(plevel_data['bottom']*100, None)    
    
        
    with xr.set_options(keep_attrs=True):
        temp_ds = temp_ds.sel({Z:plevel_slice})
        temp_ds[var_id] = temp_ds[var_id]*scaling_factor 
    
    if method == 'mean':
        temp_ds = temp_ds.mean(dim=Z)
    elif method == 'sum':
        temp_ds = temp_ds.sum(dim=Z)
    
    pcm=temp_ds[var_id].plot(ax=ax[axi], cmap=cmap, norm=norm,  add_colorbar=False)
   
    global_map(ax[axi])

    ax[axi].set_ylabel('')
    ax[axi].set_xlabel('')
    ax[axi].set_title(source_id)
    if add_global_avg:
        gavg = global_avg(temp_ds).compute()
        ax[axi].text(0.12, 0.05, 
                     f"Mean: {gavg[var_id].values:.3f}",
                     transform=ax[axi].transAxes,va='center', ha='center',
                     bbox=dict(boxstyle="square",ec=(1., 0.5, 0.5),
                   fc=(1., 0.8, 0.8))
                    )
if label == None:
    units = temp_ds[var_id].attrs["units"]
    label= f'{var_id} {units}'
fig.colorbar(pcm, cax=cax, extend=params.get('cb_extend','neither'),
             label=label)
plt.savefig(snakemake.output.outpath, bbox_inches='tight', dpi=144, facecolor='w')