In [2]:
import xarray as xr
import matplotlib.pyplot as plt
from pyclim_noresm.general_util_funcs import global_avg
import pandas as pd
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
import matplotlib as mpl
from workflow.scripts.utils import calc_relative_change, calc_abs_change
import numpy as np

In [3]:
start_year = snakemake.params.get('start_year',1950)

paths_hist_sst = snakemake.input.paths_hist
paths_hist_sst_piaer = snakemake.input.paths_piaer_hist
vname = snakemake.wildcards.vName

nmodels = len(snakemake.input.paths_piaer_hist)
ctrl_path = snakemake.input.paths_piaer_hist
ctrl_path=sorted(ctrl_path)
exp_path = snakemake.input.paths_hist
exp_path=sorted(exp_path)

params = snakemake.params
time_slice = params.get('time_slice', slice(5,None))
scaling_dict = snakemake.config['variable_scalings'].get(vname,None)
if scaling_dict:
    scaling_factor = scaling_dict['c']
else:
    scaling_factor =1
vcenter = params.get('vcenter', None)
if snakemake.wildcards.kind=='abs':
    minmax = params.get('abs_minmax',None)
    if minmax:
        maxv=minmax[1]
        minv=minmax[0]
    if params.get('scaling_factor',None):
        scaling_factor=params.get('scaling_factor')
    label = params.get('label','')
else:
    minmax = params.get('rel_minmax', None)
    if minmax:
        maxv = minmax[1]
        minv= minmax[0]
    label=params.get('label','')

    
if not vcenter:
    vcenter=0.
else:
    vcenter = float(vcenter)

In [4]:


diffs = {}
mins=[]
maxs=[]
for ctrl, exp in zip(ctrl_path, exp_path):
    ds_exp = xr.open_dataset(exp).load()
    ds_ctrl = xr.open_dataset(ctrl).load()
    with xr.set_options(keep_attrs=True):
        if snakemake.wildcards.kind == 'rel':
            diff = calc_relative_change(ds_ctrl, ds_exp, time_average=False)

        elif snakemake.wildcards.kind == 'abs':
            diff = calc_abs_change(ds_ctrl, ds_exp, time_average=False)
            diff[diff.variable_id] = diff[diff.variable_id]*scaling_factor
        diff = diff.assign(time=diff.time.dt.year)
        diff_df = global_avg(diff[vname].sel(time=slice(start_year,None))).to_pandas()
        diffs[ds_exp.source_id] = diff_df
        mins.append(diff_df.min())
        maxs.append(diff_df.max()) 
    
unit = params.get('units', diff[diff.variable_id].units)

if not minmax:
    minv = np.floor(np.array(mins).mean())
    maxv = np.ceil(np.array(maxs).mean())

if snakemake.wildcards.kind =='rel':
    label = diff[diff.variable_id].long_name
else:
    label = params.get('label', diff[diff.variable_id].long_name)
    if scaling_dict:
        units = scaling_dict['units']
        print(units)
    else:
        units = params.get('units', ds_ctrl[vname].units)
    
    label = '$\Delta$({} {}) of \n {} [{}]'.format(ds_exp.experiment_id,
                                                   ds_ctrl.experiment_id,ds_ctrl[vname].long_name,
                                                  units)

df = pd.DataFrame(diffs)

In [30]:
fig, ax = plt.subplots(figsize=(10,5))
if np.all(df >0) == False:
    ax.axhline(y=0, color='k', linewidth=3.1)
df.plot(ax=ax, alpha = 0.3, color=["#1845fb", "#ff5e02", "#c91f16", 
                                   "#c849a9", "#adad7d", "#86c8dd", "#578dff", "#656364"],
       legend=False)
for line in ax.lines: # put this before you call the 'mean' plot function.
    line.set_label(s='')
df.rolling(3,center=True).mean().plot(ax=ax, 
                          color=["#1845fb", "#ff5e02", "#c91f16", "#c849a9", 
                                 "#adad7d", "#86c8dd", "#578dff", "#656364"],
                         linewidth=2.1, legend=False)
# ax.set_ylabel(f'$\Delta ${vname} histSST histSST-piAer \n W m-2, fontsize=14)
ax.set_ylabel(label, fontsize=12)
if snakemake.wildcards.kind =='rel':
    ax.set_ylim(minv,maxv)
elif minmax:
    ax.set_ylim(minv,maxv)
ax.xaxis.set_major_locator(MultipleLocator(5))
ax.set_xlabel('')
h,l = ax.get_legend_handles_labels()
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.legend(fontsize=14)
plt.savefig(snakemake.output.outpath, dpi=144, bbox_inches='tight')