In [None]:
import xarray as xa
from matplotlib import pyplot as plt
import cartopy
import glob
import re
natsort = lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split("(\d+)", s)]
import xcdat as xc
import numpy as np
from matplotlib import colors
import matplotlib as mpl
mpl.rc('font', family='DejaVu Serif') 

In [None]:
def get_forced(var):
    if var=='tos':
        cat = 'Omon'
        start_year = 1950
        vid = 'tos'
    elif var=='pr':
        cat = 'Amon'
        start_year = 1979
        vid = 'pr'
    elif var=='monmaxpr':
        cat = 'Aday'
        start_year = 1979
        vid = 'pr'
    root_dir = "/p/lustre3/shiduan/ForceSMIP"  # path to forcesmip data (NCAR)
    models = ['CESM2', 'MIROC6', 'MPI-ESM1-2-LR', 'MIROC-ES2L', 'CanESM5']
    reference_period = (str(start_year)+"-01-01", "2022-12-31") # climatological period (for anomaly calculations)
    tv_time_period = (str(start_year)+"-01-01", "2023-01-01")
    model_mean_list = []
    for model in models:
        print(model)
        mpath = root_dir + '/Training/' + cat + '/' + var + '/' + model
        mfiles = glob.glob(mpath + '/*')
        # print(mfiles)
        if model == "CESM2":
            members = [p.split("ssp370_")[-1].split(".1880")[0] for p in mfiles]
        else:
            members = [p.split("_")[-1].split(".")[0] for p in mfiles]
        members.sort(key=natsort)
        # initialize model ensemble xarray dataset
        ds_model = None
        for im, member in enumerate(members):
            # print member progress
            print('.', end='')
            # get member filename
            fn = glob.glob(mpath + "/*_" + member + ".*.nc")
            # make sure filename is unique
            if len(fn) != 1:
                raise ValueError("Unexpected number of model members")
            else:
                fn = fn[0]
            # load data
            ds = xc.open_dataset(fn)
            ds = ds.bounds.add_missing_bounds(axes=['T'])
            # remove singletons / lon
            ds = ds.squeeze()
            #ds = ds.drop_vars('lon')
            # subset data to user-specified time period
            ds = ds.sel(time=slice(tv_time_period[0], tv_time_period[1]))
            # calculate departures (relative to user-specified reference time period)
            ds = ds.temporal.departures(vid, freq='month', reference_period=reference_period)
            if 'file_qf' in ds.variables:
                ds = ds.drop('file_qf')
            if 'ref_time' not in locals():
                ref_time = ds.time
            for i, t in enumerate(ds.time.values):
                m = t.month; y = t.year
                rt = ref_time.values[i]; rm = rt.month; ry = rt.year
                if ((ry != y) | (rm != m)):
                    raise ValueError("model time and reference time do not match")
            ds["time"] = ref_time.copy()
            # add model realization to model ensemble dataset
            if ds_model is None:
                ds_model = ds
            else:
                ds_model = xa.concat((ds_model, ds), dim='member')
        ds_model_mean = ds_model.mean(dim='member', skipna=False)
        ds_model_mean = ds_model_mean.load()
        # all_models.append(ds_model) 
        model_mean_list.append(ds_model_mean)
    return model_mean_list

In [None]:
pr_model_mean_list = get_forced(var='pr')
monmaxpr_model_mean_list = get_forced(var='monmaxpr')
tos_model_mean_list = get_forced(var='tos')

In [None]:
models = ['CESM2', 'MIROC6', 'MPI-ESM1-2-LR', 'MIROC-ES2L', 'CanESM5']
ecs = {'CESM2':5.15, 'CanESM5':5.64, 'MIROC6':2.60, 'MPI-ESM1-2-LR':3.03, 'MIROC-ES2L':2.66}
fig, axes = plt.subplots(nrows=3, ncols=5, subplot_kw={'projection':cartopy.crs.Robinson(central_longitude=180)}, figsize=(15, 5))
norm = colors.TwoSlopeNorm(vmin=-.35, vcenter=0, vmax=.35)
for i, model in enumerate(models):
    data = pr_model_mean_list[i].pr
    time_length = len(data.time.data)
    time_step = np.linspace(1, time_length+1, time_length)
    data['time'] = time_step
    trend = data.polyfit(dim='time', deg=1)
    trend = trend.polyfit_coefficients.sel(degree=1)*12*10 # degree per decade
    ax = axes.flatten()[i]
    ax.contourf(trend.lon, trend.lat, trend, transform=cartopy.crs.PlateCarree(), 
                norm=norm, cmap='BrBG', levels=30)
    ax.set_title(models[i]+' '+"{:.2f}".format(ecs[models[i]]))
    ax.add_feature(cartopy.feature.COASTLINE, alpha=.5)
    # plt.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm, cmap='coolwarm'), ax=ax, extend='both', shrink=.5)
cbar_ax = fig.add_axes([.91, 0.67, 0.008, 0.15])
cb = fig.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm, cmap='BrBG'), cax=cbar_ax, extend='both')
print(np.max(trend).data, ' ', np.min(trend).data)
norm = colors.TwoSlopeNorm(vmin=-2.8, vcenter=0, vmax=2.8)
for i, model in enumerate(models):
    data = monmaxpr_model_mean_list[i].pr
    time_length = len(data.time.data)
    time_step = np.linspace(1, time_length+1, time_length)
    data['time'] = time_step
    trend = data.polyfit(dim='time', deg=1)
    trend = trend.polyfit_coefficients.sel(degree=1)*12*10 # degree per decade
    ax = axes.flatten()[i+5]
    ax.contourf(trend.lon, trend.lat, trend, transform=cartopy.crs.PlateCarree(), 
                norm=norm, cmap='BrBG', levels=30)
    # ax.set_title(models[i]+' '+"{:.2f}".format(ecs[models[i]]))
    ax.add_feature(cartopy.feature.COASTLINE, alpha=.5)
    # plt.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm, cmap='coolwarm'), ax=ax, extend='both', shrink=.5)
cbar_ax = fig.add_axes([.91, 0.42, 0.008, 0.15])
cb = fig.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm, cmap='BrBG'), cax=cbar_ax, extend='both')
print(np.max(trend).data, ' ', np.min(trend).data)
norm = colors.TwoSlopeNorm(vmin=-.5, vcenter=0, vmax=.5)
for i, model in enumerate(models):
    data = tos_model_mean_list[i].tos
    time_length = len(data.time.data)
    time_step = np.linspace(1, time_length+1, time_length)
    data['time'] = time_step
    trend = data.polyfit(dim='time', deg=1)
    trend = trend.polyfit_coefficients.sel(degree=1)*12*10 # degree per decade
    ax = axes.flatten()[i+10]
    ax.contourf(trend.lon, trend.lat, trend, transform=cartopy.crs.PlateCarree(), 
                norm=norm, cmap='coolwarm', levels=30)
    # ax.set_title(models[i]+' '+"{:.2f}".format(ecs[models[i]]))
    # plt.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm, cmap='coolwarm'), ax=ax, extend='both', shrink=.5)
cbar_ax = fig.add_axes([.91, 0.15, 0.008, 0.15])
cb = fig.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm, cmap='coolwarm'), cax=cbar_ax, extend='both')
print(np.max(trend).data, ' ', np.min(trend).data)
plt.subplots_adjust(hspace=0, wspace=0.05)
plt.savefig('model-intercomparison.png', dpi=160, bbox_inches='tight')
plt.show()