# Analysis notebook for PiC_UVnudge runs
## Set up
### Packages

In [1]:
import numpy as np
import xarray as xr
import xesmf as xe
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import pandas as pd
import scipy
from scipy import stats, interpolate
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.mathtext import _mathtext as mathtext
import matplotlib.ticker as mticker
from matplotlib import gridspec, animation
import matplotlib.path as mpath
import matplotlib.colors as colors
import matplotlib.dates as mdates
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import warnings
warnings.simplefilter('ignore', UserWarning)
warnings.filterwarnings('ignore')
import datetime as dt
from datetime import timedelta
from cmcrameri import cm
from Processing_functions import FixLongitude, Wilks_pcrit
import jinja2
import cftime
import dask
from dask_jobqueue import PBSCluster
from dask.distributed import Client
from functools import partial
from collections import defaultdict
import os

In [2]:
## Plot types to make - CHANGE
# 0: Weighted spatial mean or Arctic sea ice area
# 1: Spatial maximum (AMOC)
# 2: Leave alone (doing spatial map or sea ice concentration)
# 3: Zonal
plots = {
    'amoc': [False, 1],
    'toa': [False, 0],
    'map': [False, 2],
    'ts': [False, 0],
    'sia': [False, 0],
    'mtrd': [False, 0],
    'strd': [True, 2],
    'zon': [False, 3],
    'ztrd': [False, 3]
}

## Categorical plot type - DO NOT CHANGE
plot_types = {
    'spatial': [False, []],
    'line': [False, []],
    'mtrd': [False, []],
    'zonal': [False, []]
}

# Set up plot_types based on plots
for pl, att in plots.items():
    if att[0]:
        if pl == 'mtrd':
            plot_types['mtrd'][0] = True
            plot_types['mtrd'][1].append(pl)
        elif att[1] <= 1:
            plot_types['line'][0] = True
            plot_types['line'][1].append(pl)
        elif att[1] == 2:
            plot_types['spatial'][0] = True
            plot_types['spatial'][1].append(pl)
        elif att[1] == 3:
            plot_types['zonal'][0] = True
            plot_types['zonal'][1].append(pl)

# Spatial domain - CHANGE s_domain & t_domain only
s_domain = False # True: Global, False: Arctic
s_domain = True if (plots['toa'][0] or plots['amoc'][0]) else s_domain # Make sure TOA is global domain
a_domain = plot_types['spatial'][0] or plot_types['zonal'][0] # True: 50-90, False: 70-90
t_domain = 1980 # start year

## Time averaging type - CHANGE
time_avg = 2   # 0: Monthly, 1: Yearly, 2: Seasonal, 3: All data, 4: Timeseries

## Ensemble mean or All members - CHANGE
ens_type = 1   # 0: All_members, 1: Mean

# Variables - CHANGE
comp = 'ice' # compset
freq = 0 # 0: monthly, 1: daily
var_ind = 0

# DO NOT CHANGE
var_list = {'atm': ['TREFHT','PSL','RESTOM','Z3'],
            'ice': ['aice'],
            'ocn': ['MOC']}
var_ext = {0: '', 1: '_d'}
var = var_list[comp][var_ind]+var_ext[freq]

# Plot levels for spatial trends - CHANGE
plot_levels = [300, 500, 850, 925]

In [3]:
## Test numbers - DO NOT CHANGE
tst_nums = np.arange(1,4)

## Test names
# O: LENS ensemble
# 1: PiC_UVnudge ensemble
# 2: PiC_UVnudge single run
# 3: observations
# attribute structure: [use dataset, dataset type, line color, line style, zorder]
# CHANGE ONLY - use dataset
ds_names = {
    'LENS2 piControl': [True, 0],
    'PiC_UVnudge': [False, 1],
    'PiC_UVnudge_LM': [False, 1],
    'PiC_UVnudge_MM': [False, 1],
    'PiC_UVnudgenew': [False, 2],
    'PiC_UVnudge_2006': [True, 1],
    'PiC_UVnudge_LM2006': [True, 1],
    'PiC_UVnudge_MM2006': [True, 1],
    'PiC_UVnudge_1988': [False, 2],
    'PiC_UVnudge_2006_2000': [False, 2],
    'ERA5': [not (plots['amoc'][0] or plots['toa'][0]), 3],
    'GISTEMP': [plots['ts'][0] or (plots['mtrd'][0] and var == 'TREFHT'), 3]
}
vercompres = 'b.e21.B1850cmip6.f09_g17.'
cesm2piC = 'b.e21.B1850.f09_g17.'

## Filepaths - DO NOT CHANGE
path_to_work = '/glade/work/glydia/'
path_to_lensdata = path_to_work+'processed_CESM2_LENS_data/'
path_to_expdata = path_to_work+'Arctic_controls_processed_data/'
path_to_plotdata  = path_to_expdata+'plotting_data/'

# Extensions - DO NOT CHANGE
h_ext = {'atm': ['.h0.'],
       'ice': ['.h.','.h1.'],
       'ocn': ['.h.']}
yr_extn = {False: [".195001-202312.",".19500101-20231231."],
           True: [".*.",".05010101-05741231."]}
vert_lev = {'atm': [False,False,False,True],
            'ice': [False],
            'ocn': [False]}
file_bool = not vert_lev[comp][var_ind] and freq == 0
file_ext = {True: 'nc', False: 'zarr'}

In [4]:
########################## DO NOT CHANGE ANYTHING BELOW THIS LINE #############################

In [5]:
%%time
    
## Select plot type
time_str_list = {0: 'month', 1: 'year', 2: 'season', 3: 'all', 4: 'timeseries'}
time_outstr = time_str_list[time_avg]

## Select ensemble type
ens_str_list = {0: 'All_members', 1: 'Mean'}
ens_str = ens_str_list[ens_type]

## Select time and spatial domain strings
sd_str_list = {True: 'Global', False: 'Arctic'}
sd_str = sd_str_list[s_domain]
td_str = str(t_domain)

CPU times: user 5 µs, sys: 1 µs, total: 6 µs
Wall time: 7.39 µs


In [6]:
## Print Script Configurations
# Plot types
print('Plot types: ')
catcount = 0

for cat, att in plot_types.items():
    if att[0]:
        print('   '+cat)
        catcount += 1
        for pl in att[1]:
            print('      '+pl)

# Variable
if var == 'Z3':
    print('Variable: Z3, U, V')
else:
    print('Variable(s): '+var)

# Datasets
print('Datasets:')
for dsname, attr in ds_names.items():
    if attr[0]:
        print('   '+dsname)

# Spatial domain
print('Spatial domain: '+sd_str)

# Time domain
print('Time domain: '+td_str+'-2023')

# Time averaging
print('Time averaging: '+time_outstr)

# Ensemble
print('Ensemble averaging: '+ens_str)

Plot types: 
   spatial
      strd
Variable(s): aice
Datasets:
   LENS2 piControl
   PiC_UVnudge_2006
   PiC_UVnudge_LM2006
   PiC_UVnudge_MM2006
   ERA5
Spatial domain: Arctic
Time domain: 1980-2023
Time averaging: season
Ensemble averaging: Mean


In [7]:
## Check number of different averaging types
if catcount > 1:
    raise SystemExit("ERROR: More than one averaging type fed into CreateMasterDS()")

### Cluster

In [8]:
cluster = PBSCluster(cores    = 1,
                     memory   = '50GiB',
                     queue    = 'casper',
                     walltime = '12:00:00',
                     account  = 'UCUB0155',
                     name='PiC_UVnudge_process_'+var)
cluster.scale(4*9)
client = Client(cluster)

In [9]:
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/glydia/Arctic_breakdown/proxy/36821/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/glydia/Arctic_breakdown/proxy/36821/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://128.117.208.97:43471,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/glydia/Arctic_breakdown/proxy/36821/status,Total threads: 0
Started: Just now,Total memory: 0 B


### Custom functions

#### pre_process

In [10]:
def pre_process(da):
    da = da.assign_coords(time=pd.date_range('1950-01-01','2023-12-01',freq='MS'))
    return da

#### LoadData

