In [3]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from workflow.scripts.utils import calc_error, calc_error_gridded
from workflow.scripts.plotting_tools import global_map, create_facet_plot
from pyclim_noresm.general_util_funcs import global_avg
import numpy as np
import matplotlib as mpl
import pandas as pd

In [4]:
paths = sorted(snakemake.input.paths)
vname = snakemake.wildcards.vName

params = snakemake.params
time_slice = params.get('time_slice', slice(5,None))
nlevels = params.get('nlevels', 11)
draw_error_mask = params.get('draw_error_mask', True)

In [5]:
erfs = {}
erfs['mean'] = {}
erfs['std'] = {}
for path in paths:
    ds = xr.open_dataset(path)
    ds = ds.isel(year=time_slice)
    ds = ds.rename({'year':'time'})    
    source = ds.parent_source_id
    ts = global_avg(ds[vname])
    ts = ts.mean(dim='time')
    std = calc_error(ds[vname], kind='SEM',time_dim='time')
    erfs['mean'][source] = ts.values
    erfs['std'][source] = std.values
    

In [6]:
df = pd.DataFrame(erfs)
nmodels = len(df['mean'])
multimodel_mean= df['mean'].mean()
error = df['mean'].std()/np.sqrt(nmodels)
df.loc['Multi-model','mean'] = multimodel_mean 
df.loc['Multi-model', 'std'] =error
# df = df.transpose()
vmax=round(abs(df.loc[:,'mean']).max()+0.1,1)
df = df.astype(float)

In [7]:
dsets = {}
for path in paths:

    ds = xr.open_dataset(path)
    source_id = ds.parent_source_id
    dsets[source_id]={}
    ds = ds[snakemake.wildcards.vName].isel(year=time_slice)
    error = calc_error_gridded(ds,kind='SEM' )
    ds = ds.mean(dim='year',keep_attrs=True).to_dataset(name=f'{source_id}')
    dsets[source_id]['data']=ds
    dsets[source_id]['error']=error
if vname in ['LWDirectEff','LWDirectEff_cs']:
    cvmin=-2
    cvmax=2
else:
    cvmin=-5
    cvmax=5
     

In [8]:
fig,ax,cax = create_facet_plot(len(paths)+1,
                            subplot_kw={'projection':ccrs.PlateCarree()}, 
                figsize=(14,12), last_axis_plain=True,create_cax=False)
cmap = mpl.cm.get_cmap('RdYlBu_r', nlevels)

for k, i in zip(list(ax.keys())[:-1], dsets.keys()):
    mesh=dsets[i]['data'][i].plot.pcolormesh(ax=ax[k],vmin=cvmin,
                              vmax=cvmax,cmap=cmap, 
                                   add_colorbar=False)

    error_mask = dsets[i]['data'][i] < dsets[i]['error']
    if draw_error_mask:
        error_mask.plot.contourf(ax=ax[k],hatches=[None,'...'],alpha=0., levels=3, add_colorbar=False)
    
    ax[k].set_title(i)
    global_map(ax[k])
    
cax = fig.add_axes([0.94,0.39,0.02,0.47])
cmap = mpl.cm.RdYlBu_r
last = list(ax.keys())[-1]
df['mean'].plot.barh(ax=ax[last], legend=False, xerr=df['std'])
ax[last].axvline(color='darkgrey')
ax[last].set_xlim(-vmax,vmax)
ax[last].set_xlabel('$W/m^2$')
ax[last].yaxis.tick_right()
fig.colorbar(mesh, cax=cax, extend='both', 
             label='$W/m^2$')
plt.savefig(snakemake.output.outpath, bbox_inches='tight', facecolor='white', dpi=144)