# Intro

This notebook is to calcuate ATLN and ENSO indices and get a list of years for the phase of each on the CESM LENS2 project

how the models show the pattern differences

how variability has been changing between 

regression with precip GPCP version 2.3 1979-2020

# Imports

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
import cftime
import dask
from dask_jobqueue import PBSCluster
from dask.distributed import Client
import xeofs as xe
import glob
from geocat.viz import util as gvutil
import cartopy.crs as ccrs
import cartopy.feature as cf
import cartopy.util as cutil
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import colormaps as cm
import xarray_regrid

# PBSClustter

In [None]:
# Create a PBS cluster object
cluster = PBSCluster(account='P93300313',
                     job_name='ATLN-ENSO-CESMLE2',
                     cores=1,
                     memory='8GiB',
                     processes=1,
                     walltime='02:00:00',
                     queue='casper',
                     interface='ext',
                     n_workers=1)

# dont scale many workers unless using LE
# cluster.scale(10)

client = Client(cluster)
client

In [None]:
cluster.scale(2)

In [None]:
# client.shutdown()
cluster.workers

# Useful functions

In [None]:
def ds_map(ds_to_plt, bounds=[20, -60, 10, -10], name='figure'):
    fig, ax = plt.subplots(1, 1,
                           subplot_kw={'projection': ccrs.PlateCarree()})
    fig.subplots_adjust(hspace=0, wspace=0, top=0.925, left=0.1)
    cbar_ax = fig.add_axes([0, 0, 0.1, 0.1])
    cdat, clon = cutil.add_cyclic_point(ds_to_plt, ds_to_plt.longitude)

    ax.set_title(name)
    lat_ticks = np.arange(bounds[3], bounds[2], 5)
    lon_ticks = np.arange(bounds[1], bounds[0], 10)
    ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
    ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
    lon_formatter = LongitudeFormatter(zero_direction_label=True)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)
    ax.add_feature(cf.LAND)

    
    def resize_colobar(event):
        plt.draw()
        posn = ax.get_position()
        cbar_ax.set_position([posn.x0 + posn.width + 0.01, posn.y0,
                              0.04, posn.height])
        
    ax.set_extent(bounds, ccrs.PlateCarree())
    sst_contour = ax.contourf(clon, ds_to_plt.latitude, cdat,
                              levels=np.arange(-0.4, 0.5, 0.05),
                              # levels=40,
                              transform=ccrs.PlateCarree(), cmap='inferno', extend='both')
    fig.canvas.mpl_connect('resize_event', resize_colobar)
    ax.coastlines()
    plt.colorbar(sst_contour, cax=cbar_ax)
    resize_colobar(None)
    # plt.savefig(name, dpi=300)
    plt.show()


def detrend_dim(da, dim, deg=1):
    # detrend along a single dimension
    p = da.polyfit(dim=dim, deg=deg)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)
    return da - fit


def detrend(da, dims, deg=1):
    # detrend along multiple dimensions
    # only valid for linear detrending (deg=1)
    da_detrended = da
    for dim in dims:
        da_detrended = detrend_dim(da_detrended, dim, deg=deg)
    return da_detrended


def index_plot(ds1, name1='', threshold=0.5):
    lim = 4 * threshold
    fig, ax = plt.subplots(figsize=(12, 6))
    
    ax.plot(ds1.time, ds1, color='black', label=name1)
    gvutil.add_major_minor_ticks(ax, x_minor_per_major=15, y_minor_per_major=3, labelsize=20)
    
    gvutil.set_axes_limits_and_ticks(ax, ylim=(-1*lim, lim))
    ax.fill_between(ds1.time, ds1, y2=-threshold,
                    where=ds1 < -threshold, color='blue', interpolate=True)
    ax.fill_between(ds1.time, ds1, y2=threshold,
                    where=ds1> threshold, color='red', interpolate=True)
    plt.title(f'{name1}')
    ax.set_xlabel('year', fontsize=24)
    plt.grid()
    plt.show()

# Data Imports

## SST