In [11]:
def LoadData(casename, set_type, varname):
    # Condition for only using 501-574 from LENS2 piControl, i.e. when plotting Z3
    cond_lens = var != 'Z3' # True means use 50 slices, False means use 501-574 only
    
    # Create file name
    if set_type == 0:
        # If cond_lens true, select all netcdf files with *
        # Else, select zarr file with 1950-2023 time string
        filename = cesm2piC+'CMIP6-piControl.001'+h_ext[comp][freq]+varname+yr_extn[cond_lens][freq]+file_ext[file_bool]
        totalpath = path_to_lensdata+filename
    elif set_type == 3:
        filename = casename+h_ext[comp][freq]+varname+yr_extn[False][freq]+file_ext[file_bool]
        totalpath = path_to_work+'processed_'+casename+'_data/'+filename
    else:
        filename = vercompres+casename+h_ext[comp][freq]+varname+yr_extn[False][freq]+file_ext[file_bool]
        totalpath = path_to_expdata+'processed_'+casename+'_data/'+filename

    # Load if piControl and using slices
    if set_type == 0 and cond_lens: 
        data = xr.open_mfdataset(totalpath, combine='nested',concat_dim='slice',preprocess=pre_process)  
    
    # Load if netCDF and not piControl
    elif file_bool:
        data = xr.open_dataset(totalpath, chunks={'time':12})  

    # Load if Zarr
    else:
        data = xr.open_zarr(totalpath, group=varname,  chunks={'time':12})

    return data

In [12]:
# Set spatial domain slicing
slice_ocn = dict(lat_aux_grid=slice(33,55),moc_z=slice(800*100,2200*100), moc_comp=0,transport_reg=1)
slice_icemod = dict(nj=slice(250,385))
slice_iceobs = dict(lat=slice(50,90))
slice_atmwei = dict(lat=slice(70,90))
slice_atmspt = dict(lat=slice(50,90))
slice_time = dict(time=slice(td_str+'-01-01','2023-12-31'))

#### CreateMasterDS

In [13]:
def CreateMasterDS(varname):
    ds_load_list = []
    pd_time = pd.date_range('1950-01-01','2023-12-01',freq='MS')

    # Find which plots will be made
    funcstr = None
    for plotname, attrs in plots.items():
        # If plot will be made
        if attrs[0]:
            # If weighted spatial average or Arctic sea ice area
            if attrs[1] == 0:
                funcstr = 'S' if varname == 'aice' else 'W'
                break
            # Elif spatial maximum
            elif attrs[1] == 1:
                funcstr = 'M'
                break

            # Elif zonal average
            elif attrs[1] == 3:
                funcstr = 'Z'
                break

    print('Figure out processing calculations to do')

    ## Load all datasets
    # Load each dataset
    for dsname, attrs in ds_names.items():
        # If dataset is going to be plotted
        if attrs[0]:
            print(dsname)
            # If dataset is ERA5
            if dsname == 'ERA5':
                # If using U or V variable, use Target_U & Target_V from PiC_UVnudge
                if varname == 'U' or varname == 'V':
                    varname = 'Target_'+varname
                    ds_load = LoadData('PiC_UVnudge', 1, varname)
                    ds_load = ds_load.mean('ensemble_member')
                else:
                    ds_load = LoadData(dsname, attrs[1], varname)
                    ds_load = CalcGridArea(ds_load)
            else:
                ds_load = LoadData(dsname, attrs[1], varname)
    
            # Reduce to dataarray
            ds_load = ds_load[varname]
        
            # Rename dataarray name to be casename
            ds_load = ds_load.rename(dsname)
            
            print('  Initial data loading complete')

            # Make all datasets have the same time axis
            ds_load = ds_load.assign_coords(time=pd_time)

            print('  Fixed time dimension')
                

            ## Do weighted averages, ensemble means, sea ice area calculations, and area maximums
            # Do spatial domain slicing
            # ICE
            if comp == 'ice':
                ds_load = ds_load.loc[slice_iceobs] if attrs[1] == 3 else ds_load.loc[slice_icemod]
            # OCEAN
            elif comp == 'ocn':
                ds_load = ds_load.loc[slice_ocn].drop('moc_components')
            # ATMOSPHERE
            else: 
                if not s_domain and a_domain:
                    ds_load = ds_load.loc[slice_atmspt]
                elif not s_domain and not a_domain:
                    ds_load = ds_load.loc[slice_atmwei]

                # Remove cyclic point if not ERA5
                if dsname != 'ERA5':
                    ds_load = ds_load.where(ds_load.lon != 180., drop=True)

            print('  Sliced data')
                    
            # Sort processing by variable and graph type
            # Doing spatial plot or SIC
            if funcstr == None or funcstr == 'Z':
                print('  No proceesing, spatial')
                # If ERA5, rename lat/lon
                if dsname == 'ERA5':
                    ds_load = ds_load.rename({'lat':'latE', 'lon':'lonE'})
                    if varname == 'Z3':
                        ds_load = ds_load.rename({'lev':'plev'})

                if funcstr == 'Z':
                    print('  Calculating zonal average')
                    if dsname == 'ERA5':
                        ds_load = ds_load.mean('lonE', skipna=True)
                    else:
                        ds_load = ds_load.mean('lon', skipna=True)
                else:
                    print('  No proceesing, spatial')
                
                    
            # Doing SIA plot(s)
            elif funcstr == 'S':
                print('  Calculating sea ice area')
                ds_load = CalcSIA(ds_load, attrs[1])
            # Doing ts or TOA plot(s)
            elif funcstr == 'W':
                print('  Calculating weighted average')
                ds_load = CalcWeightedMean(ds_load)
            # Doing AMOC plots:
            elif funcstr == 'M':
                print('  Calculating maximum')
                ds_load = ds_load.max(('moc_z','lat_aux_grid'))
                
            # If Ensemble mean
            if ens_type and attrs[1] == 1:
                print('  Calculating ensemble mean')
                ds_load = ds_load.mean('ensemble_member')

            ds_load = ds_load.rename(dsname)

            print('  Processing on dataset complete')
            ds_load_list.append(ds_load)
    
    # Merge dataarrays
    ds_proc = xr.merge(ds_load_list)

    print('All datasets merged')

    # Add LENS2 ensemble mean if not 3D variable
    ds_proc['LENS2 piControl mean'] = ds_proc['LENS2 piControl'].mean('slice')

    # Slice time
    ds_proc = ds_proc.loc[slice_time]
    
    return ds_proc

#### CalcWeightedMean

In [14]:
def CalcWeightedMean(ds):
    # Set up
    avg_dim = ('lon','lat')

    # Create weights
    weights = np.cos(np.deg2rad(ds.lat))
    weights.compute()

    # Weight data
    ds_w = ds.weighted(weights)

    # Calculate weighted mean
    ds_mean_w = ds_w.mean(avg_dim, skipna=True)

    ds_mean_w.compute()
    
    return ds_mean_w

#### CalcAnom

In [15]:
def CalcAnom(ds_mean_w, period):
    # Calculate period mean
    ds_p_mean_w = ds_mean_w.mean(period,skipna=True)

    # Calculate anomalies
    ds_anom_w = ds_mean_w-ds_p_mean_w

    return ds_anom_w       

#### CalcDetAnom

In [16]:
def CalcDetAnom(ds_anom_w, period):
    if period == 'time':
        raw_time = ds_anom_w.time  
        ds_anom_w = AddCoordTrend(ds_anom_w, period)
        
    # Calculate linear regression coefficients
    ds_reg_coef = ds_anom_w.polyfit(dim=period,deg=1)

    # Calculate y-values of linear regression
    ds_yreg = (ds_reg_coef.loc[dict(degree=1)]*ds_anom_w[period])+ds_reg_coef.loc[dict(degree=0)]

    # Calculate detrended anomalies
    ds_dtrd_anom = ds_anom_w-ds_yreg['polyfit_coefficients']

    if period == 'time':
        ds_dtrd_anom = ds_dtrd_anom.assign_coords(time=raw_time)
    
    return ds_dtrd_anom

#### CalcR

In [17]:
def CalcR(ds_obs, ds_mod, period, ens):
    if ens == 1:
        # Calculate model ensemble mean
        ds_mod_mean = ds_mod.mean('ensemble_member')
    elif ens == 2:
        ds_mod_mean = ds_mod
    
    # Calculate correlation coefficient for anomalies (detrended and not)
    ds_r = xr.corr(ds_obs, ds_mod_mean, dim=period)
    
    return ds_r

