# SST Empirical Orthogonal Function Analysis

This notebook will import SST data from a source, select it to be inside the scope of the project and do EOF analysis to determine Central Atlantic Niño Index and Eastern Atlantic Niño Index.

# Imports

In [17]:
import numpy as np
import pandas as pd
import xarray as xr
import dask
from dask_jobqueue import PBSCluster
from dask.distributed import Client
import xeofs as xe
import glob
from geocat.viz import util as gvutil
import cartopy.crs as ccrs
import cartopy.feature as cf
import cartopy.util as cutil
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import colormaps as cm
import matplotlib
import matplotlib.pyplot as plt

# PBSCluster

In [3]:
# Create a PBS cluster object
cluster = PBSCluster(account='P93300313',
                     job_name='ATLN-ENSO-CESMLE2',
                     cores=1,
                     memory='8GiB',
                     processes=1,
                     walltime='02:00:00',
                     queue='casper',
                     interface='ext',
                     n_workers=1)

# dont scale many workers unless using LE
# cluster.scale(10)

client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/acruz/proxy/8787/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/acruz/proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://128.117.208.186:35985,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/acruz/proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [4]:
cluster.scale(2)

In [5]:
# cluster.shutdow()
cluster.workers

{'PBSCluster-0': <dask_jobqueue.pbs.PBSJob: status=running>}

# Useful Functions

In [20]:
def ds_map(ds_to_plt, bounds=[20, -60, 10, -10], name='figure'):
    fig, ax = plt.subplots(1, 1,
                           subplot_kw={'projection': ccrs.PlateCarree()})
    fig.subplots_adjust(hspace=0, wspace=0, top=0.925, left=0.1)
    cbar_ax = fig.add_axes([0, 0, 0.1, 0.1])
    cdat, clon = cutil.add_cyclic_point(ds_to_plt, ds_to_plt.longitude)

    ax.set_title(name)
    lat_ticks = np.arange(bounds[3], bounds[2], 5)
    lon_ticks = np.arange(bounds[1], bounds[0], 10)
    ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
    ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
    lon_formatter = LongitudeFormatter(zero_direction_label=True)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)
    ax.add_feature(cf.LAND)

    
    def resize_colobar(event):
        plt.draw()
        posn = ax.get_position()
        cbar_ax.set_position([posn.x0 + posn.width + 0.01, posn.y0,
                              0.04, posn.height])
        
    ax.set_extent(bounds, ccrs.PlateCarree())
    sst_contour = ax.contourf(clon, ds_to_plt.latitude, cdat,
                              levels=np.arange(-0.4, 0.5, 0.05),
                              # levels=40,
                              transform=ccrs.PlateCarree(), cmap='inferno', extend='both')
    fig.canvas.mpl_connect('resize_event', resize_colobar)
    ax.coastlines()
    plt.colorbar(sst_contour, cax=cbar_ax)
    resize_colobar(None)
    # plt.savefig(name, dpi=300)
    plt.show()


def detrend_dim(da, dim, deg=1):
    # detrend along a single dimension
    p = da.polyfit(dim=dim, deg=deg)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)
    return da - fit


def index_plot(ds1, name1='', threshold=0.5):
    lim = 4 * threshold
    fig, ax = plt.subplots(figsize=(12, 6))
    
    ax.plot(ds1.time, ds1, color='black', label=name1)
    gvutil.add_major_minor_ticks(ax, x_minor_per_major=15, y_minor_per_major=3, labelsize=20)
    
    gvutil.set_axes_limits_and_ticks(ax, ylim=(-1*lim, lim))
    ax.fill_between(ds1.time, ds1, y2=-threshold,
                    where=ds1 < -threshold, color='blue', interpolate=True)
    ax.fill_between(ds1.time, ds1, y2=threshold,
                    where=ds1> threshold, color='red', interpolate=True)
    plt.title(f'{name1}')
    ax.set_xlabel('year', fontsize=24)
    plt.grid()
    plt.show()

# Imports

