In [None]:
from __future__ import absolute_import, division, print_function

In [None]:
# License: MIT

# Dynamical analysis of FEM-BV-VAR model for NAO

This notebook contains all the necessary routines for identifying the optimal FEM-BV-VAR model for the NAO and its dynamical properties as presented in the manuscript:

"Dynamical analysis of a reduced model for the NAO" (Quinn, Harries, and O'Kane, 2020)

## Packages

In [None]:
%matplotlib inline

from copy import deepcopy
import itertools
import os
import time

import cartopy.crs as ccrs
import matplotlib
import matplotlib.dates as mdates
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import scipy
import scipy.linalg as linalg
import scipy.stats as stats
import pandas as pd
import seaborn as sns

from cartopy.util import add_cyclic_point
from joblib import Parallel, delayed
from scipy.signal import correlate
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.utils import check_random_state

from statsmodels.nonparametric.smoothers_lowess import lowess
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from clustering_dynamics.dynamics import *

## File paths

In [None]:
## Set these as appropriate
PROJECT_DIR = os.path.join(os.path.dirname(os.path.abspath('produce_figures_tables.ipynb')),'..')
DATA_DIR = os.path.join(PROJECT_DIR,'data')
RESULTS_DIR = os.path.join(PROJECT_DIR,'results')
FEM_BV_VAR_DIR = os.path.join(RESULTS_DIR,'fembv_varx_fits')
EOF_DIR = os.path.join(RESULTS_DIR,'eofs','nc')

## Fit parameters

In [None]:
reanalysis = 'nnr1'
var_name = 'hgt'
var_lev = '500'
var_ext = 'anom'
timespan = '1948_2018'
base_period = [np.datetime64('1979-01-01'), np.datetime64('2018-12-31')]
hemisphere = 'NH'
region = 'atlantic'
season = 'ALL'
pc_scaling = 'unit'
max_eofs = 200
lat_weights = 'scos'

base_period_str = '{}_{}'.format(pd.to_datetime(base_period[0]).strftime('%Y%m%d'),
                                pd.to_datetime(base_period[1]).strftime('%Y%m%d'))

n_PCs = 20

## Load reanalysis data

In [None]:
data_filename = '.'.join([var_name, var_lev, timespan, 'nc'])
data_file = os.path.join(DATA_DIR, data_filename)

hpa500 = xr.open_dataset(data_file)

In [None]:
## calculate anomalies based on 1979-2011 climatology
base_period_da = hpa500.where(
            (hpa500['time'].dt.year >= 1979) &
            (hpa500['time'].dt.year <= 2018), drop=True)

clim_mean_da = base_period_da.groupby(
            base_period_da['time'].dt.dayofyear).mean('time')

anom_da = (base_period_da.groupby(
            base_period_da['time'].dt.dayofyear) - clim_mean_da)


## create data array of anomalies
lats = anom_da.variables['lat'][:]
lons = anom_da.variables['lon'][:]
Zg = anom_da.variables['hgt'][:]

roll_to = -lons.argmin()
lons = np.roll(lons, roll_to)
data = np.roll(Zg.squeeze(), roll_to, axis=-1)

data, lons = add_cyclic_point(data, coord=lons)

data = xr.DataArray(data[:,0:36,:], coords=[anom_da.time, lats[0:36], lons[:]], 
                         dims=['time','lat','lon'])

## Load EOFs

In [None]:
eofs_filename = '.'.join([var_name, var_lev, timespan, base_period_str, 'anom', hemisphere, region, base_period_str,
                           season, 'max_eofs_{:d}'.format(max_eofs), lat_weights, pc_scaling, 'eofs','nc'])
eofs_file = os.path.join(EOF_DIR, eofs_filename)
eofs = xr.open_dataset(eofs_file)

### Figure A1

In [None]:
lats = anom_da.variables['lat'][:]
lons = anom_da.variables['lon'][:]
Zg = anom_da.variables['hgt'][:]

roll_to = -np.argmin(lons.data)
lons = np.roll(lons, roll_to)
data = np.roll(Zg.squeeze(), roll_to, axis=-1)

data, lons = add_cyclic_point(data, coord=lons)

fig = plt.figure(figsize=(10,10))
for j in np.arange(0,20):
    ax = fig.add_subplot(4, 5, j+1, projection=ccrs.Orthographic(central_longitude=0.0,central_latitude=90.0))
    ax.set_global()
    lon, lat = np.meshgrid(lons[101:], lats[0:29]) 
    fill = ax.contourf(lons[101:], lats[0:29],
                       eofs.eofs[j,0,0:29,:],
                       60, transform=ccrs.PlateCarree(),cmap='PRGn',vmin=-0.1,vmax=0.1)
   
    ax.set_title('EOF ' + str(j+1))
    
    # draw coastlines
    ax.coastlines()
       
#plt.tight_layout()

#plt.savefig('../figures/figA1.pdf')

# Compare FEM-BV-VAR models

In [None]:
## identify optimal model parameters
model_prefix = 'hgt.500.1948_2018.{}.anom.{}.{}.{}.ALL.max_eofs_{:d}.scos.unit.fembv_varx.n_pcs{:d}'.format(
    base_period_str, hemisphere, region, base_period_str, max_eofs, n_PCs)

n_components = [1, 2, 3]
memory = [0, 1, 2, 3, 4, 5]
state_lengths = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60]

cv_results = {'n_components': [], 'memory': [], 'state_length': [],
              'mean_test_cost': [], 'mean_test_rmse': [], 'mean_test_log_likelihood': [],
              'stderr_test_cost': [], 'stderr_test_rmse': [], 'stderr_test_log_likelihood': []}

n_samples = None
for k in n_components:
    for m in memory:
        for p in state_lengths:
            
            model_ds = xr.open_dataset(os.path.join(
                FEM_BV_VAR_DIR,
                '.'.join([model_prefix, 'k{:d}.m{:d}.state_length{:d}.nc'.format(
                k, m, p)])))

            if n_samples is None:
                n_samples = model_ds.sizes['time']
            else:
                if model_ds.sizes['time'] != n_samples:
                    raise RuntimeError('Number of samples do not match')

            cv_results['n_components'].append(k)
            cv_results['memory'].append(m)
            cv_results['state_length'].append(p)
            cv_results['mean_test_cost'].append(model_ds['test_cost'].mean('fold').item())
            cv_results['mean_test_rmse'].append(model_ds['test_rmse'].mean('fold').item())
            cv_results['mean_test_log_likelihood'].append(model_ds['test_log_likelihood'].mean('fold').item())
            cv_results['stderr_test_cost'].append(model_ds['test_cost'].std('fold').item() / np.sqrt(model_ds.sizes['fold']))
            cv_results['stderr_test_rmse'].append(model_ds['test_rmse'].std('fold').item() / np.sqrt(model_ds.sizes['fold']))
            cv_results['stderr_test_log_likelihood'].append(model_ds['test_log_likelihood'].std('fold').item() / np.sqrt(model_ds.sizes['fold']))

            model_ds.close()

for f in cv_results:
    cv_results[f] = np.asarray(cv_results[f])

min_rmse_idx = np.argmin(cv_results['mean_test_rmse'])

print('Min. test RMSE k = ', cv_results['n_components'][min_rmse_idx])
print('Min. test RMSE m = ', cv_results['memory'][min_rmse_idx])
print('Min. test RMSE p = ', cv_results['state_length'][min_rmse_idx])

## Figure 1

In [None]:
fig = plt.figure(figsize=(10, 8))
ax = plt.gca()

axins = inset_axes(ax, width='50%', height='45%', loc=5)

unique_n_components = np.unique(n_components)
unique_memory = np.unique(memory)