#### CalcEnsSp

In [18]:
def CalcEnsSp(ds_mean_w):
    # Calculate ensemble spread
    ds_ensp = np.sqrt(ds_mean_w.var('ensemble_member'))

    return ds_ensp

#### DropNonPiC

In [19]:
def DropNonPiC(ds):
    # Drop all variables and dimensions not used by PiC_UVnudge
    var_set = set(ds.keys())

    non_pic_vars = []
    non_pic_dims = []
    for v in var_set:
        if 'PiC_UVnudge' not in v:
            non_pic_vars.append(v)

            if v == 'LENS2 piControl':
                non_pic_dims.append('slice')
        
    ds_d = ds.drop_vars(non_pic_vars)
    ds_d = ds_d.drop_dims(non_pic_dims)

    return ds_d

#### NudgeYears

In [20]:
def NudgeYears(ds_puv, time_type):
    numyr_range = np.arange(0.0,74.0)
    numoffset = {'PiC_UVnudge': 0.0,
                 'PiC_UVnudge_LM': 0.0,
                 'PiC_UVnudge_MM': 0.0,
                 'PiC_UVnudge_2006': 56.0,
                 'PiC_UVnudge_LM2006': 56.0,
                 'PiC_UVnudge_MM2006': 56.0,
                 'PiC_UVnudgenew': 0.0,
                 'PiC_UVnudge_1988': 38.0,
                 'PiC_UVnudge_2006_2000': 106.0}
    
    # Re-assign time coordinate to be number of years nudged
    yrsnud_list = []
    for dsname, da in ds_puv.items():
        if dsname in numoffset.keys():
            da_yrsnud = da.assign_coords({time_type: numyr_range+numoffset[dsname]})
            yrsnud_list.append(da_yrsnud.rename(dsname))

    ds_yrsnud = xr.merge(yrsnud_list)
    return ds_yrsnud

#### CalcMonthTrd

In [21]:
def CalcMonthTrd(weighted_avg):
    # Group by month & calculate trends
    grouped = weighted_avg.groupby('time.month')

    ds_list = []
    for mon, dsmon in grouped:
        ds_sizes = dsmon.sizes
        new_time = np.arange(1,ds_sizes['time']+1)
        
        dsmon = dsmon.assign_coords(time=new_time)
        dsmon = dsmon.polyfit(dim='time',deg=1)
        dsmon *= 10
        ds_list.append(dsmon)

    month_trd = xr.concat(ds_list, pd.Index(np.arange(1,13),name='month'))
    
    return month_trd['polyfit_coefficients'].loc[dict(degree=1)]

#### CalcAnnTrd

In [22]:
def CalcAnnTrd(weighted_avg):
    # Group by year & get annual means
    ann_avg = weighted_avg.groupby('time.year').mean('time')

    # Calculate annual trend
    ann_trd = ann_avg.polyfit(dim='year',deg=1)
    ann_trd *= 10
    

    return ann_trd['polyfit_coefficients'].loc[dict(degree=1)]

#### Regrid

In [23]:
def Regrid(ds_timeavg, regridder, regrid_type, pvals=False):
    # regrid_type: 'sic' or 'era'
    sic_cond = (regrid_type == 'sic')
    era_cond = (regrid_type == 'era')

    if sic_cond:
        print('Regridding CICE grid -> ATM grid...')
    if era_cond:
        print('Regridding ERA5 grid -> ATM grid...')

    # Do fillna if regridding sic
    if sic_cond:
        nval = 0.99 if pvals else 0.000001

    # If regridding ERA5 DataArray
    if era_cond and type(ds_timeavg) == xr.core.dataarray.DataArray:
        da = ds_timeavg.rename({'latE': 'lat', 'lonE': 'lon'})
        da_re = regridder(da)
        da_re = da_re.rename({'x':'lon', 'y':'lat'})
        da_regrid = da_re.rename('ERA5')

        return da_regrid
        

    # Else, regridding dataset
    else:
        # Regrid data
        regrid_list = []
        for dsname, da in ds_timeavg.items():
            cond = {'sic': dsname == 'ERA5',
                    'era': dsname != 'ERA5'}
            # If SIC regridding & dataset name is ERA5, don't regrid ERA5
            # If ERA regridding & dataset name is not ERA5, don't regrid, just append
            if cond[regrid_type]:
                # Don't regrid, just add for era regridding
                if era_cond:
                    da = da.assign_coords({'lon':lons, 'lat': lats})
                    regrid_list.append(da.rename(dsname))
            # Ignore pcrit variables
            elif ' pcrit' in dsname:
                regrid_list.append(da)
            
            else:
                print('Regridding '+dsname)
    
                # Need to rename for era
                if era_cond:
                    da = da.rename({'latE': 'lat', 'lonE': 'lon'})
                da_re = regridder(da)
        
                # Rename x, y to be lon, lat and reassign lon and lat data (which includes cyclic point)
                da_re = da_re.rename({'x':'lon', 'y':'lat'})
    
                # Only fillna for sic
                if sic_cond:
                    da_re = da_re.assign_coords({'lon':lons, 'lat': lats})
                    da_re = da_re.fillna(nval)
                    
                regrid_list.append(da_re.rename(dsname))
    
        # Fill na for ERA5 only if sic & not pvals 
        if not pvals and sic_cond:
            regrid_list.append(ds_timeavg['ERA5'].fillna(nval))
        
        ds_regrid = xr.merge(regrid_list, join='left')
    
        return ds_regrid

#### CalcGridArea

In [24]:
def CalcGridArea(ds):
    # For ERA5 0.25x0.25 grid
    dlat = 0.25
    dlon = 0.25
    R = 6367.47 # km

    lons, lats = np.meshgrid(ds.lon.values, ds.lat.values)

    dy = R*np.deg2rad(dlat)
    dx = np.deg2rad(dlon)*R*np.cos(np.deg2rad(lats))

    area = np.abs(dx*dy)
    ds = ds.assign_coords(area=(('lat','lon'), area))

    return ds

#### CalcSIA

In [25]:
def CalcSIA(ds, set_type):
    # Extracts aice variable and only counts cells with sic > 15%
    ds_aice = ds
    ds_aice = xr.where(ds_aice > .15,1,0)

    # Multiples selected sic cells with tarea, then only selects Arctic sea ice, sums over the entire domain, then converts to km2
    if set_type == 3:
        dsa = (ds_aice*ds['area']).sum(dim=['lon','lat'])*1e-6
    else:
        dsa = (ds_aice*ds['tarea']).sum(dim=['ni','nj'])*1e-6*1e-6

    # Modifies attributes of DataArray accordingly
    dsa.attrs['units'] = 'million km^2'
    dsa.attrs['long_name'] = 'Arctic sea ice area'

    # Assigns new variable to dataset and returns original dataset
    return dsa

#### CalcTrend

In [26]:
def CalcTrend(da, time_type):
    slope, intercept, rval, pval, stderr = xr.apply_ufunc(stats.linregress,
                                                          da[time_type], da,
                                                          input_core_dims=[[time_type], [time_type]],
                                                          output_core_dims=[[],[],[],[],[]],
                                                          vectorize=True,
                                                          dask='parallelized',
                                                          dask_gufunc_kwargs=dict(allow_rechunk=True))
    slope = slope*10
    return slope

#### AddCoordTrend

In [27]:
def AddCoordTrend(da, time_type):
    da_sizes = da.sizes
    new_time = np.arange(1,da_sizes[time_type]+1)
    if time_type == 'time':
        da = da.assign_coords(time=new_time)
    elif time_type == 'year':
        da = da.assign_coords(year=new_time)
    return da

#### AddCyclic

In [28]:
def AddCyclic(da: xr.DataArray, londim) -> xr.DataArray:
    # Add cyclic point
    cyclic_data, cyclic_lon = add_cyclic_point(da.data, coord=da[londim])
    cyclic_coords = {dim: da.coords[dim] for dim in da.dims}
    cyclic_coords[londim] = cyclic_lon

    da = xr.DataArray(cyclic_data, dims=da.dims, coords=cyclic_coords, attrs=da.attrs, name=da.name)
    return da

#### VertLENS2StatsLoop