In [7]:
# start with hadisst files
ds = xr.open_dataset('/glade/campaign/collections/rda/data/d277003/HadISST_sst.nc.gz').compute()
# # CESM SST path
# files = glob.glob('/glade/campaign/cgd/cesm/CESM2-LE/timeseries/atm/proc/tseries/month_1/SST/*.nc')
# CESM2_ds = xr.openmf_dataset(files)
ds

# Select data

In [8]:
# selected dates
dates = xr.date_range(start='1920-01-16', end='2025-01-16', freq='1ME')
# dates = xr.date_range(start='1970-01-16', end='2022-01-16', freq='1ME')
ds = ds.sel(time=dates, method='nearest')
# ATL area
ATL_hadisst = ds['sst'].sel(latitude=slice(10, -10), longitude=slice(-60, 20)).compute()

# Anomalies

In [9]:
# the mistake was missing the groupby function
# all year
ATL_clim = ATL_hadisst.groupby(ATL_hadisst['time'].dt.month).mean(dim='time').compute()
ATL_anom_pm = ATL_hadisst.groupby(ATL_hadisst['time'].dt.month) - ATL_clim

# all anomaly fields were linearly detrended
ATL_anom_dtrend = detrend_dim(ATL_anom_pm, dim='time')

# EOF

In [10]:
model = xe.single.EOF(n_modes=3, use_coslat=False)
# all year climatologies of all data
model.fit(ATL_anom_dtrend, dim='time')
components = model.components()
xplained_var = model.explained_variance_ratio().values

# PC and EOF normalizing and scaling

In [11]:
# scale by PC std
# nomalized in this package is L2 norm not STD
PCs = model.scores(normalized=False)

# normalized by l2norm true as test
# PCs = model.scores()

pc_std = PCs.std()
pc_mean = PCs.mean()

normalized_PCs = (PCs - pc_mean)/ pc_std
scaled_EOF = components * pc_std

In [14]:
PC1 = normalized_PCs.sel(mode=1)
PC2 = normalized_PCs.sel(mode=2)
PC3 = normalized_PCs.sel(mode=3)

# EAN and CAN combined EOF patterns

In [15]:
EATLs = (scaled_EOF.sel(mode=1) + scaled_EOF.sel(mode=3)) / (2 ** 0.5)
CATLs = (scaled_EOF.sel(mode=1) - scaled_EOF.sel(mode=3)) / (2 ** 0.5)

# CANI and EANI

In [23]:
EANI = (PC1 + PC3) / (2 ** 0.5)
CANI = (PC1 - PC3) / (2 ** 0.5)
# using 5 months, but maybe tweak to another according to literature
EANI_roll = EANI.rolling(time=5, center=True).mean()
CANI_roll = CANI.rolling(time=5, center=True).mean()

# Variability

In [22]:
EANI_var = EANI.rolling(time=60).var()
CANI_var = CANI.rolling(time=60).var()
CvE_r = CANI_var / EANI_var

# Plotting

## EOFs

In [None]:
j=0
for i in scaled_EOF['mode'].values:
    mode = scaled_EOF.sel(mode=i)
    ds_map(mode, name=f'EOF{i} scaled by PCs STD {xplained_var[j] * 100 }%')
    j+=1

## PCs

In [None]:
plt.plot(PC1.time, PC1, label='PC1')
plt.plot(PC2.time, PC2, label='PC2')
plt.plot(PC3.time, PC3, label='PC3')
plt.legend()
plt.grid()
plt.show()

## CANI EANI patterns

In [None]:
ds_map(EATLs, name=f'EAN scaled pattern')
ds_map(CATLs, name=f'CAN scaled pattern')

## Variability

In [None]:
plt.figure(figsize=(12, 4))
plt.plot(EANI_var.time[:], EANI_var[:], label='EANI Variance', color='blue', linestyle='--')
plt.plot(CANI_var.time[:], CANI_var[:], label='CANI Variance', color='orangered', linestyle='--')
plt.plot(CvE_r.time, CvE_r, label='Variance ratio C / E', color='black')
plt.axline((0, 1), slope=0, color='gray', linestyle='--')
plt.legend()
plt.show()