n_memory_vals = len(unique_memory)
width = 0.8
if n_memory_vals % 2 == 0:
    offsets = [-j - 0.5 for j in range(0, n_memory_vals // 2)][::-1] + [j + 0.5 for j in range(0, n_memory_vals // 2)]
else:
    offsets = [-j for j in range(1, n_memory_vals // 2 + 1)][::-1] + [0] + [j for j in range(1, n_memory_vals // 2 + 1)]

colors = itertools.cycle(('#fdcc8a', '#fc8d59', '#d7301f', '#fef0d9'))
linestyles = itertools.cycle(('-', '--', ':', '-.'))

for k in unique_n_components:
    
    c = next(colors)
    ls = next(linestyles)

    markers = itertools.cycle(('.', 'x', 's', 'd', 'v', '^', '<', '>'))
    
    for i, m in enumerate(unique_memory):
        
        marker = next(markers)
        
        mask = np.logical_and(cv_results['n_components'] == k, cv_results['memory'] == m)
        
        xcoords = cv_results['state_length'][mask] + offsets[i] * width
        cv_mean = cv_results['mean_test_rmse'][mask]
        cv_std_err = cv_results['stderr_test_rmse'][mask]
        
        ax.errorbar(xcoords, cv_mean, yerr=cv_std_err, capsize=5, markersize=8, color=c, ls='none', marker=marker,
                    label='$K = {:d}, m = {:d}$'.format(k, m))
        
        axins.errorbar(xcoords, cv_mean, yerr=cv_std_err, capsize=5, markersize=8, color=c, ls='none', marker=marker)

ax.legend(ncol=3, fontsize=14, bbox_to_anchor=(0.5, -0.3), loc='center', borderaxespad=0.)

ax.grid(ls='--', color='gray', alpha=0.5)
axins.grid(ls='--', color='gray', alpha=0.5)

ax.tick_params(axis='both', labelsize=14)
ax.set_xlabel('$p$ (days)', fontsize=16)
ax.set_ylabel('Test set RMSE', fontsize=16)

axins.set_ylim(180, 200)
axins.set_xlim(-3, 25)
axins.tick_params(axis='both', labelsize=14)

ax.set_title('$d = 20$, $N_{init} = 20$, $N_{folds} = 10$', fontsize=18)
       
#plt.savefig('../figures/fig1.pdf', bbox_inches='tight')

plt.show()

plt.close()

# Properties of optimal FEM-BV-VAR model

## Load optimal model

In [None]:
k = 3
m = 3
p = 5

model_filename = '.'.join([var_name, var_lev, timespan, base_period_str, 'anom', hemisphere, region, base_period_str,
                           season, 'max_eofs_{:d}'.format(max_eofs), lat_weights, pc_scaling, 'fembv_varx',
                           'n_pcs20','k{:d}'.format(k),'m{:d}'.format(m),'state_length{:d}'.format(p),'nc'])
model_file = os.path.join(FEM_BV_VAR_DIR, model_filename)
model = xr.open_dataset(model_file)

## Plotting state composites

In [None]:
def viterbi_state_assignments(weights_da, time_name='time', state_name='fembv_state'):

    n_samples = weights_da.sizes[time_name]

    state_axis = weights_da.get_axis_num(state_name)

    if state_axis != 1:
        weights_da = weights_da.transpose(time_name, state_name)

    weights = weights_da
    mask = np.all(np.isfinite(weights), axis=1)
    
    valid_weights = weights[mask]
    
    valid_viterbi = np.argmax(valid_weights.data, axis=1)
    
    full_viterbi = np.full((n_samples,), np.NaN)
    full_viterbi[mask] = valid_viterbi

    viterbi = xr.DataArray(
        full_viterbi,
        coords={time_name: weights_da[time_name]},
        dims=[time_name], name=state_name)
    
    return viterbi

In [None]:
def calculate_fembv_state_composites(model_ds, anom_da, bootstrap=True, bootstrap_type='independent',
                                     n_bootstrap=1000, time_name='time', random_seed=None):
    """Calculate FEM-BV-VARX state composites."""
    
    random_state = np.random.default_rng(random_seed)
    n_components = model_ds.sizes['fembv_state']

    affs = model['weights'].dropna(time_name)
    affs_start = affs[time_name].min()
    affs_end = affs[time_name].max()
    
    viterbi = viterbi_state_assignments(affs)
    
    anom_da = anom_da.where(
        (anom_da[time_name] >= affs_start) & (anom_da[time_name] <= affs_end),
        drop=True)
        
    composites_da = anom_da.groupby(viterbi).mean(time_name)
    
    if not bootstrap:
        return composites_da
        
    n_samples = viterbi.sizes[time_name]

    percentile_scores_da = xr.zeros_like(composites_da)
    
    if bootstrap_type == 'independent':
        for k in range(n_components):
            
            n_events = np.sum(viterbi == k).item()
            
            bootstrap_composites = []
            for s in range(n_bootstrap):
                t_boot = random_state.choice(n_samples, size=n_events, replace=False)
                bootstrap_composites.append(anom_da.isel({time_name: t_boot}).mean(time_name).squeeze())
    
            bootstrap_composites = xr.concat(bootstrap_composites, dim='bootstrap_sample')

            composite_dims = list(composites_da.sel(fembv_state=k).squeeze().dims)
            composite_coords = composites_da.sel(fembv_state=k).squeeze().coords
        
            # ensure sample dimension is first dimension
            bootstrap_composites = bootstrap_composites.transpose(*(['bootstrap_sample'] + composite_dims))
        
            original_shape = [composites_da.sizes[d] for d in composite_dims]
            n_features = np.prod(original_shape)
        
            flat_composite = np.reshape(composites_da.sel(fembv_state=k).data, (n_features,))
            flat_bootstrap_composites = np.reshape(bootstrap_composites.data, (n_bootstrap, n_features))
        
            scores = np.zeros((n_features,), dtype=np.float64)
            for i in range(n_features):
                scores[i] = stats.percentileofscore(flat_bootstrap_composites[:, i], flat_composite[i], kind='weak') / 100.0
   
            scores_da = xr.DataArray(np.reshape(scores, original_shape), coords=composite_coords, dims=composite_dims)
        
            percentile_scores_da.loc[dict(fembv_state=k)] = scores_da
    
    elif bootstrap_type == 'multinomial':

        bootstrap_composites = {k: [] for k in range(n_components)}
        
        for s in range(n_bootstrap):
            
            t = list(np.arange(n_samples))
            
            for k in range(n_components):
                n_events = np.sum(viterbi == k).item()
                t_boot = random_state.choice(t, size=n_events, replace=False)
                t = [ti for ti in t if ti not in t_boot]
                bootstrap_composites[k].append(anom_da.isel({time_name: t_boot}).mean(time_name).squeeze())

            assert len(t) == 0
            
        for k in range(n_components):

            bootstrap_composites[k] = xr.concat(bootstrap_composites[k], dim='bootstrap_sample')

            composite_dims = list(composites_da.sel(fembv_state=k).squeeze().dims)
            composite_coords = composites_da.sel(fembv_state=k).squeeze().coords
        
            # ensure sample dimension is first dimension
            bootstrap_composites[k] = bootstrap_composites[k].transpose(*(['bootstrap_sample'] + composite_dims))
        
            original_shape = [composites_da.sizes[d] for d in composite_dims]
            n_features = np.prod(original_shape)
        
            flat_composite = np.reshape(composites_da.sel(fembv_state=k).data, (n_features,))
            flat_bootstrap_composites = np.reshape(bootstrap_composites[k].data, (n_bootstrap, n_features))
        
            scores = np.zeros((n_features,), dtype=np.float64)
            for i in range(n_features):
                scores[i] = stats.percentileofscore(flat_bootstrap_composites[:, i], flat_composite[i], kind='weak') / 100.0
   
            scores_da = xr.DataArray(np.reshape(scores, original_shape), coords=composite_coords, dims=composite_dims)
        
            percentile_scores_da.loc[dict(fembv_state=k)] = scores_da

    else:
        raise ValueError("Unrecognized bootstrap method '%r'" % bootstrap_type)

    composites_ds = xr.Dataset({'composites': composites_da, 'bootstrap_percentile': percentile_scores_da})

    return composites_ds

In [None]:
start_time = time.time()

bootstrap = True
fembv_composites = calculate_fembv_state_composites(model, anom_da['hgt'], bootstrap=bootstrap, bootstrap_type='multinomial', n_bootstrap=10,
                                                    random_seed=0)
alpha = 0.01
if bootstrap and isinstance(fembv_composites, xr.Dataset):     
    fembv_composites = xr.where((fembv_composites['bootstrap_percentile'] >= 1.0 - 0.5 * alpha) |
                                (fembv_composites['bootstrap_percentile'] <= 0.5 * alpha), fembv_composites['composites'], np.NaN)
    
end_time = time.time()
elapsed = (end_time-start_time)/60
print("Elapsed time: {} min".format(round(elapsed,4)))

## Figure 2

In [None]:
n_composites = fembv_composites.sizes['fembv_state']
n_cols = n_composites
n_rows = 1
wrap_lon = True
    
projection = ccrs.Orthographic(central_latitude=90, central_longitude=0)

vmins = np.full((n_composites,), fembv_composites.min().item())
vmaxs = np.full((n_composites,), fembv_composites.max().item())

height_ratios = np.ones((n_rows + 1))
height_ratios[-1] = 0.1

fig = plt.figure(constrained_layout=False, figsize=(4 * n_cols, 4 * n_rows))

gs = gridspec.GridSpec(ncols=n_cols, nrows=n_rows + 1, figure=fig,
                       wspace=0.05, hspace=0.2,
                       height_ratios=height_ratios)

lat = fembv_composites['lat']
lon = fembv_composites['lon']

row_index = 0
col_index = 0


for i in range(n_composites):

    composite_data = fembv_composites.sel(fembv_state=i).squeeze().values

    vmin = np.nanmin(composite_data)
    vmax = np.nanmax(composite_data)
    
    ax_vmin = -max(np.abs(vmin), np.abs(vmax))
    ax_vmax = -ax_vmin

    if wrap_lon:
        composite_data, composite_lon = add_cyclic_point(composite_data, coord=lon)
    else:
        composite_lon = lon

    lon_grid, lat_grid = np.meshgrid(composite_lon, lat)

    ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

    ax.coastlines()
    ax.set_global()

    cs = ax.pcolor(lon_grid, lat_grid, composite_data, vmin=ax_vmin, vmax=ax_vmax,
                   cmap=plt.cm.RdBu_r, transform=ccrs.PlateCarree())

    if np.any(~np.isfinite(composite_data)):
        ax.patch.set_facecolor('lightgray')

    cb_ax = fig.add_subplot(gs[-1, col_index])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')
    cb.set_label(r'$Z_{g500\,\mathrm{hPa}}^\prime$ (gpm)', fontsize=14)

    ax.set_title('state {}'.format(i+1), fontsize=14)

    ax.set_aspect('equal')
    fig.canvas.draw()

    col_index += 1
    if col_index == n_cols:
        col_index = 0
        row_index += 1

#plt.savefig('../figures/fig2.pdf', bbox_inches='tight')

# Separating by state and identifying transitions

In [None]:
## separate affiliation indices by state
comp1_ind = np.where(model.weights[5:].argmax(dim = 'fembv_state') == 0)[0]+5
comp2_ind = np.where(model.weights[5:].argmax(dim = 'fembv_state') == 1)[0]+5
comp3_ind = np.where(model.weights[5:].argmax(dim = 'fembv_state') == 2)[0]+5

## extract transition indices (last day in state)
affil_seq = model.weights[5:].argmax(dim = 'fembv_state')
trans_ind_all = np.array([],dtype=int)
state_length_all = np.array([],dtype=int)

for i in np.arange(0,affil_seq.shape[0]-1):
    if affil_seq[i] != affil_seq[i+1]:
        trans_ind_all = np.append(trans_ind_all,i+5)

## extract residency times
state_length_all = np.empty(trans_ind_all.shape[0]+1,dtype=int)
        
for i in np.arange(0,trans_ind_all.shape[0]+1):
    if i == 0:
        state_length_all[i] = trans_ind_all[i]-5+1
    elif i == trans_ind_all.shape[0]:
        state_length_all[i] = model.weights.shape[0]-trans_ind_all[i-1]-1
    else:
        state_length_all[i] = trans_ind_all[i]-trans_ind_all[i-1]
         

In [None]:
## separate transitions and residencies by state
trans_ind_1 = np.array([],dtype=int)
trans_ind_2 = np.array([],dtype=int)
trans_ind_3 = np.array([],dtype=int)

state_length_1 = np.array([],dtype=int)
state_length_2 = np.array([],dtype=int)
state_length_3 = np.array([],dtype=int)

for i in np.arange(0,trans_ind_all.shape[0]):
    state_affil_i = affil_seq[trans_ind_all[i]-5]
    if state_affil_i == 0:
        trans_ind_1 = np.append(trans_ind_1,trans_ind_all[i])
        state_length_1 = np.append(state_length_1,state_length_all[i])
    elif state_affil_i == 1:
        trans_ind_2 = np.append(trans_ind_2,trans_ind_all[i])
        state_length_2 = np.append(state_length_2,state_length_all[i])
    elif state_affil_i == 2:
        trans_ind_3 = np.append(trans_ind_3,trans_ind_all[i])
        state_length_3 = np.append(state_length_3,state_length_all[i])
    else:
        print('invalid state at index {}'.format(i))
        break

## classify final state residence
if np.isin(model.weights.shape[0]-1,comp1_ind):
    state_length_1 = np.append(state_length_1,state_length_all[-1])
elif np.isin(model.weights.shape[0]-1,comp2_ind):
    state_length_2 = np.append(state_length_2,state_length_all[-1])
elif np.isin(model.weights.shape[0]-1,comp3_ind):
    state_length_3 = np.append(state_length_3,state_length_all[-1])
else:
    print('invalid state at index {}'.format(i))

## separate transitions by state transitioned to
trans_ind_to_1 = np.array([],dtype=int)
trans_ind_to_2 = np.array([],dtype=int)
trans_ind_to_3 = np.array([],dtype=int)

for i in np.arange(0,trans_ind_all.shape[0]):
    state_affil_i = affil_seq[trans_ind_all[i]+1-5]
    if state_affil_i == 0:
        trans_ind_to_1 = np.append(trans_ind_to_1,trans_ind_all[i]+1)
    elif state_affil_i == 1:
        trans_ind_to_2 = np.append(trans_ind_to_2,trans_ind_all[i]+1)
    elif state_affil_i == 2:
        trans_ind_to_3 = np.append(trans_ind_to_3,trans_ind_all[i]+1)
    else:
        print('invalid state at index {}'.format(i))
        break

In [None]:
## assign times in states and at transitions
state_1_times = model.time[comp1_ind]
state_2_times = model.time[comp2_ind]
state_3_times = model.time[comp3_ind]

trans_1_times = model.time[trans_ind_1]
trans_2_times = model.time[trans_ind_2]
trans_3_times = model.time[trans_ind_3]

## Calculate statistics by state

In [None]:
## convert residencies to xarray and add time coordinates
trans_inds_collect = [trans_ind_1,trans_ind_2,trans_ind_3]
state_length_inds = [[]]*k
for ii in np.arange(0,k):
    if affil_seq[-1] == ii:
        state_length_inds[ii] = np.append(trans_inds_collect[ii],-1)
    else:
        state_length_inds[ii] = trans_inds_collect[ii]

state_length_1   = xr.DataArray(state_length_1, coords=[model.time[state_length_inds[0]]], dims=['time'])
state_length_2   = xr.DataArray(state_length_2, coords=[model.time[state_length_inds[1]]], dims=['time'])
state_length_3   = xr.DataArray(state_length_3, coords=[model.time[state_length_inds[2]]], dims=['time'])

state_length_collect = [state_length_1,state_length_2,state_length_3]

means = np.zeros((len(state_length_collect),5))
mins = np.zeros((len(state_length_collect),5))
maxs = np.zeros((len(state_length_collect),5))

seasons = ['DJF','MAM','JJA','SON','ALL']

for si in np.arange(0,len(seasons)):
    if seasons[si] == 'ALL':
        for jj in np.arange(0,len(state_length_collect)):
            means[jj,si] = np.mean(state_length_collect[jj])
            mins[jj,si]  = np.min(state_length_collect[jj])
            maxs[jj,si]  = np.max(state_length_collect[jj])
    else:
        for jj in np.arange(0,len(state_length_collect)): 
            seasonal_state_lengths = state_length_collect[jj].where(model.time.dt.season==seasons[si],drop=True)
            means[jj,si] = np.mean(seasonal_state_lengths)
            mins[jj,si]  = np.min(seasonal_state_lengths)
            maxs[jj,si]  = np.max(seasonal_state_lengths)


## Table 1

In [None]:
df = pd.DataFrame(np.row_stack([mins[0,:],means[0,:],maxs[0,:],
                                  mins[1,:],means[1,:],maxs[1,:],
                                  mins[2,:],means[2,:],maxs[2,:]]), 
                  index=['state 1 min','state 1 mean','state 1 max','state 2 min','state 2 mean','state 2 max',
                         'state 3 min','state 3 mean','state 3 max'],
                  columns=['DJF','MAM','JJA','SON','All'])

h_styles = [dict(selector="th", props=[("font-size", "12pt")])]
df.style.set_table_styles(h_styles).set_properties(**{'font-size': '14pt'}).format("{:,.1f}")

## Seasonal behaviour of states and transitions

In [None]:
## separate by specific transition
trans_ind_1_to_2 = np.array([],dtype=int)
trans_ind_1_to_3 = np.array([],dtype=int)
trans_ind_2_to_1 = np.array([],dtype=int)
trans_ind_2_to_3 = np.array([],dtype=int)
trans_ind_3_to_1 = np.array([],dtype=int)
trans_ind_3_to_2 = np.array([],dtype=int)

for ti in trans_ind_1:
    if np.isin(ti+1,trans_ind_to_2):
        trans_ind_1_to_2 = np.append(trans_ind_1_to_2,ti)
    elif np.isin(ti+1,trans_ind_to_3):
        trans_ind_1_to_3 = np.append(trans_ind_1_to_3,ti)

for ti in trans_ind_2:
    if np.isin(ti+1,trans_ind_to_1):
        trans_ind_2_to_1 = np.append(trans_ind_2_to_1,ti)
    elif np.isin(ti+1,trans_ind_to_3):
        trans_ind_2_to_3 = np.append(trans_ind_2_to_3,ti)

for ti in trans_ind_3:
    if np.isin(ti+1,trans_ind_to_1):
        trans_ind_3_to_1 = np.append(trans_ind_3_to_1,ti)
    elif np.isin(ti+1,trans_ind_to_2):
        trans_ind_3_to_2 = np.append(trans_ind_3_to_2,ti)

trans_1_to_2_times = model.time[trans_ind_1_to_2]
trans_1_to_3_times = model.time[trans_ind_1_to_3]
trans_2_to_1_times = model.time[trans_ind_2_to_1]
trans_2_to_3_times = model.time[trans_ind_2_to_3]
trans_3_to_1_times = model.time[trans_ind_3_to_1]
trans_3_to_2_times = model.time[trans_ind_3_to_2]

In [None]:
## number of transitions
trans_num_season = np.zeros((7,5),dtype=np.float)

trans_ind_collect = [trans_ind_1_to_2,trans_ind_1_to_3,
                     trans_ind_2_to_1,trans_ind_2_to_3,
                     trans_ind_3_to_1,trans_ind_3_to_2]

seasons = ['DJF','MAM','JJA','SON']

for si in np.arange(0,len(seasons)):
    for jj in np.arange(0,len(trans_ind_collect)):
        trans_times_season = model.time[trans_ind_collect[jj]].where(model.time.dt.season==seasons[si],drop=True)
        trans_num_season[jj,si] = trans_times_season.shape[0]

trans_num_season[6,:] = np.sum(trans_num_season[0:6,:],axis=0)
trans_num_season[:,4] = np.sum(trans_num_season[:,0:4],axis=1)

In [None]:
## days in each state
state_times_season = np.zeros((4,5),dtype=np.float)

state_times_collect = [state_1_times,state_2_times,state_3_times,model.time[5:]]
 
for si in np.arange(0,len(seasons)):
    for jj in np.arange(0,len(state_times_collect)):   
        state_times_season[jj,si] = state_times_collect[jj].where(model.time.dt.season==seasons[si],drop=True).shape[0]

state_times_season[:,4] = np.sum(state_times_season[:,0:4],axis=1)

### Table 2

In [None]:
df = pd.DataFrame(np.concatenate([trans_num_season,state_times_season],axis=0), 
                  columns=np.append(seasons,'All seasons'),
                  index=['1 to 2','1 to 3','2 to 1','2 to 3','3 to 1','3 to 2','Any trans',
                         'state 1','state 2','state 3','Any state'])

h_styles = [dict(selector="th", props=[("font-size", "12pt")])]
df.style.set_table_styles(h_styles).set_properties(**{'font-size': '14pt'}).format("{:,.0f}")

# Compare with index

In [None]:
index_cpc = 'NAO'
IND_cpc = np.genfromtxt(os.path.join(DATA_DIR,'cpc.{}.daily.csv'.format(index_cpc.lower())), delimiter=',')

## calculating time variable
IND_cpc_time = np.zeros(IND_cpc.shape[0],dtype='datetime64[s]')
for tt in np.arange(0,IND_cpc.shape[0]):
    IND_cpc_time[tt] = np.datetime64('{}-{}-{}'.format(int(IND_cpc[tt,0]),
                       str(int(IND_cpc[tt,1])).zfill(2),str(int(IND_cpc[tt,2])).zfill(2)))

## Sliding window residecy percent

In [None]:
## Teleconnection index

## sliding window anomaly residency
y = 0
end_ind = IND_cpc.shape[0]-365
IND_cpc_pos = np.empty(end_ind)
IND_cpc_neg = np.empty(end_ind)

for y_ind in np.arange(0,end_ind):
    IND_cpc_pos[y] = np.count_nonzero(IND_cpc[y_ind:y_ind+365,3]>=0)/365
    IND_cpc_neg[y] = np.count_nonzero(IND_cpc[y_ind:y_ind+365,3]<=0)/365
    
    y += 1

## Convert to xarray
IND_cpc_pos = xr.DataArray(IND_cpc_pos, coords=[IND_cpc_time[365:]], dims=['time'])
IND_cpc_neg = xr.DataArray(IND_cpc_neg, coords=[IND_cpc_time[365:]], dims=['time'])    

In [None]:
## Model fit index

## sliding window anomaly residency
n_components = k
days = 365
    
y = 0
end_ind = model.time.shape[0]-days-5
comp_freq_sw = np.empty((end_ind,k))

for y_ind in np.arange(0,end_ind):
    for state in np.arange(0,k):
        comp_freq_sw[y,state] = np.count_nonzero(model.weights[5+y_ind:5+y_ind+days].argmax(dim='fembv_state') == state,
                                                 axis=0)/days
    y += 1
    
## convert to xarray
comp_freq_sw = xr.DataArray(comp_freq_sw, coords=[model.time[5+365:],np.arange(1,4)], dims=['time','state'])

## Yearly average and LOWESS fit residency percent

In [None]:
## Teleconnection index

## percentage of calendar year spent in negative NAO state
start_year = 0
y = 0
IND_binned = np.empty((IND_cpc_time.shape[0]))
num_years = round(IND_cpc_time.shape[0]/365)
year_inds_IND = np.zeros(num_years, dtype=int)

for year in np.arange(2021-num_years,2021):
    if year == 2020:
        days = 182
    elif np.mod(year,4)==0:
        days = 366
    else:
        days = 365
    
    IND_binned[start_year:start_year+days] = np.ones(days)*np.count_nonzero(IND_cpc[start_year:start_year+days,3]<0)/days
    year_inds_IND[y] = int(start_year)
    
    start_year += days
    y += 1


## convert to xarray
IND_binned = xr.DataArray(IND_binned, coords=[IND_cpc_time], dims=['time'])

## LOWESS fit
IND_binned_lowess = np.empty((year_inds_IND.shape[0]))
binned_lowess_p = lowess(IND_binned[year_inds_IND].sel(time=slice("1979-01-01", "2018-12-31")),
                         IND_binned[year_inds_IND].time.sel(time=slice("1979-01-01", "2018-12-31")),frac=0.25)
IND_binned_lowess = binned_lowess_p[:,1]

IND_binned_lowess = xr.DataArray(IND_binned_lowess, 
                                 coords=[IND_binned[year_inds_IND+181].time.sel(time=slice("1979-01-01", "2018-12-31"))], dims=['time'])

In [None]:
## Model fit index
## choose index from in
model_NAO_ind = 1

## percentage of calendar year spent in negative NAO state
affil_binned = np.empty((model.time.shape[0]-5,k))
num_years = round(model.time.shape[0]/365)
year_inds = np.zeros(num_years, dtype=int)
affil_binned_lowess = np.zeros((year_inds.shape[0],k))

start_year = 0
y = 0
for year in np.arange(2019-num_years,2019):
    if np.mod(year,4)==0:
        days = 366
    elif year == 1979:
        days = 360
    else:
        days = 365
    
    for state in np.arange(0,k):
        affil_binned[start_year:start_year+days,state] = np.ones(days)*(np.count_nonzero(model.weights[5+start_year:5+start_year+days].argmax(dim='fembv_state') == state,
                                                                               axis=0)/days)
    if y == 0:
        year_inds[y] = int(start_year)
    else:
        year_inds[y] = int(start_year)+5
    
    start_year += days
    y += 1
    
## LOWESS fit
for state in np.arange(0,k):
    binned_lowess_p = lowess(affil_binned[year_inds,state],model.time[year_inds],frac=0.25)
    affil_binned_lowess[:,state] = binned_lowess_p[:,1]

## convert to xarray
affil_binned = xr.DataArray(affil_binned, coords=[model.time[5:], np.arange(0,k)], dims=['time','fembv_state'])
affil_binned_lowess = xr.DataArray(affil_binned_lowess, coords=[model.time[year_inds+181], np.arange(0,k)], dims=['time','fembv_state'])

## Figure 3

In [None]:
## choose state index to compare
state_ind = 1

fig = plt.figure(figsize=(10,6))

ax1 = fig.add_subplot(2,1,1)
ax1.plot(IND_cpc_neg.time, IND_cpc_neg)
ax1.plot(comp_freq_sw.time, comp_freq_sw[:,state_ind])

ax1.set_xlim([np.datetime64("1980-01-01"),np.datetime64("2018-12-31")])
ax1.set_ylim(0.0, 1.05)
ax1.tick_params(axis='both', labelsize=13)
ax1.grid(ls='--', color='gray', alpha=0.5)

ax1.set_xlabel('Year', fontsize=14)
ax1.set_ylabel('Residency fraction', fontsize=14)
plt.title('Residency percent (365 day sliding window)',fontsize=15)
plt.legend(['CPC NAO$^-$ index','state {}'.format(state_ind+1)],loc='upper right')
plt.tight_layout()

ax2 = fig.add_subplot(2,1,2)
p1 = ax2.plot(IND_binned.time, IND_binned,alpha=0.6)
color1 = p1[0].get_color()
ax2.plot(IND_binned_lowess.time, IND_binned_lowess,'--',color=color1,lw=3)
p2 = ax2.plot(affil_binned.time, affil_binned[:,state_ind],alpha=0.6)
color2 = p2[0].get_color()
ax2.plot(affil_binned_lowess.time, affil_binned_lowess[:,state_ind],'--',color=color2,lw=3)

ax2.set_xlim([np.datetime64("1980-01-01"),np.datetime64("2018-12-31")])
ax2.set_ylim(0.0, 1.05)
ax2.grid(ls='--', color='gray', alpha=0.5)
ax2.tick_params(axis='both', labelsize=13)

ax2.set_xlabel('Year', fontsize=14)
ax2.set_ylabel('Residency fraction', fontsize=14)
plt.title('Residency percent (yearly average and LOWESS)',fontsize=15)

plt.tight_layout()

#plt.savefig('../figures/fig3.pdf'.format(region,k,m,p))

### Correlations

In [None]:
def calculate_correlation(y1, y2):
    """Calculate correlation of one variable with another."""
    start_time = max(y1.index.min(), y2.index.min())
    end_time = min(y1.index.max(), y2.index.max())
    y1_da = y1.where((y1.index >= start_time) & (y1.index <= end_time)).dropna()
    y2_da = y2.where((y2.index >= start_time) & (y2.index <= end_time)).dropna()
    mask = np.logical_or(np.isnan(y1_da.to_numpy()), np.isnan(y2_da.to_numpy()))
    y1_values = np.ma.masked_array(y1_da.to_numpy(), mask=mask)
    y2_values = np.ma.masked_array(y2_da.to_numpy(), mask=mask)
    correlation = stats.pearsonr(y1_values, y2_values)[0]
    return correlation

In [None]:
# correlation to NAO (negative) index
corr_cpc = calculate_correlation(IND_cpc_neg.to_pandas(),comp_freq_sw[:,1].to_pandas())

# correlation to binned NAO (negative) index
corr_cpc_binned = calculate_correlation(IND_binned.to_pandas(),affil_binned[:,1].to_pandas()) 

# correlation to low-pass filter NAO (negative) index
corr_cpc_binned_lowess = calculate_correlation(IND_binned_lowess.to_pandas(),affil_binned_lowess[:,1].to_pandas())    

df = pd.DataFrame(np.round([corr_cpc,corr_cpc_binned,corr_cpc_binned_lowess],2), 
                  index=['no filter','binned','binned and LOWESS'],columns=[str(p)+' days'])

df.style.set_caption('Correlations with NAO neg index').set_properties(**{'font-size': '12pt'}).format("{:.2}")

# Dynamical analysis

## Compute matrix cocycle

In [None]:
time_len = model.weights.shape[0]-5
state_space = m*n_PCs
A = np.array(model.A)
gammas = np.array(model.weights)

start = time.time()

AT = np.matmul(gammas[:,:],A[:,0,:,:].transpose(1, 0, 2)).transpose(0,2,1)
for mm in np.arange(1,m):
    AT = np.concatenate((AT,np.matmul(gammas[:,:],A[:,mm,:,:].transpose(1, 0, 2)).transpose(0,2,1)),axis=1)

I0 = np.concatenate((np.eye(n_PCs*(m-1)),np.zeros((n_PCs*(m-1),n_PCs))),axis=1)
I0 = np.repeat(I0[:, :, np.newaxis], AT.shape[2], axis=2)

matrix_cocycle = np.concatenate((AT,I0),axis=0)
matrix_cocycle = matrix_cocycle[:,:,5:]

end = time.time()
elapsed = end-start
print("Elapsed time: {} sec".format(round(elapsed,4)))

## Load CLVs

In [None]:
Ms = [3,10,30,50]
state_space = m*n_PCs
CLVs_all = [[]]*len(Ms)
i = 0

for M in Ms:
    
    CLVs_filename = '.'.join([var_name, var_lev, timespan, base_period_str, 'anom', hemisphere, region, 'ALL', 
                         'max_eofs_{:d}'.format(max_eofs), lat_weights, pc_scaling, 'm{:d}'.format(m),
                          'state_length{:d}'.format(p),'CLVs', 'M{:d}'.format(M),'orth1','nc'])

    CLVs_file = os.path.join(FEM_BV_VAR_DIR, 'CLVs','truncated', CLVs_filename)

    CLVs_ds = xr.open_dataset(CLVs_file)
    CLVs = CLVs_ds['CLVs']
    
    CLVs_all[i] = CLVs
    i += 1

In [None]:
## load time for CLV calculations
time_CLVs_all = [[]]*len(Ms)

for i in np.arange(0,len(Ms)):
    CLVs = CLVs_all[i]
    time_CLVs_all[i] = CLVs.time

## Calculate FTCLEs

In [None]:
Lyaps_all = [[]]*len(Ms)

for Mi in np.arange(0,len(Ms)):
    M = Ms[Mi]
    CLVs = CLVs_all[Mi]
    time_CLVs = time_CLVs_all[Mi]

    M_FTLE  = 1
    orth_win = 1
    Nk = np.arange(0,M_FTLE+1,orth_win)
    Qp = np.eye(state_space)
    Lyaps = np.empty((CLVs.shape[1],CLVs.shape[2]))

    start = time.time()

    for i in np.arange(0,CLVs.shape[2]):
        C = np.array(CLVs[:,:,i])
        norm_C = linalg.norm(C,axis=0)
        Lyap_i = np.empty((CLVs.shape[1],M_FTLE))
        for tt in np.arange(0,M_FTLE):
            C = np.matmul(matrix_cocycle[:,:,i+tt+M],C)
            Lyap_i[:,tt] = linalg.norm(C,axis=0)-norm_C
            norm_C = linalg.norm(C,axis=0)
        Lyaps[:,i] = np.mean(Lyap_i,axis=1)
    
    end = time.time()
    elapsed = end-start
    print("Elapsed time: {} sec, M = {}".format(round(elapsed,4),M))
    
    Lyaps_all[Mi] = Lyaps  

In [None]:
## convert to xarray
for Mi in np.arange(0,len(Ms)):
    Lyaps_all[Mi] = xr.DataArray(Lyaps_all[Mi], coords=[np.arange(1,CLVs_all[Mi].shape[1]+1), time_CLVs_all[Mi]], dims=['FTLE', 'time'])

## Calculate asymptotic Lyapunov exponents using QR

In [None]:
Lyap_asymp = calculate_FTLEs(state_space,matrix_cocycle,np.arange(0,matrix_cocycle.shape[2]+1,1))

### Figure 4

In [None]:
fig = plt.figure(figsize=(13,10))

for j in np.arange(0,len(Ms)):
    mean_Lyaps = np.mean(Lyaps_all[j][:,5:],axis=1)
    min_Lyaps = np.min(Lyaps_all[j][:,5:],axis=1)
    max_Lyaps = np.max(Lyaps_all[j][:,5:],axis=1)
    std_Lyaps = np.std(Lyaps_all[j][:,5:],axis=1)
    
    ax = fig.add_subplot(len(Ms), 1, j+1)

    for ll in range(0,10):
        pl = ax.errorbar(ll+1, mean_Lyaps[ll], yerr=std_Lyaps[ll], fmt='o',lw=3,capsize=8,markeredgewidth=2);
        c = pl[0].get_color()
        ax.plot(ll+1.1, Lyap_asymp[ll],'o',ms=8,color=c,fillstyle='none',mew=2);
        ax.plot(ll+1, min_Lyaps[ll],'D',ms=8,color=c);
        ax.plot(ll+1, max_Lyaps[ll],'s',ms=8,color=c);
    mean = ax.errorbar(np.nan,np.nan,yerr=np.nan,fmt='o',lw=3,capsize=8,markeredgewidth=2,color=[0,0,0],label='mean/std');
    handles = [mean,
            matplotlib.lines.Line2D([],[],marker='D',ms=8,color=[0,0,0],linestyle='none'),
           matplotlib.lines.Line2D([],[],marker='s',ms=8,color=[0,0,0],linestyle='none'),
           matplotlib.lines.Line2D([],[],marker='o',ms=8,color=[0,0,0],fillstyle='none',mew=2,linestyle='none')]
    if j == 0:
        ax.legend(handles,('mean/std','min','max','asymp'),ncol=4)
    ax.set_ylabel('$\Lambda_i$')
    ax.set_title('\n M = {}'.format(Ms[j]),fontsize = 13)
    ax.grid()
    if j == (len(Ms)-1):
        ax.set_xlabel('i')
    else:
        ax.get_xaxis().set_ticklabels([])
    ax.set_ylim([-0.75, 0.25])  

    plt.tight_layout
    
#plt.savefig('../figures/fig4.eps')

## DimKY

In [None]:
dimKY_all = [[]]*len(Ms)

for Mi in np.arange(0,len(Ms)):
    M = Ms[Mi]
    Lyaps = Lyaps_all[Mi]
    CLVs = CLVs_all[Mi]
    time_CLVs = time_CLVs_all[Mi]
    
    dimKY = np.empty((CLVs.shape[2]))

    t = 0
    
    start = time.time()
    for tt in time_CLVs:
        Lyaps_t = Lyaps.sel(time = tt)[:10]
        Lyaps_ord = Lyaps_t.sortby(Lyaps_t,ascending=False)
        for ll in np.arange(1,Lyaps_ord.shape[0]):
            S = np.sum(Lyaps_ord[:ll])
            if S<0:
                i_min = ll-1
                break
            elif ll == CLVs.shape[1]:
                i_min = ll-1
    
        dimKY[t] = i_min + np.sum(Lyaps_ord[:i_min])/abs(Lyaps_ord[i_min])
        t += 1
    
    end = time.time()
    elapsed = end-start
    print("Elapsed time: {} sec, M = {}".format(round(elapsed,4),M))
    
    dimKY_all[Mi] = dimKY

In [None]:
## convert to xarray
for Mi in np.arange(0,len(Ms)):
    dimKY_all[Mi] = xr.DataArray(dimKY_all[Mi], coords=[time_CLVs_all[Mi]], dims=['time'])

### Probabilities of positive dimension by state

In [None]:
dimKY_pos_all = np.array(np.zeros(len(Ms)))
dimKY_pos_state1_all = np.array(np.zeros(len(Ms)))
dimKY_pos_state2_all = np.array(np.zeros(len(Ms)))
dimKY_pos_state3_all = np.array(np.zeros(len(Ms)))  

M_labs = [[]]*len(Ms)

for M_ii in np.arange(0,len(Ms)):
    M_labs[M_ii] = 'M = {}'.format(Ms[M_ii])
    
    ## extract dimKY for push forward step
    dimKY = dimKY_all[M_ii]
    
    ## probability dimKY is positive
    dimKY_pos = (np.sum(dimKY>0))/dimKY.shape[0]
    
    ## extract times in each state where dynamics are calculated
    state1_times_CLVs = state_1_times.where(state_1_times == time_CLVs_all[M_ii],drop=True)
    state2_times_CLVs = state_2_times.where(state_2_times == time_CLVs_all[M_ii],drop=True)
    state3_times_CLVs = state_3_times.where(state_3_times == time_CLVs_all[M_ii],drop=True)

    ## given each state, probability dimKY positive 
    dimKY_pos_state1 = np.sum(dimKY.sel(time = state1_times_CLVs)>0)/state1_times_CLVs.shape[0]
    dimKY_pos_state2 = np.sum(dimKY.sel(time = state2_times_CLVs)>0)/state2_times_CLVs.shape[0]
    dimKY_pos_state3 = np.sum(dimKY.sel(time = state3_times_CLVs)>0)/state3_times_CLVs.shape[0]

    ## store values for table
    dimKY_pos_all[M_ii] = dimKY_pos
    dimKY_pos_state1_all[M_ii] = dimKY_pos_state1
    dimKY_pos_state2_all[M_ii] = dimKY_pos_state2
    dimKY_pos_state3_all[M_ii] = dimKY_pos_state3

### Table 3

In [None]:
df = pd.DataFrame(np.round([dimKY_pos_state1_all,dimKY_pos_state2_all,dimKY_pos_state3_all,
                            dimKY_pos_all],4), 
                  index=['P(FTLE > 0 | state 1)',
                           'P(FTLE > 0 | state 2)','P(FTLE > 0 | state 3)','P(FTLE > 0)'],
                  columns=M_labs)

h_styles = [dict(selector="th", props=[("font-size", "12pt")])]
df.style.set_table_styles(h_styles).set_properties(**{'font-size': '14pt'}).format("{:,.3f}")

### Average dimension by state

In [None]:
state_times_collection = [state_1_times, state_2_times, state_3_times]
comp_ind_collection = [comp1_ind, comp2_ind, comp3_ind]

dimKY_state_avg = np.zeros((3,2))
state_inds_long = [[]]*3

for jj in np.arange(0,len(state_times_collection)):
    dimKY_state_avg[jj,0] = np.mean(dimKY_all[0].where(dimKY_all[0].time == state_times_collection[jj],drop=True))
    
    state_inds_long_temp = np.array([],dtype=int)
    
    for ii in comp_ind_collection[jj]:
        if np.all(np.isin(model.time[ii-2:ii+3],state_times_collection[jj])) == True:
            state_inds_long_temp = np.append(state_inds_long_temp,ii)
    
    state_inds_long[jj] = state_inds_long_temp
    dimKY_state_avg[jj,1] = np.mean(dimKY_all[0].where(dimKY_all[0].time == model.time[state_inds_long[jj]],drop=True))
    

### Table 4

In [None]:
df = pd.DataFrame(np.round(dimKY_state_avg,4), 
                  index=['state 1','state 2','state 3'],
                  columns=['no filter','5 day filter'])

h_styles = [dict(selector="th", props=[("font-size", "12pt")])]
df.style.set_table_styles(h_styles).set_properties(**{'font-size': '14pt'}).format("{:,.2f}")

## Calculate Alignment

In [None]:
num_CLVs_all = [[]]*len(Ms)
align_all = [[]]*len(Ms)

num_CLV_test = 6

for i in np.arange(0,len(Ms)):
    
    start = time.time()
    
    M = Ms[i]
    CLVs = CLVs_all[i]
    
    num_CLVs = CLVs.shape[2]
    time_CLVs = time_CLVs_all[i]
    align = np.array(np.zeros((num_CLV_test,num_CLV_test,num_CLVs),dtype=np.float))

    for t in np.arange(0,num_CLVs):
        for clvi in np.arange(0,num_CLV_test):
            for clvj in np.arange(clvi+1,num_CLV_test+1):
                align[clvi,clvj-1,t] = abs(np.dot(np.array(CLVs[:,clvi,t]),np.array(CLVs[:,clvj,t])))
    
    num_CLVs_all[i] = num_CLVs
    align_all[i] = align
    
    end = time.time()
    elapsed = (end-start)/60
    print("Elapsed time: {} min, M = {}".format(round(elapsed,4),M))

In [None]:
## convert to xarray
for i in np.arange(0,len(Ms)):
    M = Ms[i]
    CLVs = CLVs_all[i]
    time_CLVs = time_CLVs_all[i]
    align = align_all[i]
    
    align_all[i] = xr.DataArray(align, coords=[np.arange(1,num_CLV_test+1),np.arange(2,num_CLV_test+2), time_CLVs], dims=['CLV_i','CLV_j', 'time'])

### Figure 7

In [None]:
j = 0

fig = plt.figure(figsize=(10,16))
(ax1, ax2, ax3, ax4, ax5, ax6) = fig.subplots(6, 1, gridspec_kw={'height_ratios': [2, 1, 1, 2, 1, 1]})

axes_collect = np.array([[ax1, ax2, ax3],[ax4, ax5, ax6]]).T
plot_titles = ['(a)','(b)']

for pi in np.arange(0,axes_collect.shape[1]):
    axes = axes_collect[:,pi]
    
    axes[0].plot(model.time[comp1_ind], model.weights[comp1_ind,0]*0.6,'ko')
    axes[0].plot(model.time[comp2_ind], model.weights[comp2_ind,1]*0.55,'ks')
    axes[0].plot(model.time[comp3_ind], model.weights[comp3_ind,2]*0.5,'kd')
    axes[0].plot(align_all[j].time,align_all[j][0,0,:].T)
    axes[0].plot(align_all[j].time,align_all[j][1,1,:].T)
    axes[0].plot(align_all[j].time,align_all[j][0,1,:].T)
    axes[0].legend(['state 1','state 2','state 3','$\\theta_{1,2}$','$\\theta_{2,3}$','$\\theta_{1,3}$'])

    for kk in np.arange(0,3):
        axes[1].plot(time_CLVs_all[j],Lyaps_all[j][kk,:],'C{}'.format(kk))
    axes[1].legend(['$\Lambda_1$','$\Lambda_2$','$\Lambda_3$'])

    axes[2].plot(time_CLVs_all[j],dimKY_all[j],'.-')
    axes[2].set_xlabel(plot_titles[pi],fontsize = 20)
    if pi == 0:
        axes[2].legend(['dim_KY'],loc = 'lower right')
    else:
        axes[2].legend(['dim_KY'])
    
    
    for axii in np.arange(0,axes_collect.shape[0]):
        if pi == 0:
            axes[axii].set_xlim([np.datetime64("2012-03-15"),np.datetime64("2012-08-01")])
        else: 
            axes[axii].set_xlim([np.datetime64("1993-11-01"),np.datetime64("1994-03-15")])

fig.tight_layout()

#plt.savefig('../figures/fig7.eps')

## Extracting transitions associated with persistent states 

In [None]:
## Extract transitions associated with long states
char_time = 4

trans_ind_1_long = np.array([],dtype=int)
trans_ind_2_long = np.array([],dtype=int)
trans_ind_3_long = np.array([],dtype=int)

trans_ind_to_1_long = np.array([],dtype=int)
trans_ind_to_2_long = np.array([],dtype=int)
trans_ind_to_3_long = np.array([],dtype=int)

for ii in np.arange(0,trans_ind_all.shape[0]):
    if state_length_all[ii] > char_time:    
        if np.isin(trans_ind_all[ii],trans_ind_1):
            trans_ind_1_long = np.append(trans_ind_1_long,trans_ind_all[ii])
        elif np.isin(trans_ind_all[ii],trans_ind_2):
            trans_ind_2_long = np.append(trans_ind_2_long,trans_ind_all[ii])
        elif np.isin(trans_ind_all[ii],trans_ind_3):
            trans_ind_3_long = np.append(trans_ind_3_long,trans_ind_all[ii])
        else:
            print('error: invalid transition index')
    if state_length_all[ii+1] > char_time:
        if np.isin(trans_ind_all[ii]+1,trans_ind_to_1):
            trans_ind_to_1_long = np.append(trans_ind_to_1_long,trans_ind_all[ii]+1)
        elif np.isin(trans_ind_all[ii]+1,trans_ind_to_2):
            trans_ind_to_2_long = np.append(trans_ind_to_2_long,trans_ind_all[ii]+1)
        elif np.isin(trans_ind_all[ii]+1,trans_ind_to_3):
            trans_ind_to_3_long = np.append(trans_ind_to_3_long,trans_ind_all[ii]+1)
        else:
            print('error: invalid transition index')
        
trans_1_long_times = model.time[trans_ind_1_long]
trans_2_long_times = model.time[trans_ind_2_long]
trans_3_long_times = model.time[trans_ind_3_long]

trans_to_1_long_times = model.time[trans_ind_to_1_long]
trans_to_2_long_times = model.time[trans_ind_to_2_long]
trans_to_3_long_times = model.time[trans_ind_to_3_long]

In [None]:
## categorize by specific transition
trans_ind_1_to_2_long = np.array([],dtype=int)
trans_ind_1_to_3_long = np.array([],dtype=int)
trans_ind_2_to_1_long = np.array([],dtype=int)
trans_ind_2_to_3_long = np.array([],dtype=int)
trans_ind_3_to_1_long = np.array([],dtype=int)
trans_ind_3_to_2_long = np.array([],dtype=int)

for ti in trans_ind_1_long:
    if np.isin(ti+1,trans_ind_to_2_long):
        trans_ind_1_to_2_long = np.append(trans_ind_1_to_2_long,ti)
    elif np.isin(ti+1,trans_ind_to_3_long):
        trans_ind_1_to_3_long = np.append(trans_ind_1_to_3_long,ti)

for ti in trans_ind_2_long:
    if np.isin(ti+1,trans_ind_to_1_long):
        trans_ind_2_to_1_long = np.append(trans_ind_2_to_1_long,ti)
    elif np.isin(ti+1,trans_ind_to_3_long):
        trans_ind_2_to_3_long = np.append(trans_ind_2_to_3_long,ti)

for ti in trans_ind_3_long:
    if np.isin(ti+1,trans_ind_to_1_long):
        trans_ind_3_to_1_long = np.append(trans_ind_3_to_1_long,ti)
    elif np.isin(ti+1,trans_ind_to_2_long):
        trans_ind_3_to_2_long = np.append(trans_ind_3_to_2_long,ti)

trans_1_to_2_long_times = model.time[trans_ind_1_to_2_long]
trans_1_to_3_long_times = model.time[trans_ind_1_to_3_long]
trans_2_to_1_long_times = model.time[trans_ind_2_to_1_long]
trans_2_to_3_long_times = model.time[trans_ind_2_to_3_long]
trans_3_to_1_long_times = model.time[trans_ind_3_to_1_long]
trans_3_to_2_long_times = model.time[trans_ind_3_to_2_long]

## Extracting alignment behaviour associated with transitions

In [None]:
## select push forward step (here M=3)
Mi = 0
M = Ms[Mi]
align = align_all[Mi]
CLVs = CLVs_all[Mi]
time_CLVs = time_CLVs_all[Mi]

In [None]:
## extract alignment for days around transitions
start_ind = 5
end_ind = -5

trans_1_to_2_inds_CLVs = trans_ind_1_to_2_long[np.where(trans_1_to_2_long_times.isin(time_CLVs[start_ind:end_ind]))[0]]
align_trans_1_to_2 = np.array(np.zeros((num_CLV_test,num_CLV_test,trans_1_to_2_inds_CLVs.shape[0],10),dtype=np.float))

trans_1_to_3_inds_CLVs = trans_ind_1_to_3_long[np.where(trans_1_to_3_long_times.isin(time_CLVs[start_ind:end_ind]))[0]]
align_trans_1_to_3 = np.array(np.zeros((num_CLV_test,num_CLV_test,trans_1_to_3_inds_CLVs.shape[0],10),dtype=np.float))

trans_2_to_1_inds_CLVs = trans_ind_2_to_1_long[np.where(trans_2_to_1_long_times.isin(time_CLVs[start_ind:end_ind]))[0]]
align_trans_2_to_1 = np.array(np.zeros((num_CLV_test,num_CLV_test,trans_2_to_1_inds_CLVs.shape[0],10),dtype=np.float))

trans_2_to_3_inds_CLVs = trans_ind_2_to_3_long[np.where(trans_2_to_3_long_times.isin(time_CLVs[start_ind:end_ind]))[0]]
align_trans_2_to_3 = np.array(np.zeros((num_CLV_test,num_CLV_test,trans_2_to_3_inds_CLVs.shape[0],10),dtype=np.float))

trans_3_to_1_inds_CLVs = trans_ind_3_to_1_long[np.where(trans_3_to_1_long_times.isin(time_CLVs[start_ind:end_ind]))[0]]
align_trans_3_to_1 = np.array(np.zeros((num_CLV_test,num_CLV_test,trans_3_to_1_inds_CLVs.shape[0],10),dtype=np.float))

trans_3_to_2_inds_CLVs = trans_ind_3_to_2_long[np.where(trans_3_to_2_long_times.isin(time_CLVs[start_ind:end_ind]))[0]]
align_trans_3_to_2 = np.array(np.zeros((num_CLV_test,num_CLV_test,trans_3_to_2_inds_CLVs.shape[0],10),dtype=np.float))


for dd in np.arange(end_ind,start_ind):
    align_trans_1_to_2[:,:,:,dd-end_ind] = align.sel(time = model.time[trans_1_to_2_inds_CLVs-dd])
    align_trans_1_to_3[:,:,:,dd-end_ind] = align.sel(time = model.time[trans_1_to_3_inds_CLVs-dd])
    align_trans_2_to_1[:,:,:,dd-end_ind] = align.sel(time = model.time[trans_2_to_1_inds_CLVs-dd])
    align_trans_2_to_3[:,:,:,dd-end_ind] = align.sel(time = model.time[trans_2_to_3_inds_CLVs-dd])
    align_trans_3_to_1[:,:,:,dd-end_ind] = align.sel(time = model.time[trans_3_to_1_inds_CLVs-dd])
    align_trans_3_to_2[:,:,:,dd-end_ind] = align.sel(time = model.time[trans_3_to_2_inds_CLVs-dd])

### Figure 8

In [None]:
fig = plt.figure(figsize=(10,5))
ax = plt.gca()

align_trans_all = np.concatenate([align_trans_1_to_2, align_trans_1_to_3, 
                                  align_trans_2_to_1, align_trans_2_to_3,
                                  align_trans_3_to_1, align_trans_3_to_2],axis=2)

data = np.concatenate([align_trans_all[0,0,:,:],align_trans_all[1,1,:,:],align_trans_all[0,1,:,:]],axis=0)

data = np.reshape(data,(data.shape[0]*data.shape[1],1))

pair_labs = np.expand_dims(np.repeat(np.concatenate([np.repeat('$\\theta_{1,2}$',align_trans_all.shape[2]),
            np.repeat('$\\theta_{2,3}$',align_trans_all.shape[2]),
            np.repeat('$\\theta_{1,3}$',align_trans_all.shape[2])],axis=0),align_trans_all.shape[3],axis=0),axis=1)

day_labs = np.expand_dims(np.tile(np.flip(np.arange(end_ind+1,start_ind+1)),align_trans_all.shape[2]*3),axis=1)

## create the pandas DataFrame 
df = pd.DataFrame(np.concatenate([data,pair_labs,day_labs],axis=1),columns = ['alignment','pair','day']) 
df['alignment'] = pd.to_numeric(df['alignment'])
df['day'] = pd.to_numeric(df['day'])
    
sns.boxplot(x = 'day', y = 'alignment', hue = 'pair', data=df)

ax.set_xticklabels(np.arange(end_ind+1,start_ind+1))
ax.set_title('all transitions',fontsize=12)
ax.legend(loc='upper right')
plt.grid()
plt.tight_layout()

#plt.savefig('../figures/fig8.eps') 

### Figure 9

In [None]:
fig = plt.figure(figsize=(10,12))

align_collections = [align_trans_1_to_2, align_trans_1_to_3, 
                     align_trans_2_to_1, align_trans_2_to_3,
                     align_trans_3_to_1, align_trans_3_to_2]

titles  = ['from 1 to 2 ({} samples)'.format(trans_1_to_2_inds_CLVs.shape[0]),
           'from 1 to 3 ({} samples)'.format(trans_1_to_3_inds_CLVs.shape[0]),
           'from 2 to 1 ({} samples)'.format(trans_2_to_1_inds_CLVs.shape[0]),
           'from 2 to 3 ({} samples)'.format(trans_2_to_3_inds_CLVs.shape[0]),
           'from 3 to 1 ({} samples)'.format(trans_3_to_1_inds_CLVs.shape[0]),
           'from 3 to 2 ({} samples)'.format(trans_3_to_2_inds_CLVs.shape[0])]

for j in np.arange(0,6):
    ax = fig.add_subplot(6,1,j+1)
    align_to_plot = align_collections[j]
    
    p1 = ax.plot(np.flip(np.arange(end_ind+1,start_ind+1)),align_to_plot[0,0,:,:].T,color='C0')
    p2 = ax.plot(np.flip(np.arange(end_ind+1,start_ind+1)),align_to_plot[1,1,:,:].T,color='C1')
    p3 = ax.plot(np.flip(np.arange(end_ind+1,start_ind+1)),align_to_plot[0,1,:,:].T,color='C2')
    
    #ax.set_xticklabels(np.arange(end_ind+1,start_ind+1))
    #ax.set_ylabel('$\\theta_{{{},{}}}$'.format(CLV,CLV_ind+2))
    ax.set_title(titles[j],fontsize=12)
    ax.legend([p1[0],p2[0],p3[0]],['$\\theta_{1,2}$','$\\theta_{2,3}$','$\\theta_{1,3}$'],loc='upper right')
    plt.grid()
    #plt.show()
    plt.tight_layout()

#plt.savefig('../figures/fig9.eps') 

## Projection of CLVs in physical space

### Figure 5

In [None]:
## plot unstable CLVs during persistent states

M_ii = 0
inds = np.arange(0,n_PCs)

fig = plt.figure(figsize=(8,3),constrained_layout=False)

State_titles = ['state 1','state 2','state 3']

pp = 1
for state in np.arange(1,k):
    FTLEs_persist_temp = Lyaps_all[M_ii].sel(time =  model.time[state_inds_long[state][0:1]])
        
    unstable_ind = np.where(FTLEs_persist_temp>0)[0][0]
        
    CLV_persist_temp = CLVs_all[M_ii].sel(CLV = unstable_ind+1, time =  model.time[state_inds_long[state][0:1]])[inds,:]
    
    ### CLV has arbitrary direction
    ### to keep patterns consistent, manually change the direction where necessary:
    ###    CLV_persist_temp = -1*CLV_persist_temp
    
    CLV_persist_temp = -1*CLV_persist_temp
    
    CLV_persist_comp = np.mean(np.matmul(CLV_persist_temp.values.T,
                               eofs.eofs.loc[0:19,500,:,:].values.transpose(1,0,2)).transpose(0,2,1),axis=2)
    
    ax = fig.add_subplot(1, 2, pp, projection=ccrs.Orthographic(central_longitude=0.0,central_latitude=90.0))
    ax.set_global()
    lon, lat = np.meshgrid(lons[100:], lats[0:37])
    fill = ax.pcolor(lons[100:-1],lats[0:37],CLV_persist_comp,
                     transform=ccrs.PlateCarree(), cmap='PRGn',vmin=-0.05,vmax=0.05)

    ax.set_title(State_titles[state] + ' CLV ' + str(unstable_ind+1))
    ax.coastlines()

    plt.tight_layout()
                                                                            
    pp += 1
    
#plt.savefig('../figures/fig5.pdf')

### Figures D1-D6

In [None]:
## plot transitions associated with persistent states
M_ii = 0

theta_labs = ['$\\theta_{1,2}$','$\\theta_{2,3}$','$\\theta_{1,3}$']

trans_ind_collect = [trans_1_to_2_inds_CLVs, trans_1_to_3_inds_CLVs,
                     trans_2_to_1_inds_CLVs, trans_2_to_3_inds_CLVs,
                     trans_3_to_1_inds_CLVs, trans_3_to_2_inds_CLVs]

titles  = ['from 1 to 2','from 1 to 3','from 2 to 1',
           'from 2 to 3','from 3 to 1','from 3 to 2']

for pi in np.arange(0,len(trans_ind_collect)):
    
    fig = plt.figure(figsize=(15,12),constrained_layout=False)
    gs = matplotlib.gridspec.GridSpec(8, 6)

    trans_ind = trans_ind_collect[pi][3:4]

    Lyap_i = Lyaps_all[M_ii].sel(time = model.time[trans_ind[0]-2:trans_ind[0]+4])
        
    align_i = align_all[M_ii].sel(time = model.time[trans_ind[0]-2:trans_ind[0]+4])
        

    ax = fig.add_subplot(gs[0, :])

    ax.plot(align_i.time,align_i[0,0,:].T,'.-')
    ax.plot(align_i.time,align_i[1,1,:].T,'.-')
    ax.plot(align_i.time,align_i[0,1,:].T,'.-')
    ax.legend(['$\\theta_{1,2}$','$\\theta_{2,3}$','$\\theta_{1,3}$'],loc='center right',
                bbox_to_anchor=(1.02, 0.5))

    ax2 = fig.add_subplot(gs[1, :])
    ax2.plot(Lyap_i.time,Lyap_i[kk,:]*0,'k')
    p1 = ax2.plot(Lyap_i.time,Lyap_i[0,:],'C0.-')
    p2 = ax2.plot(Lyap_i.time,Lyap_i[1,:],'C1.-')
    ax2.legend([p1[0],p2[0]],['$\Lambda_1$','$\Lambda_2$'],loc='center right',
            bbox_to_anchor=(1.015, 0.5))

    for j in np.arange(0,3):
        pp = 0
        for dd in np.arange(-2,4):
            inds = np.arange(0,n_PCs)
        
            CLV_trans_temp = CLVs_all[M_ii].sel(CLV = j+1, time = model.time[trans_ind+dd])[inds,:]
    
            ### CLV has arbitrary direction
            ### to keep patterns consistent, manually change the direction where necessary:
            ###    CLV_trans_temp = -1*CLV_trans_temp
    
            CLV_trans_comp = np.mean(np.matmul(CLV_trans_temp.values.T,
                                  eofs.eofs.loc[0:19,500,:,:].values.transpose(1,0,2)).transpose(0,2,1),axis=2)
    
            ax = fig.add_subplot(gs[j*2+2:j*2+4, pp], projection=ccrs.Orthographic(central_longitude=0.0,central_latitude=90.0))
            ax.set_global()
            lon, lat = np.meshgrid(lons[100:], lats[0:37])
            fill = ax.pcolor(lons[100:-1],lats[0:37],CLV_trans_comp,
                       transform=ccrs.PlateCarree(), cmap='PRGn',vmin=-0.05,vmax=0.05)

            ax.set_title('CLV ' + str(j+1) + ', day ' + str(dd))
            ax.coastlines()
            

            plt.tight_layout()
        
            pp +=1
        
    
    plt.title(titles[pi])
        
    #plt.savefig('../figures/figC{}.pdf'.format(pi))

### Figure 6

In [None]:
## choose unstable patterns from above plots
inds = np.arange(0,n_PCs)

fig = plt.figure(figsize=(10,4),constrained_layout=False)
 
## selected by inspection of above plots
trans_ind_ex = [trans_2_to_1_inds_CLVs, trans_2_to_3_inds_CLVs,
                trans_3_to_1_inds_CLVs, trans_3_to_2_inds_CLVs]

day_ind_ex = [1,1,2,1]

CLV_ind_ex = [1,2,2,1]

titles = ['A','B','C','D']

pp = 1
for ti in np.arange(0,len(trans_ind_ex)):
    ## check that FTCLE is positive
    FTLEs_trans_temp = Lyaps_all[M_ii].sel(FTLE = CLV_ind_ex[ti],
                                           time = model.time[trans_ind_ex[ti][3:4]+day_ind_ex[ti]])
    
    if FTLEs_trans_temp<0:
        print('FTCLE {} is negative on '.format(CLV_ind_ex[ti]) + 
              np.datetime_as_string(model.time[trans_ind_ex[ti][3:4]+day_ind_ex[ti]][0].values, unit='D'))
        break
        
    CLV_trans_temp = CLVs_all[M_ii].sel(CLV = CLV_ind_ex[ti], 
                                          time = model.time[trans_ind_ex[ti][3:4]+day_ind_ex[ti]])[inds,:]
    
    ### CLV has arbitrary direction
    ### to keep patterns consistent, manually change the direction where necessary:
    ###    CLV_trans_temp = -1*CLV_trans_temp
    
    if ti == 0 or ti == 1 or ti == 2:
        CLV_trans_temp = -1*CLV_trans_temp
    
    CLV_persist_comp = np.mean(np.matmul(CLV_trans_temp.values.T,
                               eofs.eofs.loc[0:19,500,:,:].values.transpose(1,0,2)).transpose(0,2,1),axis=2)
    
    ax = fig.add_subplot(1, 4, pp, projection=ccrs.Orthographic(central_longitude=0.0,central_latitude=90.0))
    ax.set_global()
    lon, lat = np.meshgrid(lons[100:], lats[0:37])
    fill = ax.pcolor(lons[100:-1],lats[0:37],CLV_persist_comp,
                     transform=ccrs.PlateCarree(), cmap='PRGn',vmin=-0.05,vmax=0.05)

    ax.set_title(titles[ti])
    ax.coastlines()

    plt.tight_layout()
                                                                            
    pp += 1

#plt.savefig('../figures/fig6.pdf') 

### Table 5

In [None]:
## selected by inspection of above plots

trans_ind_ex = [trans_2_to_1_inds_CLVs, trans_2_to_3_inds_CLVs, trans_2_to_3_inds_CLVs,
                trans_2_to_3_inds_CLVs, trans_3_to_1_inds_CLVs, trans_2_to_3_inds_CLVs,
                trans_3_to_2_inds_CLVs]


day_ind_ex = [1,1,2,1,2,2,1]

CLV_ind_ex = [1,1,2,2,2,1,1]

pattern_ex = ['A','A','A','B','C','D','D']

transition_ex = ['2 to 1','2 to 3','2 to 3','2 to 3',
                 '3 to 1','2 to 3','3 to 2']

FTLEs_ex = [[]]*7

for ti in np.arange(0,len(trans_ind_ex)):
    FTLEs_temp = Lyaps_all[M_ii].sel(FTLE = CLV_ind_ex[ti], time = model.time[trans_ind_ex[ti][3:4]+day_ind_ex[ti]])
    
    if FTLEs_temp < 0:
        print('FTCLE {} is negative on '.format(CLV_ind_ex[ti]) + 
              np.datetime_as_string(model.time[trans_ind_ex[ti][3:4]+day_ind_ex[ti]][0].values, unit='D'))
        break
    
    FTLEs_ex[ti] = np.round(FTLEs_temp[0].values,3)
    
df = pd.DataFrame([pattern_ex,transition_ex,day_ind_ex,CLV_ind_ex,FTLEs_ex],
                  index = ['pattern','transition','day','CLV','FTCLE'], columns = [' ']*7)


df.T

### Figure 10

In [None]:
fig = plt.figure(figsize=(12,8))
(ax1, ax2, ax3, ax4) = fig.subplots(4, 2, gridspec_kw={'width_ratios': [2, 1]})

axes_collect = np.array([ax1, ax2, ax3, ax4])

pxx_all = [[]]*len(Ms)

for j in np.arange(0,len(Ms)):
    axes = axes_collect[j]

    axes[0].plot(align_all[j].time,align_all[j][0,0,:].T)
    axes[0].set_xlim([np.datetime64("2010-01-01"),np.datetime64("2017-06-01")])
    axes[0].set_title('M = {}'.format(Ms[j]),fontsize = 13)
    
    freq, pxx = scipy.signal.welch(align_all[j][0,0,:],nperseg=4084)#,detrend='linear')
    peaks_2sd = scipy.signal.find_peaks(pxx,threshold=2*np.std(pxx))[0]
    peaks_3sd = scipy.signal.find_peaks(pxx,threshold=3*np.std(pxx))[0]
    
    axes[1].loglog(freq,pxx/(np.sum(pxx)))
    axes[1].loglog(freq[peaks_2sd],pxx[peaks_2sd]/(np.sum(pxx)),'r.',mew=2,ms=6)
    axes[1].loglog(freq[peaks_3sd],pxx[peaks_3sd]/(np.sum(pxx)),'rx',mew=3,ms=8)
    axes[1].set_xlim([5e-4,5e-1])
    axes[1].set_title('M = {}'.format(Ms[j]),fontsize = 13)
    
    if j == len(Ms)-1:
        axes[0].set_xlabel('(a)', fontsize=20)
        axes[1].set_xlabel('(b)', fontsize=20)
    
    pxx_all[j] = pxx
    
plt.tight_layout()

#plt.savefig('../figures/fig10.eps')

## Alignment and transition index

In [None]:
## Calculate transition index with window equal to push forward (here M=50)
j = -1
window = Ms[j]
y = 0
end_ind = model.time.shape[0]-window-5
trans_index = np.empty(end_ind)

for y_ind in np.arange(0,end_ind):
    trans_index[y] = np.count_nonzero(np.isin(model.time[5+y_ind:5+y_ind+window],model.time[trans_ind_all]))/window
    
    y += 1

## convert to xarray
trans_index = xr.DataArray(trans_index, coords=[model.time[5+window:5+end_ind+window]], dims=['time'])

### Figure 11

In [None]:
j = -1
fig = plt.figure(figsize=(10,2.5))
ax = plt.gca()
ax.plot(align_all[j].time,align_all[j][0,0,:].T)
ax.plot(model.time[window+5:], trans_index)
ax.set_xlim([np.datetime64("2009-01-01"),np.datetime64("2018-12-31")])
plt.title('Transition index vs alignment (M={})'.format(window),fontsize=13)
plt.legend(['$\\theta_{1,2}$','transition index'],loc='lower right')
#plt.show()
plt.tight_layout()

#plt.savefig('../figures/fig11.pdf')

In [None]:
def calculate_lagged_correlations(y1, y2, nlags=40):
    """Calculate lagged correlations of one variable with another."""
    start_time = max(y1.index.min(), y2.index.min())
    end_time = min(y1.index.max(), y2.index.max())
    nonlagged_da = y1.where((y1.index >= start_time) & (y1.index <= end_time)).dropna()
    lagged_da = y2.where((y2.index >= start_time) & (y2.index <= end_time)).dropna()
    mask = np.logical_or(np.isnan(nonlagged_da.to_numpy()), np.isnan(lagged_da.to_numpy()))
    nonlagged_values = np.ma.masked_array(nonlagged_da.to_numpy(), mask=mask)
    lagged_values = np.ma.masked_array(lagged_da.to_numpy(), mask=mask)
    correlations = np.empty((nlags,))
    correlations[0] = stats.pearsonr(nonlagged_values, lagged_values)[0]
    for i in range(1, nlags):
        correlations[i] = stats.pearsonr(nonlagged_values[i:], lagged_values[:-i])[0]
    return correlations

In [None]:
corrs = calculate_lagged_correlations(trans_index.to_pandas(),align_all[j][0,0,:].to_pandas(),nlags=365)

In [None]:
## print max correlation and lag
max_corr_ind = np.where(abs(corrs) == np.max(abs(corrs)))[0][0]

print('Max correlation', round(corrs[max_corr_ind],2),'for',max_corr_ind,'day lag')

## Alignment by season

In [None]:
## separate alignemnt by season
seasons = ['DJF','MAM','JJA','SON']
align_season = [[]]*4
align_season_avg = [[]]*4

ii = 0
for si in seasons:
    align_season[ii] = align_all[-1].where(time_CLVs_all[-1].dt.season==si,drop=True)
    align_season_avg[ii] = np.mean(align_season[ii],axis=2)
    ii += 1

### Figure 12

In [None]:
fig = plt.figure(figsize=[12, 12])
matplotlib.gridspec.GridSpec(2,1)


for axi in np.arange(0,4):
    mask =  np.tri(align_season_avg[axi].shape[0], k=-1)
    align_season_avg[axi] = np.ma.array(align_season_avg[axi], mask=mask)         
  
    a1 = plt.subplot2grid((2,2), (int(np.floor(axi/2)),np.mod(axi,2)), colspan=1, rowspan=1)
    
    n_levels = 10
    cmap = matplotlib.cm.get_cmap('seismic',n_levels)
    cmap_opaque = np.array(np.zeros((n_levels,4),dtype=np.float))
    for c in np.arange(0,n_levels):
        cmap_opaque[c,:] = np.array(cmap(c))
        cmap_opaque[c,3] = 0.7
    cmap = matplotlib.colors.ListedColormap(cmap_opaque)
    cmap.set_bad('w',1.)
    
    cax = a1.matshow(align_season_avg[axi].T,cmap=cmap,vmin=0,vmax=1)
    a1.set_xticklabels(np.arange(0,7))
    a1.set_yticklabels(np.arange(1,8))
    a1.set_title(seasons[axi] + '\n',fontsize=15)

cbar_ax = fig.add_axes([0.93, 0.13, 0.03, 0.75])
fig.colorbar(cax, cax=cbar_ax)

#plt.savefig('../figures/fig12.png')