In [29]:
def VertLENS2StatsLoop(ds_plens, time_type, varname):
    ds_plens = ds_plens.groupby('plev')
    
    lev_stats_list = []
    
    for lev, dlev in ds_plens:
        path = path_to_plotdata+'Map.Trend.LENSstats.'+varname+'.'+sd_str+'.Z'+str(int(lev))+'.'+td_str+'.'+time_outstr+'.'+ens_str+'.nc'
        if os.path.isfile(path):
            print('Loading LENS2 piControl trend distribution for Z'+str(int(lev)))
            dlev_stats = xr.open_dataset(path)
        else:
            print('Calculating spatial trends for '+str(lev)+' hPa')
            dlev_stats = LENS2TrendsEnsemble('Map.Trend',dlev, time_type, varname, True)
            SaveData(dlev_stats, graph_type_str+'.LENSstats', varname, time_outstr, lev)
            
        lev_stats_list.append(dlev_stats)

    ds_plens_stats = xr.concat(lev_stats_list, dim='plev')
    
    return ds_plens_stats

#### LENS2TrendsEnsemble

In [30]:
def LENS2TrendsEnsemble(graph_type, ds_lens, time_type, varname, loop):
    path = path_to_plotdata+graph_type+'.LENSstats.'+varname+'.'+sd_str+'.'+td_str+'.'+time_outstr+'.'+ens_str+'.nc'
    if os.path.isfile(path) and not loop:
        print('Loading LENS2 piControl trend distribution')
        ds_ens_stats = xr.open_dataset(path)
    
    else:
        print('Calculating LENS2 piControl trend distribution')
        # Calculating averages
        if time_avg == 0:
            ds_lens = ds_lens.groupby('time.month')
            time_list = np.arange(1,13)
            
        # Seasonal averaging
        elif time_avg == 2:
            ds_lens = ds_lens.resample(time='QS-DEC').mean('time')
            ds_lens = ds_lens.groupby('time.month')
            time_list = np.arange(3,13,3)
        
        # Calculate trends
        trends_avg_list = []
        trends_std_list = []
    
         # Cycle through months/seasons
        for m, ds in ds_lens:
            print('Spatial trends for '+str(m))
    
            # Add linear time coordinate for trends - doesn't work with datetime time coordinate
            ds = AddCoordTrend(ds, time_type)
    
            print('  Calculating trends for LENS2 piControl')
            ds_gp_trend = CalcTrend(ds.load(), time_type)
            ds_mean = ds_gp_trend.mean('slice', skipna = True)
            ds_std = ds_gp_trend.std('slice', skipna = True)          
                
            trends_avg_list.append(ds_mean.rename('LENS2 piControl mu'))
            trends_std_list.append(ds_std.rename('LENS2 piControl sigma'))
    
        ds_ens_mu = xr.concat(trends_avg_list, dim=pd.Index(time_list, name=time_outstr))
        ds_ens_sigma = xr.concat(trends_std_list, dim=pd.Index(time_list, name=time_outstr))
        ds_ens_stats = xr.merge([ds_ens_mu, ds_ens_sigma])

        if not loop:
            SaveData(ds_ens_stats, graph_type_str+'.LENSstats', varname, time_outstr)
            

    return ds_ens_stats

#### VertSptrendsLoop

In [31]:
def VertSptrendsLoop(ds_sptpl, ds_plens_ens, time_type, varname):
    ds_sptpl = ds_sptpl.groupby('plev')
    ds_plens_ens = ds_plens_ens.groupby('plev')
    
    lev_spt_list = []
    lev_pval_list = []
    for lev, dsp_lev in ds_sptpl:
        dlen_lev = ds_plens_ens[lev]
        print('Calculating spatial trends for '+str(lev)+' hPa')
        dlev_trends, dlev_pval = SpatZonTrends(dsp_lev, dlen_lev, time_type, varname)
        lev_spt_list.append(dlev_trends)
        lev_pval_list.append(dlev_pval)
        
    ds_pltrends = xr.concat(lev_spt_list, dim='plev')
    ds_plpval = xr.concat(lev_pval_list, dim='plev')

    return ds_pltrends, ds_plpval

#### SpatZonTrends

In [32]:
def SpatZonTrends(ds, ds_lens_ens, time_type, varname):
    ds_spz = ds.drop_vars(['LENS2 piControl'])
    ds_spz = ds_spz.drop_dims('slice')

    # Bool for spatial/zonal differentiation
    dimlist = ds_spz['ERA5'].dims
    zonspt_bool = 'lonE' in dimlist # True: spatial, False: zonal
    
    # Varname bool
    var_bool = varname == 'Z3'
    
    # Calculating averages
    if time_avg == 0:
        ds_spz = ds_spz.groupby('time.month')
        
    # Seasonal averaging
    elif time_avg == 2:
        ds_spz = ds_spz.resample(time='QS-DEC').mean('time')
        ds_spz = ds_spz.groupby('time.month')
    
    # Calculate trends
    trends_tm_list = []
    pvals_tm_list = []
    
    # Cycle through months/seasons
    for m, ds in ds_spz:
        
        print('Trends for '+str(m))
        trend_ds_list = []
        pval_ds_list = []

        # Add linear time coordinate for trends - doesn't work with datetime time coordinate
        ds = AddCoordTrend(ds, time_type)

        # Index LENS stats
        ds_stats = ds_lens_ens.loc[{time_outstr: m}]

        # Cycle through all variables
        for dsname, da in ds.items():
            print('  Calculating trends for '+dsname)
            da_gp_trend = CalcTrend(da.load(), time_type) 
            
            trend_ds_list.append(da_gp_trend.rename(dsname))

            era_bool = dsname == 'ERA5'

            # Calculate p-val and pcrit against LENS2
            # If LENS2 ensemble mean or ERA5 & 2D variable, skip p-value calculation
            if dsname != 'LENS2 piControl mean' and (not era_bool or ((not file_bool and zonspt_bool) or (not var_bool and not zonspt_bool))):
                # If 3D variable & ERA5, regrid data to CESM2
                if era_bool and zonspt_bool:
                    # If Z3 (i.e. not on CESM2 grid)
                    if var_bool:
                        da_gp_trend = Regrid(da_gp_trend, regridderERA, 'era')
                    else:
                        da_gp_trend = da_gp_trend.rename({'latE': 'lat', 'lonE': 'lon'})
                # If ERA5 & zonal (and thus not including Z3)
                elif era_bool and not zonspt_bool:
                    da_gp_trend = da_gp_trend.rename({'latE': 'lat'})
                    
                # Calculate p-value and p-critical value
                print('  Calculating p-values and p-critical values for '+dsname)
                da_pval = CalcSlopeSig(ds_stats, da_gp_trend)
                
                # Rename
                da_pval = da_pval.rename(dsname)
                da_pval = da_pval.to_dataset()
                da_pval = da_pval.assign({dsname+' pcrit': da_pval['pcrit']})
                da_pval = da_pval.drop_vars('pcrit')
                pval_ds_list.append(da_pval)
            
        ds_trend_slice = xr.merge(trend_ds_list, compat='no_conflicts')
        trends_tm_list.append(ds_trend_slice)
        
        ds_pval_slice = xr.merge(pval_ds_list, compat='no_conflicts')
        pvals_tm_list.append(ds_pval_slice)

    ds_trend = xr.concat(trends_tm_list, dim=pd.Index(date_str, name=time_outstr))
    ds_pval = xr.concat(pvals_tm_list, dim=pd.Index(date_str, name=time_outstr))

    return ds_trend, ds_pval

#### CalcSlopeSig

In [33]:
def CalcSlopeSig(ds_lens, da):
    # Extract statistics
    mu_lens = ds_lens['LENS2 piControl mu']
    sig_lens = ds_lens['LENS2 piControl sigma']
    
    # Z statistic
    zstat = (da-mu_lens)/sig_lens

    # Calculate p-value
    pval = stats.norm.sf(abs(zstat))*2

    # Add p-val to DataArray with new data
    da_pval = da.copy(data=pval)

    # Calculate p_critical
    pcrit = Wilks_pcrit(pval, 0.05)

    da_pval['pcrit'] = pcrit
    print(pcrit)
    
    return da_pval

#### PattCorr