In [None]:
# start with hadisst files
ds = xr.open_dataset('/glade/campaign/collections/rda/data/d277003/HadISST_sst.nc.gz').compute()
# # CESM SST path
# files = glob.glob('/glade/campaign/cgd/cesm/CESM2-LE/timeseries/atm/proc/tseries/month_1/SST/*.nc')
# CESM2_ds = xr.openmf_dataset(files)
ds

## Precipitation

In [None]:
files = glob.glob('/glade/campaign/collections/rda/data/d728008/gpcp_v3.2_monthly/*/*.nc4')
precip = xr.open_mfdataset(files, engine='netcdf4')
precip = precip['sat_gauge_precip']
precip.compute()

# Regrid

In [None]:
precip.rename({'lat': 'latitude', 'lon': 'longitude'})

precip = precip.regrid.linear(ds, time_dim='time')
precip.compute()

# select data

In [None]:
# selected dates
dates = xr.date_range(start='1920-01-16', end='2025-01-16', freq='1ME')
# dates = xr.date_range(start='1970-01-16', end='2022-01-16', freq='1ME')
ds = ds.sel(time=dates, method='nearest')
# select summer
# summer_ds = ds.where(ds['time'].dt.month.isin([6, 7, 8]), drop=True)

In [None]:
# ATL area
ATL_hadisst = ds['sst'].sel(latitude=slice(10, -10), longitude=slice(-60, 20)).compute()

# ENSO34 area
ENSO34_hadisst = ds['sst'].sel(latitude=slice(5, -5), longitude=slice(-170, -120)).compute()

# ATLN and ENSO indices

## ONI Index

In [None]:
# get weighted latitudes
weights = np.cos(np.deg2rad(ENSO34_hadisst.latitude))
weights.name = "weights"

In [None]:
ENSO34_clim = ENSO34_hadisst.groupby(ENSO34_hadisst['time'].dt.month).mean(dim='time').compute()
ENSO34_anom = ENSO34_hadisst.groupby(ENSO34_hadisst['time'].dt.month) - ENSO34_clim
# all anomaly fields were linearly detrended zhang et al
ENSO34_anom_dtrend = detrend_dim(ENSO34_anom, dim='time')

In [None]:
ENSO34_roll = ENSO34_anom_dtrend.rolling(time=3, center=True).mean()
ENSO34_index = ENSO34_roll.mean(('longitude', 'latitude'), skipna=True).compute()

In [None]:
# index plot
fig, ax = plt.subplots(figsize=(12,6))

threshold = 0.5

ax.plot(ENSO34_index.time, ENSO34_index, color='black')
# gvutil.add_major_minor_ticks(ax, x_minor_per_major=15, y_minor_per_major=3, labelsize=20)

# gvutil.set_axes_limits_and_ticks(ax,ylim=(-2., 2.))
ax.fill_between(ENSO34_index.time, ENSO34_index,
                y2=-threshold, where=ENSO34_index < -threshold,
                color='blue', interpolate=True)
ax.fill_between(ENSO34_index.time, ENSO34_index,
                y2=threshold, where=ENSO34_index > threshold,
                color='red', interpolate=True)

plt.title('ONI HADISST1.1')
ax.set_xlabel('year', fontsize=24)
plt.grid()
plt.show()

## clim and anomalies

In [None]:
# the mistake was missing the groupby function
# all year
ATL_clim = ATL_hadisst.groupby(ATL_hadisst['time'].dt.month).mean(dim='time').compute()
ATL_anom_pm = ATL_hadisst.groupby(ATL_hadisst['time'].dt.month) - ATL_clim

# all anomaly fields were linearly detrended
ATL_anom_dtrend = detrend_dim(ATL_anom_pm, dim='time')

## EOFa

In [None]:
model = xe.single.EOF(n_modes=3, use_coslat=False)
# all year climatologies of all data
model.fit(ATL_anom_dtrend, dim='time')

In [None]:
components = model.components()

In [None]:
xplained_var = model.explained_variance_ratio().values

In [None]:
# scale by PC std
# nomalized in this package is L2 norm not STD
PCs = model.scores(normalized=False)

# normalized by l2norm true as test
# PCs = model.scores()

