In [2]:
import xarray as xr
import matplotlib.pyplot as plt
from pyclim_noresm.general_util_funcs import global_avg
import pandas as pd
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
import matplotlib as mpl
from workflow.scripts.utils import calc_relative_change, calc_abs_change
import numpy as np

In [3]:
time_slice = snakemake.params.get('time_slice',slice(None,None))
start_year = snakemake.params.get('start_year',1950)

paths = sorted(snakemake.input.paths)
vname = snakemake.wildcards.vName
params = snakemake.params
scaling_dict = snakemake.config['variable_scalings'].get(vname,None)
minmax = params.get('minmax',None)
nmodels = len(paths)
if minmax:
    maxv=minmax[1]
    minv=minmax[0]

In [22]:
dsets = {}
for path in paths:

    ds = xr.open_dataset(path)
    source_id = ds.source_id
   
    if 'year' in ds.dims:
        ds = ds.rename({'year':'time'})

    with xr.set_options(keep_attrs=True):
        ds=ds.assign(time=ds.time.dt.year)
        ts = global_avg(ds)
    dsets[source_id] = ts[vname].to_pandas()
    #time=ds.time.values

df=pd.DataFrame(dsets)
if start_year:
    df = df.loc[start_year:,:]


label = params.get('label', ds[vname].long_name)

if scaling_dict:
    units = scaling_dict['units']
else:
    units = params.get('units', ds[vname].units)
if len(label.split(' ')) > 3:
    label = ' '.join(label.split(' ')[:3]) + '\n' + ' '.join(label.split(' ')[3:])

label = f'{label} \n {units}'

In [40]:
fig, axes = plt.subplots(nrows=nmodels,figsize=(10,6), sharex=True)
cycler = mpl.cycler(color=["#1845fb", "#ff5e02", "#c91f16", "#c849a9", "#adad7d", "#86c8dd", "#578dff", "#656364"])

for ax,col,color in zip(axes,df.columns,cycler):
    if np.all(df[col] >0) == False:
        ax.axhline(y=0, color='k', linewidth=3.1)
    df[col].plot(ax=ax, alpha = 0.3, 
       legend=False, color=color['color'])
    for line in ax.lines: # put this before you call the 'mean' plot function.
        line.set_label(s='')
    df[col].rolling(5,center=True).mean().plot(ax=ax, 
                             linewidth=2.1, legend=False, color=color['color'])
    #ax.yaxis.set_major_formatter(FormatStrFormatter('%.2e'))
# ax.set_ylabel(f'$\Delta ${vname} histSST histSST-piAer \n W m-2, fontsize=14)
    ax.set_ylabel(label, fontsize=10)
    h,l = ax.get_legend_handles_labels()
    ax.legend(fontsize=14)
if minmax:
    ax.set_ylim(minv,maxv)

ax.xaxis.set_major_locator(MultipleLocator(5))
ax.set_xlabel('')

plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
# plt.legend(fontsize=14)
plt.savefig(snakemake.output.outpath, dpi=144, bbox_inches='tight')