In [34]:
def PattCorr(ds_trend):
    ## Calculation pattern correlation coefficient
    ds_trend = ds_trend.groupby(time_outstr)

    # Calculate pattern correlation
    pcorr_m_list = []

    for m, ds in ds_trend:
        print('Pattern correlation for '+str(m))
        pcorr_ds_list = []

        # Pull out ERA5
        da_era = ds['ERA5']

        weights = np.cos(np.deg2rad(ds.lat))

        for dsname, da in ds.items():
            if dsname != 'ERA5':
                # Calculate corr
                da_pcorr = xr.corr(da_era, da, weights=weights)

                pcorr_ds_list.append(da_pcorr.rename(dsname))

        ds_pcorr_slice = xr.merge(pcorr_ds_list)
        pcorr_m_list.append(ds_pcorr_slice)


    ds_pcorr = xr.concat(pcorr_m_list, dim=pd.Index(date_str, name=time_outstr))
    return ds_pcorr

#### SpatTrendRatio

In [35]:
# def SpatTrendRatio(ds_trend, rat_type):
#     ## warm: contribution to obs warming trends only
#     ## cold: contribution to obs cooling trends only

    

#### SpatZonAvg

In [36]:
def SpatZonAvg(ds):
    # Calculating averages
    if time_avg == 0:
        ds_avg = ds.groupby('time.month').mean('time')
        
    # Seasonal averaging
    elif time_avg == 2:
        ds_avg = ds.resample(time='QS-DEC').mean('time')
        ds_avg = ds_avg.groupby('time.month').mean('time')
        ds_avg = ds_avg.assign_coords(month=seas_str)
        ds_avg = ds_avg.rename({'month':'season'})

    return ds_avg

#### SpatialVar

In [37]:
def SpatialVar(ds):
    # Calculating variances
    if time_avg == 0:
        ds_var = ds.groupby('time.month').var('time')
        
    # Seasonal variances
    elif time_avg == 2:
        ds_avg = ds.resample(time='QS-DEC').mean('time')
        ds_var = ds_avg.groupby('time.month').var('time')
        ds_var = ds_var.assign_coords(month=seas_str)
        ds_var = ds_var.rename({'month':'season'})

    return ds_var

#### AddAllCyclic

In [38]:
def AddAllCyclic(ds_trend):
    londim = defaultdict(lambda: 'lon')
    londim['ERA5'] = 'lonE'

    latdim = defaultdict(lambda: 'lat')
    latdim['ERA5'] = 'latE'
    
    cyclic_ds_list = []
        
    for dsname, da in ds_trend.items():
        if 'pcrit' in dsname:
            cyclic_ds_list.append(da.rename(dsname))
        else:
            if file_bool:
                da = da.transpose(time_outstr, latdim[dsname],londim[dsname])
            else:
                da = da.transpose(time_outstr, 'plev',latdim[dsname],londim[dsname])
            da_cyc = AddCyclic(da, londim[dsname])
            
            cyclic_ds_list.append(da_cyc.rename(dsname))
            
    ds_trend = xr.merge(cyclic_ds_list)

    return ds_trend

#### SaveData

In [39]:
def SaveData(ds, plot_type, varname, tavg, plot_level=None):
    level_str = '' if plot_level == None else 'Z'+str(int(plot_level))+'.'
    
    filename = plot_type+'.'+varname+'.'+sd_str+'.'+level_str+td_str+'.'+tavg+'.'+ens_str+'.nc'
    
    print('Saving '+filename)
    # File format is:
    # (plot type, including special averaging, i.e. anomalies).varname.spatialdomain.timedomain.timeaveraging.ensemble_type.nc
    ds.to_netcdf(path_to_plotdata+filename,
                format='NETCDF4')
    return None

### Process and load data

In [40]:
%%time

ds_proc = CreateMasterDS(var)

Figure out processing calculations to do
LENS2 piControl
  Initial data loading complete
  Fixed time dimension
  Sliced data
  No proceesing, spatial
  No proceesing, spatial
  Processing on dataset complete
PiC_UVnudge_2006
  Initial data loading complete
  Fixed time dimension
  Sliced data
  No proceesing, spatial
  No proceesing, spatial
  Calculating ensemble mean
  Processing on dataset complete
PiC_UVnudge_LM2006
  Initial data loading complete
  Fixed time dimension
  Sliced data
  No proceesing, spatial
  No proceesing, spatial
  Calculating ensemble mean
  Processing on dataset complete
PiC_UVnudge_MM2006
  Initial data loading complete
  Fixed time dimension
  Sliced data
  No proceesing, spatial
  No proceesing, spatial
  Calculating ensemble mean
  Processing on dataset complete
ERA5
  Initial data loading complete
  Fixed time dimension
  Sliced data
  No proceesing, spatial
  No proceesing, spatial
  Processing on dataset complete
All datasets merged
CPU times: user 7.0

In [41]:
%%time

# Add U & V variables if doing spatial trends or patterns and original variable is geopotential
if (plot_types['zonal'][0] or plot_types['spatial'][0]) and var == 'Z3':
    # Create U & V datasets
    ds_u = CreateMasterDS('U')
    ds_v = CreateMasterDS('V')

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 6.44 µs


### Plotting set-up

In [42]:
%%time

mon_str = np.array(['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec'])
seas_str = np.array(['MAM','JJA','SON','DJF'])


CPU times: user 17 µs, sys: 0 ns, total: 17 µs
Wall time: 18.6 µs


## Month trend plots
### Set up

In [43]:
if plot_types['mtrd'][0]:
    graph_type_str = 'Linear.Trend'
    ## Select plot type - timeseries or monthly - to make and assign variables accordings
    

    # Timeseries
    if time_avg == 4:
        dim_avg = 'time.month'
        period = 'month'

### Data processing

In [44]:
%%time

if plot_types['mtrd'][0]:

    # Calculate monthly trends
    mon_trd_list = []
    for dsname, da in ds_proc.items():
        da_mon_trd = CalcMonthTrd(da)
        mon_trd_list.append(da_mon_trd.rename(dsname))

    ds_mon_trd = xr.merge(mon_trd_list)
    ds_mon_trd['LENS2 piControl min'] = ds_mon_trd['LENS2 piControl'].min('slice')
    ds_mon_trd['LENS2 piControl max'] = ds_mon_trd['LENS2 piControl'].max('slice')

    # Calculate annual trends
    ann_trd_list = []
    for dsname, da in ds_proc.items():
        da_ann_trd = CalcAnnTrd(da)
        ann_trd_list.append(da_ann_trd.rename(dsname))

    ds_ann_trd = xr.merge(ann_trd_list)
    ds_ann_trd['LENS2 piControl min'] = ds_ann_trd['LENS2 piControl'].min('slice')
    ds_ann_trd['LENS2 piControl max'] = ds_ann_trd['LENS2 piControl'].max('slice')

    # Write out data
    SaveData(ds_mon_trd, graph_type_str, var, 'month')
    SaveData(ds_ann_trd, graph_type_str, var, 'year')

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 5.25 µs


## Line plots
### Set up

In [45]:
if plot_types['line'][0]:
    graph_type_str = 'Linear'
        

    # Time averaging for yearly plot
    if time_avg == 1:
        dim_avg = 'time.year'  
        period='year'
    if time_avg == 4:
        dim_avg = 'time.month'
        period ='time'

### Data processing

In [46]:
%%time

if plot_types['line'][0]:

    ## Absolute (only one that will be used for SIA-one month, TOA, AMOC)
    # Yearly averaging for absolute
    if time_avg == 1:
        ds_abs = ds_proc.groupby(dim_avg).mean('time')

        ds_abs['LENS2 piControl min'] = ds_abs['LENS2 piControl'].min('slice')
        ds_abs['LENS2 piControl max'] = ds_abs['LENS2 piControl'].max('slice')
        
        SaveData(ds_abs, graph_type_str+'.abs', var, period)

    elif time_avg == 4:
        # If only plotting one month from SIA data
        ds_abs = ds_proc.groupby(dim_avg)

        for m, ds in ds_abs:
            ds['LENS2 piControl min'] = ds['LENS2 piControl'].min('slice')
            ds['LENS2 piControl max'] = ds['LENS2 piControl'].max('slice')

            SaveData(ds, graph_type_str+'.abs', var, mon_str[m-1])

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 5.25 µs