pc_std = PCs.std()
pc_mean = PCs.mean()

normalized_PCs = (PCs - pc_mean)/ pc_std
scaled_EOF = components * pc_std

In [None]:
j=0
for i in scaled_EOF['mode'].values:
    mode = scaled_EOF.sel(mode=i)
    ds_map(mode, name=f'EOF{i} scaled by PCs STD {xplained_var[j] * 100 }%')
    j+=1

In [None]:
PC1 = normalized_PCs.sel(mode=1)
PC2 = normalized_PCs.sel(mode=2)
PC3 = normalized_PCs.sel(mode=3)
# PC4 = normalized_PCs.sel(mode=4)
# PC5 = normalized_PCs.sel(mode=5)

In [None]:
plt.plot(PC1.time, PC1, label='PC1')
plt.plot(PC2.time, PC2, label='PC2')
plt.plot(PC3.time, PC3, label='PC3')
# plt.plot(PC4.time, PC4, label='PC4')
# plt.plot(PC5.time, PC5, label='PC5')
plt.legend()
plt.grid()
plt.show()

## EOF differences

In [None]:
EATL = (components.sel(mode=1) + components.sel(mode=3)) / (2 ** 0.5)
CATL = (components.sel(mode=1) - components.sel(mode=3)) / (2 ** 0.5)

In [None]:
EATLs = (scaled_EOF.sel(mode=1) + scaled_EOF.sel(mode=3)) / (2 ** 0.5)
CATLs = (scaled_EOF.sel(mode=1) - scaled_EOF.sel(mode=3)) / (2 ** 0.5)

In [None]:
# ds_map(EATL, name=f'EAN pattern')
# ds_map(CATL, name=f'CAN pattern')
ds_map(EATLs, name=f'EAN scaled pattern')
ds_map(CATLs, name=f'CAN scaled pattern')

## CANI and EANI

In [None]:
EANI = (PC1 + PC3) / (2 ** 0.5)
CANI = (PC1 - PC3) / (2 ** 0.5)
# ENSO34_roll = ENSO34_anom_dtrend.rolling(time=3, center=True).mean()
# ENSO34_index = ENSO34_roll.mean(('longitude', 'latitude'), skipna=True).compute()
EANI_roll = EANI.rolling(time=3, center=True).mean()
CANI_roll = CANI.rolling(time=3, center=True).mean()

index_plot(EANI_roll, name1='EANI', threshold=1)
index_plot(CANI_roll, name2='CANI', threshold=1)

In [None]:
EANI_s = EANI_roll.where(EANI['time.season'] == 'JJA')
CANI_s = CANI_roll.where(CANI['time.season'] == 'JJA')

EANI_p = np.unique(EANI_s.where(EANI >= 1, drop=True).time.dt.year)
EANI_n = np.unique(EANI_s.where(EANI <= 1, drop=True).time.dt.year)
CANI_p = np.unique(CANI_s.where(CANI >= 1, drop=True).time.dt.year)
CANI_n = np.unique(CANI_s.where(CANI >= 1, drop=True).time.dt.year)

print(f'EANI positive: {EANI_p}')
print(f'EANI negative: {EANI_n}')
print(f'CANI positive: {CANI_p}')
print(f'CANI negative: {CANI_n}')

## Variability

In [None]:
EANI_var = EANI.rolling(time=60).var()
CANI_var = CANI.rolling(time=60).var()
CvE_r = CANI_var / EANI_var

plt.figure(figsize=(12, 4))
plt.plot(EANI_var.time[:], EANI_var[:], label='EANI Variance', color='blue', linestyle='--')
plt.plot(CANI_var.time[:], CANI_var[:], label='CANI Variance', color='orangered', linestyle='--')
plt.plot(CvE_r.time, CvE_r, label='Variance ratio C / E', color='black')
plt.axline((0, 1), slope=0, color='gray', linestyle='--')
plt.legend()
plt.show()

# Precipitation

## JJA

In [None]:
# select summer
summer_precip = precip.where(precip['time'].dt.month.isin([6, 7, 8]), drop=True)
summer_precip.compute()