In [47]:
%%time

if plot_types['line'][0] and t_domain == 1950:
    # Variables in terms of how many years they've been nudged
    if time_avg == 1:
        ds_pic = DropNonPiC(ds_abs)
        ds_nudyr = NudgeYears(ds_pic, period)

        SaveData(ds_nudyr, graph_type_str+'.nudyr', var, period)

    elif time_avg == 4:
        # If only plotting one month from SIA data
        if plots['sia'][0]:
            for m, ds in ds_abs:
                ds = DropNonPiC(ds)

                ds_nudyr = NudgeYears(ds, period)
                SaveData(ds_nudyr, graph_type_str+'.nudyr', var, mon_str[m-1])

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.01 µs


In [48]:
%%time

if plot_types['line'][0]:
    # If sea ice area, make absolute correlation coefficients
    if time_avg == 4:
        # Calculate R for each month
        for m, ds in ds_abs:
            abs_r_list = []
            era5_abs = ds['ERA5']
            for dsname, attrs in ds_names.items():
                if attrs[0] and (attrs[1] == 1 or attrs[1] == 2):
                    da_abs_r = CalcR(era5_abs, ds[dsname], period, attrs[1])
                    abs_r_list.append(da_abs_r.rename(dsname))
                    
            ds_abs_r = xr.merge(abs_r_list)
            SaveData(ds_abs_r, graph_type_str+'.absR', var, mon_str[m-1])


        # Anomalies
        anom_list = []
        for m, ds in ds_abs:
            da_anom = CalcAnom(ds, period)
            da_anom['LENS2 piControl min'] = da_anom['LENS2 piControl'].min('slice')
            da_anom['LENS2 piControl max'] = da_anom['LENS2 piControl'].max('slice')
            
            anom_list.append(da_anom)
            SaveData(da_anom, graph_type_str+'.anom', var, mon_str[m-1])
            
        ds_anom = xr.concat(anom_list, dim='time')

        # Calculate R for anomalies
        ds_anom = ds_anom.groupby(dim_avg)
        for m, ds in ds_anom:
            anom_r_list = []
            era5_anom = ds['ERA5']
            for dsname, attrs in ds_names.items():
                if attrs[0] and (attrs[1] == 1 or attrs[1] == 2):
                    da_anom_r = CalcR(era5_anom, ds[dsname], period, attrs[1])
                    anom_r_list.append(da_anom_r.rename(dsname))
                    
            ds_anom_r = xr.merge(anom_r_list)
            SaveData(ds_anom_r, graph_type_str+'.anomR', var, mon_str[m-1])

        # Detrended anomalies
        dtrd_anom_list = []
        for m, ds in ds_anom:
            dtrd_m_list = []
            for dsname, da in ds.items():
                da_dtrd = CalcDetAnom(da, period)
                dtrd_m_list.append(da_dtrd)
                
            da_dtrd = xr.merge(dtrd_m_list)
            da_dtrd['LENS2 piControl min'] = da_dtrd['LENS2 piControl'].min('slice')
            da_dtrd['LENS2 piControl max'] = da_dtrd['LENS2 piControl'].max('slice')
            
            dtrd_anom_list.append(da_dtrd)
            SaveData(da_dtrd, graph_type_str+'.anomdtrd', var, mon_str[m-1])
            
        ds_dtrd = xr.concat(dtrd_anom_list, dim='time')

        # Calculate R for detrended anomalies
        ds_dtrd = ds_dtrd.groupby(dim_avg)
        for m, ds in ds_dtrd:
            dtrd_r_list = []
            era5_dtrd = ds['ERA5']
            for dsname, attrs in ds_names.items():
                if attrs[0] and (attrs[1] == 1 or attrs[1] == 2):
                    da_dtrd_r = CalcR(era5_dtrd, ds[dsname], period, attrs[1])
                    dtrd_r_list.append(da_dtrd_r.rename(dsname))
                    
            ds_dtrd_r = xr.merge(dtrd_r_list)
            SaveData(ds_dtrd_r, graph_type_str+'.anomdtrdR', var, mon_str[m-1])
    
    
    # If surface temperature, make anomaly, detrended anomaly (yearly only), and ensemble spread plots
    elif plots['ts'][0]:
        # Calculate R for absolute temperature
        abs_r_list = []
        era5_abs = ds_abs['ERA5']
        for dsname, attrs in ds_names.items():
            if attrs[0] and (attrs[1] == 1 or attrs[1] == 2):
                da_abs_r = CalcR(era5_abs, ds_abs[dsname], period, attrs[1])
                abs_r_list.append(da_abs_r.rename(dsname))
                
        ds_abs_r = xr.merge(abs_r_list)
        SaveData(ds_abs_r, graph_type_str+'.absR', var, period)
        
        # Anomalies
        ds_anom = CalcAnom(ds_abs, period)
        ds_anom['LENS2 piControl min'] = ds_anom['LENS2 piControl'].min('slice')
        ds_anom['LENS2 piControl max'] = ds_anom['LENS2 piControl'].max('slice')
        
        SaveData(ds_anom, graph_type_str+'.anom', var, period)

        # Calculate R for anomalies
        anom_r_list = []
        era5_anom = ds_anom['ERA5']
        for dsname, attrs in ds_names.items():
            if attrs[0] and (attrs[1] == 1 or attrs[1] == 2):
                da_anom_r = CalcR(era5_anom, ds_anom[dsname], period, attrs[1])
                anom_r_list.append(da_anom_r.rename(dsname))
                
        ds_anom_r = xr.merge(anom_r_list)
        SaveData(ds_anom_r, graph_type_str+'.anomR', var, period)

        # Ensemble spread
        enspd_list = []
        for dsname, attrs in ds_names.items():
            if attrs[0] and attrs[1] == 1:
                da_enspd = CalcEnsSp(ds_abs[dsname])
                enspd_list.append(da_enspd.rename(dsname))
                
        ds_enspd = xr.merge(enspd_list)
        SaveData(ds_enspd, graph_type_str+'.espd', var, period)

        # Detrended anomalies only if yearly
        # Detrended anomalies
        dtrd_anom_list = []
        for dsname, da in ds_anom.items():
            da_dtrd = CalcDetAnom(da, period)
            dtrd_anom_list.append(da_dtrd.rename(dsname))

        ds_dtrd = xr.merge(dtrd_anom_list)
        ds_dtrd['LENS2 piControl min'] = ds_dtrd['LENS2 piControl'].min('slice')
        ds_dtrd['LENS2 piControl max'] = ds_dtrd['LENS2 piControl'].max('slice')

        SaveData(ds_dtrd, graph_type_str+'.anomdtrd', var, period)

        # Calculate R for detrended anomalies
        dtrd_r_list = []
        era5_dtrd = ds_dtrd['ERA5']
        for dsname, attrs in ds_names.items():
            if attrs[0] and (attrs[1] == 1 or attrs[1] == 2):
                da_dtrd_r = CalcR(era5_dtrd, ds_dtrd[dsname], period, attrs[1])
                dtrd_r_list.append(da_dtrd_r.rename(dsname))

        ds_dtrd_r = xr.merge(dtrd_r_list)
        SaveData(ds_dtrd_r, graph_type_str+'.anomdtrdR', var, period)

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 5.48 µs


## Create Regridders

In [49]:
%%time
if plot_types['spatial'][0]:
    ## Only run after time averaging!!!
    lats = np.array([50.4188481675393, 51.3612565445026, 52.303664921466, 
        53.2460732984293, 54.1884816753927, 55.130890052356, 56.0732984293194, 
        57.0157068062827, 57.9581151832461, 58.9005235602094, 59.8429319371728, 
        60.7853403141361, 61.7277486910995, 62.6701570680628, 63.6125654450262, 
        64.5549738219895, 65.4973821989529, 66.4397905759162, 67.3821989528796, 
        68.3246073298429, 69.2670157068063, 70.2094240837696, 71.151832460733, 
        72.0942408376963, 73.0366492146597, 73.979057591623, 74.9214659685864, 
        75.8638743455497, 76.8062827225131, 77.7486910994764, 78.6910994764398, 
        79.6335078534031, 80.5759162303665, 81.5183246073298, 82.4607329842932, 
        83.4031413612565, 84.3455497382199, 85.2879581151832, 86.2303664921466, 
        87.17277486911, 88.1151832460733, 89.0575916230366, 90])
    lons = np.array([-180, -178.75, -177.5, -176.25, -175, -173.75, -172.5, -171.25, -170, 
        -168.75, -167.5, -166.25, -165, -163.75, -162.5, -161.25, -160, -158.75, 
        -157.5, -156.25, -155, -153.75, -152.5, -151.25, -150, -148.75, -147.5, 
        -146.25, -145, -143.75, -142.5, -141.25, -140, -138.75, -137.5, -136.25, 
        -135, -133.75, -132.5, -131.25, -130, -128.75, -127.5, -126.25, -125, 
        -123.75, -122.5, -121.25, -120, -118.75, -117.5, -116.25, -115, -113.75, 
        -112.5, -111.25, -110, -108.75, -107.5, -106.25, -105, -103.75, -102.5, 
        -101.25, -100, -98.75, -97.5, -96.25, -95, -93.75, -92.5, -91.25, -90, 
        -88.75, -87.5, -86.25, -85, -83.75, -82.5, -81.25, -80, -78.75, -77.5, 
        -76.25, -75, -73.75, -72.5, -71.25, -70, -68.75, -67.5, -66.25, -65, 
        -63.75, -62.5, -61.25, -60, -58.75, -57.5, -56.25, -55, -53.75, -52.5, 
        -51.25, -50, -48.75, -47.5, -46.25, -45, -43.75, -42.5, -41.25, -40, 
        -38.75, -37.5, -36.25, -35, -33.75, -32.5, -31.25, -30, -28.75, -27.5, 
        -26.25, -25, -23.75, -22.5, -21.25, -20, -18.75, -17.5, -16.25, -15, 
        -13.75, -12.5, -11.25, -10, -8.75, -7.5, -6.25, -5, -3.75, -2.5, -1.25, 
        0, 1.25, 2.5, 3.75, 5, 6.25, 7.5, 8.75, 10, 11.25, 12.5, 13.75, 15, 
        16.25, 17.5, 18.75, 20, 21.25, 22.5, 23.75, 25, 26.25, 27.5, 28.75, 30, 
        31.25, 32.5, 33.75, 35, 36.25, 37.5, 38.75, 40, 41.25, 42.5, 43.75, 45, 
        46.25, 47.5, 48.75, 50, 51.25, 52.5, 53.75, 55, 56.25, 57.5, 58.75, 60, 
        61.25, 62.5, 63.75, 65, 66.25, 67.5, 68.75, 70, 71.25, 72.5, 73.75, 75, 
        76.25, 77.5, 78.75, 80, 81.25, 82.5, 83.75, 85, 86.25, 87.5, 88.75, 90, 
        91.25, 92.5, 93.75, 95, 96.25, 97.5, 98.75, 100, 101.25, 102.5, 103.75, 
        105, 106.25, 107.5, 108.75, 110, 111.25, 112.5, 113.75, 115, 116.25, 
        117.5, 118.75, 120, 121.25, 122.5, 123.75, 125, 126.25, 127.5, 128.75, 
        130, 131.25, 132.5, 133.75, 135, 136.25, 137.5, 138.75, 140, 141.25, 
        142.5, 143.75, 145, 146.25, 147.5, 148.75, 150, 151.25, 152.5, 153.75, 
        155, 156.25, 157.5, 158.75, 160, 161.25, 162.5, 163.75, 165, 166.25, 
        167.5, 168.75, 170, 171.25, 172.5, 173.75, 175, 176.25, 177.5, 178.75])
    
    # NEITHER REGRIDDER INCLUDES CYCLIC POINT!
    
    if var == 'aice':
        ## Create Regridder for CICE grid
        lon2d, lat2d = np.meshgrid(lons, lats)
        
        # Create target grid and sample nj ni grid 
        target_gridSIC = xr.Dataset({'lat': (['y', 'x'], lat2d),'lon': (['y', 'x'], lon2d)})
        ds_samplens = ds_proc['LENS2 piControl mean'][0]
        
        # Create regridder for sea ice
        regridderSIC = xe.Regridder(ds_samplens, target_gridSIC, 'nearest_s2d', reuse_weights=False)
    
    if plots['strd'][0]:
    
        # Create Regridder for ERA5->CESM2
        target_gridERA = xr.Dataset({'lat': ('y', lats), 'lon': ('x', lons)})
        ds_sampera = ds_proc['ERA5'][0]
        ds_sampera = ds_sampera.rename({'latE':'lat', 'lonE': 'lon'})
    
        # Create regridder for ERA5
        regridderERA = xe.Regridder(ds_sampera, target_gridERA, 'bilinear', reuse_weights=False)

CPU times: user 2.25 s, sys: 106 ms, total: 2.36 s
Wall time: 2.6 s


In [50]:
# Spatial plots
if plot_types['spatial'][0]:
    # Monthly
    if time_avg == 0:
        period = 'time'
        date_str = mon_str

    # Seasonally
    elif time_avg == 2:
        period = 'time'
        date_str = seas_str

## Spatial trend plots

### Set up

In [51]:
if plots['strd'][0]:
    graph_type_str = 'Map.Trend'

### Data processing

In [52]:
%%time

if plots['strd'][0]:
    # If not 3D
    if file_bool:
        # Calculate statistics of piControl ensemble
        ds_lens_stats = LENS2TrendsEnsemble(graph_type_str, ds_proc['LENS2 piControl'], period, var, False)
        
        # Calculate trends
        print('Calculating spatial trends in '+var)
        ds_sptrend, ds_sppval = SpatZonTrends(ds_proc, ds_lens_stats, period, var)

    else:
        ds_proc = ds_proc.where(ds_proc.plev.isin(plot_levels), drop=True)  
        ds_u = ds_u.where(ds_u.plev.isin(plot_levels), drop=True)
        ds_v = ds_v.where(ds_v.plev.isin(plot_levels), drop=True)
        
        # Calculate statistics for each level in the piControl ensemble
        ds_lens_stats = VertLENS2StatsLoop(ds_proc['LENS2 piControl'], period, var)
        ds_ulens_stats = VertLENS2StatsLoop(ds_u['LENS2 piControl'], period, 'U')
        ds_vlens_stats = VertLENS2StatsLoop(ds_v['LENS2 piControl'], period, 'V')
        
        # Calculate trends
        print('Calculating spatial trends in '+var)
        ds_sptrend, ds_sppval = VertSptrendsLoop(ds_proc, ds_lens_stats, period, var)
        print('Calculating spatial trends in U')
        ds_usptrend, ds_usppval = VertSptrendsLoop(ds_u, ds_ulens_stats, period, 'U')
        print('Calculating spatial trends in V')
        ds_vsptrend, ds_vsppval = VertSptrendsLoop(ds_v, ds_vlens_stats, period, 'V')

Loading LENS2 piControl trend distribution
Calculating spatial trends in aice
Trends for 3
  Calculating trends for PiC_UVnudge_2006
  Calculating p-values and p-critical values for PiC_UVnudge_2006
0.00014783770321974375
  Calculating trends for PiC_UVnudge_LM2006
  Calculating p-values and p-critical values for PiC_UVnudge_LM2006
0.00013909321403120852
  Calculating trends for PiC_UVnudge_MM2006
  Calculating p-values and p-critical values for PiC_UVnudge_MM2006
5.4586272872836186e-05
  Calculating trends for ERA5
  Calculating trends for LENS2 piControl mean
Trends for 6
  Calculating trends for PiC_UVnudge_2006
  Calculating p-values and p-critical values for PiC_UVnudge_2006
0.0002683124521338889
  Calculating trends for PiC_UVnudge_LM2006
  Calculating p-values and p-critical values for PiC_UVnudge_LM2006
0.000730232199461608
  Calculating trends for PiC_UVnudge_MM2006
  Calculating p-values and p-critical values for PiC_UVnudge_MM2006
5.880285704696379e-05
  Calculating trends f

In [53]:
%%time

if plots['strd'][0]:
    if var == 'aice':
        ds_sptrend = Regrid(ds_sptrend, regridderSIC, 'sic')
        ds_sppval = Regrid(ds_sppval, regridderSIC, 'sic', True)
        
    # Regrid ERA5 data
    ds_sptrend_corr = Regrid(ds_sptrend, regridderERA, 'era')

    # Calculate pattern coefficients between ERA5-PiC_UVnudgeX
    ds_trdcor = PattCorr(ds_sptrend_corr)

    SaveData(ds_trdcor, graph_type_str+'.pcorr', var, time_outstr)

    if var == 'Z3':
        # Rename ERA5 lat/lon
        ds_era = ds_usptrend['ERA5']
        ds_era = ds_era.rename({'latE':'lat', 'lonE':'lon'})
        ds_usptrend_corr = ds_usptrend.drop_vars('ERA5')
        ds_usptrend_corr = ds_usptrend_corr.drop_dims(['latE','lonE'])
        ds_usptrend_corr = ds_usptrend_corr.assign({'ERA5':ds_era})

        ds_era = ds_vsptrend['ERA5']
        ds_era = ds_era.rename({'latE':'lat', 'lonE':'lon'})
        ds_vsptrend_corr = ds_vsptrend.drop_vars('ERA5')
        ds_vsptrend_corr = ds_vsptrend_corr.drop_dims(['latE','lonE'])
        ds_vsptrend_corr = ds_vsptrend_corr.assign({'ERA5':ds_era})
        
        # Calculate pattern coefficients between ERA5-PiC_UVnudgeX
        ds_utrdcor = PattCorr(ds_usptrend_corr)
        ds_vtrdcor = PattCorr(ds_vsptrend_corr)
    
        SaveData(ds_utrdcor, graph_type_str+'.pcorr', 'U', time_outstr)
        SaveData(ds_vtrdcor, graph_type_str+'.pcorr', 'V', time_outstr)

Regridding CICE grid -> ATM grid...
Regridding PiC_UVnudge_2006
Regridding PiC_UVnudge_LM2006
Regridding PiC_UVnudge_MM2006
Regridding LENS2 piControl mean
Regridding CICE grid -> ATM grid...
Regridding PiC_UVnudge_2006
Regridding PiC_UVnudge_LM2006
Regridding PiC_UVnudge_MM2006
Regridding ERA5 grid -> ATM grid...
Regridding ERA5
Pattern correlation for DJF
Pattern correlation for JJA
Pattern correlation for MAM
Pattern correlation for SON
Saving Map.Trend.pcorr.aice.Arctic.1980.season.Mean.nc
CPU times: user 636 ms, sys: 35.9 ms, total: 672 ms
Wall time: 711 ms


In [54]:
%%time

if plots['strd'][0]:
    # Add cyclic data
    ds_sptrend = AddAllCyclic(ds_sptrend)
    ds_sppval = AddAllCyclic(ds_sppval)
    
    SaveData(ds_sptrend, graph_type_str, var, time_outstr)
    SaveData(ds_sppval, graph_type_str+'.pval', var, time_outstr)

    if var == 'Z3':
        ds_usptrend = AddAllCyclic(ds_usptrend)
        ds_usppval = AddAllCyclic(ds_usppval)

        ds_vsptrend = AddAllCyclic(ds_vsptrend)
        ds_vsppval = AddAllCyclic(ds_vsppval)

        SaveData(ds_usptrend, graph_type_str, 'U', time_outstr)
        SaveData(ds_usppval, graph_type_str+'.pval', 'U', time_outstr)

        SaveData(ds_vsptrend, graph_type_str, 'V', time_outstr)
        SaveData(ds_vsppval, graph_type_str+'.pval', 'V', time_outstr)

Saving Map.Trend.aice.Arctic.1980.season.Mean.nc
Saving Map.Trend.pval.aice.Arctic.1980.season.Mean.nc
CPU times: user 57.1 ms, sys: 4.03 ms, total: 61.1 ms
Wall time: 123 ms


## Spatial plots
### Set up

In [55]:
if plots['map'][0]:
    graph_type_str = 'Map'

### Data processing

In [56]:
%%time

if plots['map'][0]:
    # Drop CESM2-LENS individual members
    if file_bool:
        ds_avg = ds_proc.drop_vars(['LENS2 piControl'])
    else:
        ds_avg = ds_proc
    
    # Calculating averages
    ds_sp = SpatZonAvg(ds_avg)

    if var == 'Z3':
        ds_spv = SpatZonVar(ds_avg)
        ds_usp = SpatZonAvg(ds_u)
        ds_vsp = SpatZonAvg(ds_v)

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 6.44 µs


In [57]:
%%time

if plots['map'][0]:
    if var == 'aice':
        ds_sp = Regrid(ds_sp, regridderSIC, 'sic')
        
    # Add cyclic data 
    ds_sp = AddAllCyclic(ds_sp)
    SaveData(ds_sp, graph_type_str, var, time_outstr)
    
    if var == 'Z3':
        ds_spv = AddAllCyclic(ds_spv)
        ds_usp = AddAllCyclic(ds_usp)
        ds_vsp = AddAllCyclic(ds_vsp)

        SaveData(ds_spv, graph_type_str+'.var', var, time_outstr)
        SaveData(ds_usp, graph_type_str, 'U', time_outstr)
        SaveData(ds_vsp, graph_type_str, 'V', time_outstr)

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 5.72 µs


## Zonal trend plots
### Set up

In [58]:
# Zonal plots
if plot_types['zonal'][0]:
    # Monthly
    if time_avg == 0:
        period = 'time'
        date_str = mon_str

    # Seasonally
    elif time_avg == 2:
        period = 'time'
        date_str = seas_str

In [59]:
if plots['ztrd'][0]:
    graph_type_str = 'Zonal.Trend'

### Data processing

In [60]:
%%time

if plots['ztrd'][0]:
    # Calculate statistics for LENS2 piControl
    ds_lens_stats = LENS2TrendsEnsemble(graph_type_str, ds_proc['LENS2 piControl'], period, var, False)
    ds_ulens_stats = LENS2TrendsEnsemble(graph_type_str, ds_u['LENS2 piControl'], period, 'U', False)
    ds_vlens_stats = LENS2TrendsEnsemble(graph_type_str, ds_v['LENS2 piControl'], period, 'V', False)

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 6.2 µs


In [61]:
%%time

if plots['ztrd'][0]:
    print('Calculating zonal trends in '+var)
    ds_ztrend, ds_zpval = SpatZonTrends(ds_proc, ds_lens_stats, period, var)

    if var == 'Z3':
        print('Calculating zonal trends in U')
        ds_uztrend, ds_uzpval = SpatZonTrends(ds_u, ds_ulens_stats, period, 'U')
        print('Calculating zonal trends in V')
        ds_vztrend, ds_vzpval = SpatZonTrends(ds_v, ds_vlens_stats, period, 'V')    

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 5.96 µs


In [62]:
%%time

if plots['ztrd'][0]:
    SaveData(ds_ztrend, graph_type_str, var, time_outstr)
    SaveData(ds_zpval, graph_type_str+'.pval', var, time_outstr)

    if var == 'Z3':
        SaveData(ds_uztrend, graph_type_str, 'U', time_outstr)
        SaveData(ds_uzpval, graph_type_str+'.pval', 'U', time_outstr)
        
        SaveData(ds_vztrend, graph_type_str, 'V', time_outstr)
        SaveData(ds_vzpval, graph_type_str+'.pval', 'V', time_outstr)

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 5.72 µs


## Zonal Plots
### Set up

In [63]:
if plots['zon'][0]:
    graph_type_str = 'Zonal'

### Data processing

In [64]:
%%time

if plots['zon'][0]:
    # Time averaging
    ds_zon = SpatZonAvg(ds_proc)
    
    if var == 'Z3':
        # Calculate for U & V
        ds_uzon = SpatZonAvg(ds_u)
        ds_vzon = SpatZonAvg(ds_v)    

CPU times: user 0 ns, sys: 4 µs, total: 4 µs
Wall time: 5.96 µs


In [65]:
%%time

if plots['zon'][0]:
    SaveData(ds_zon, graph_type_str, var, time_outstr)

    if var == 'Z3':
        SaveData(ds_uzon, graph_type_str, 'U', time_outstr)
        SaveData(ds_vzon, graph_type_str, 'V', time_outstr)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.44 µs


In [66]:
client.shutdown()