# Pacific variability - PDO / IPO

original definition by _Mantua et al. (1997)_

> The leading EOF of monthly SST anomalies over the North Pacific (after removing the global mean SST anomaly) and its associated PC time series are termed the Pacific Decadal Oscillation (PDO)

---

0. create xr dataarrays of monthly Pacific data only  (from rect data for high res)
    1. North of 20 N
    2. North of Equator
    3. North of 38S

1. deseasonalize, detrend monthly SST data  (emphasis on consistency with other data analysis and not necessarily original definition)
    - HadISST:
        1. calculate monthly deviations (i.e. average difference) from annual mean, then remove this seasonal cycle
        2. two factor detrending with natural and anthropogenic forcing estimates at each grid point
    - CESM output:
        1. calculate monthly deviations (i.e. average difference) from annual mean, then remove this seasonal cycle
        2. remove quadratic trend at each grid point  (for different time segment)

2. EOF analysis of data

3. create annual index, lowpass filter index

4. analysis
    - spectra
    - regression patterns

In [None]:
import os
import sys
from tqdm import tqdm
import scipy as sp
import numpy as np
import pandas as pd
import xarray as xr
import cmocean
import cartopy
import cartopy.crs as ccrs
import matplotlib
import statsmodels.api as sm
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
matplotlib.rc_file('../rc_file')
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot

In [None]:
sys.path.append("..")
from paths import path_results, path_samoc, path_prace, file_HadISST
from filters import chebychev, lowpass
from regions import boolean_mask, global_ocean, gl_ocean_rect, gl_ocean_low, mask_box_in_region
from timeseries import IterateOutputCESM
from xr_DataArrays import xr_AREA, dll_dims_names
from xr_regression import xr_quadtrend

In [None]:
from ab_derivation_SST import times_ctrl, times_lpd
from ab_derivation_SST import DeriveSST as DS
from bd_analysis_indices import AnalyzeIndex as AI
print(times_ctrl)
print(times_lpd)

# 1. data preparation

- concatenate monthly SST fields into single file:
`DS().generate_monthly_SST_files('ctrl')  # when all SST rect data available` (lpd:  18 mins, years 154-566, 2.3 GB)
- generate yrly avg file for ocn_rect ctrl data:
`DS().generate_yrly_SST_ctrl_rect()`  (43 sec)
- deseasonalize:
`DS().deseasonalize_monthly_data(run)`  (1min 52s for ctrl)

In [None]:
monthly_ctrl = xr.open_dataarray(f'{path_prace}/SST/SST_monthly_ctrl.nc')
monthly_lpd  = xr.open_dataarray(f'{path_prace}/SST/SST_monthly_lpd.nc')  # proper datetime
monthly_had  = xr.open_dataarray(f'{path_prace}/SST/SST_monthly_had.nc')

In [None]:
yrly_ctrl = xr.open_dataarray(f'{path_prace}/SST/SST_yrly_rect_ctrl.nc')
yrly_lpd  = xr.open_dataarray(f'{path_prace}/SST/SST_yrly_lpd.nc')
yrly_had  = xr.open_dataarray(f'{path_prace}/SST/SST_yrly_had.nc')

In [None]:
monthly_ds_ctrl = xr.open_dataarray(f'{path_prace}/SST/SST_monthly_deseasonalized_ctrl.nc')
monthly_ds_lpd  = xr.open_dataarray(f'{path_prace}/SST/SST_monthly_deseasonalized_lpd.nc' , decode_times=False)
monthly_ds_had  = xr.open_dataarray(f'{path_prace}/SST/SST_monthly_deseasonalized_had.nc' , decode_times=False)

In [None]:
monthly_ds_ctrl[900:1200,100,100].plot()
monthly_ctrl[900:1200,100,100].plot()

# 2. deseasonalize and detrend

## detrend

### ctrl/lpd: quadratic detrending

In [None]:
# %%time
# # 24min 3s total
# # 4min 51s for lpd
# for i, da  in enumerate([monthly_ds_ctrl, monthly_ds_lpd]):
#     if i==1:  continue
#     print(i)
#     run = ["ctrl","lpd"][i]
#     times = [times_ctrl, times_lpd][i]
#     for j in tqdm(range(len(times))):
#         if j==0:  continue
#         time = times[j]
        
#         if i==0:     # ctrl
#             print(i, j)
#             da_sel = da.sel(time=slice(time[0], time[1]))
#             if time[0]<100:  # fix February of year 99 by interpolating Jan and Mar, something went wrong with the high to low res interpolation
#                 print('time[0]<100')
#                 da_sel.sel(time=99+1/8, method='nearest').values = (da_sel.sel(time=99+1/24, method='nearest').values\
#                                                                     +da_sel.sel(time=99+5/24, method='nearest').values)/2
#             else:  continue
        
#         elif i==1:   # lpd
#             da_sel = da.sel(time=slice(time[0]*365, time[1]*365))#+1e-4))
#             if len(da_sel.time) not in [1788, 3000]:
#                 da_sel = da.sel(time=slice(time[0]*365, time[1]*365+1e-4))
#             y1, y2 = int(time[0]/365), int(time[0]/365)
            
#         print(time, time[0], time[1], len(da_sel.time), da_sel.time[0].values, da_sel.time[-1].values)
#         (da_sel-xr_quadtrend(da_sel)).to_netcdf(f'{path_prace}/SST/SST_monthly_ds_dt_{run}_{time[0]}_{time[1]}.nc')

### had: two-factor detrending with natural and anthropogenic forcing signal

In [None]:
MMM_natural = xr.open_dataarray(f'{path_samoc}/GMST/CMIP5_natural.nc', decode_times=False)
MMM_anthro  = xr.open_dataarray(f'{path_samoc}/GMST/CMIP5_anthro.nc' , decode_times=False)
monthly_MMM_natural = np.repeat(MMM_natural, 12)
monthly_MMM_anthro  = np.repeat(MMM_anthro , 12)
monthly_MMM_natural = monthly_MMM_natural.assign_coords(time=monthly_had.time)
monthly_MMM_anthro  = monthly_MMM_anthro .assign_coords(time=monthly_had.time)
monthly_MMM_natural.plot()
monthly_MMM_anthro .plot()

In [None]:
# %%time
# # 04:38
# forcings = monthly_MMM_natural.to_dataframe(name='natural').join(
#             monthly_MMM_anthro.to_dataframe(name='anthro'))

# SST_stacked = monthly_ds_had.stack(z=('latitude', 'longitude'))
# ds_anthro   = SST_stacked[0,:].squeeze().copy()
# ds_natural  = SST_stacked[0,:].squeeze().copy()

# # multiple linear regression
# X = sm.add_constant(forcings[['anthro', 'natural']])
# for i, coordinate in tqdm(enumerate(SST_stacked.z)):
#     y = SST_stacked[:, i].values
#     model = sm.OLS(y, X).fit()
#     ds_anthro[i] = model.params['anthro']
#     ds_natural[i] = model.params['natural']

# beta_anthro  = ds_anthro .unstack('z')
# beta_natural = ds_natural.unstack('z')

# ds = xr.merge([{'forcing_anthro': monthly_MMM_anthro}, {'beta_anthro': beta_anthro}])
# ds.to_netcdf(f'{path_prace}/SST/SST_beta_anthro_MMM_monthly_had.nc')

# ds = xr.merge([{'forcing_natural': monthly_MMM_natural}, {'beta_natural':beta_natural}])
# ds.to_netcdf(f'{path_prace}/SST/SST_beta_natural_MMM_monthly_had.nc')

In [None]:
f, ax = plt.subplots(1,2, figsize=(12,5))
beta_natural.plot(ax=ax[0])
beta_anthro.plot(ax=ax[1])
ax[0].set_title('natural')
ax[1].set_title('anthropogenic')

In [None]:
%%time
monthly_ds_dt_had = monthly_ds_had.assign_coords(time=monthly_MMM_anthro.time) \
                    - beta_anthro*monthly_MMM_anthro \
                    - beta_natural*monthly_MMM_natural
monthly_ds_dt_had.to_netcdf(f'{path_prace}/SST/SST_monthly_ds_tfdt_had.nc')

# 3. EOF analysis

## subselect Pacific data

In [None]:
def shift_had(da):
    """ shifts lons to [0,360] to make Pacific contiguous """
    return da.assign_coords(longitude=(da.longitude+360)%360).roll(longitude=180, roll_coords=True)

def focus_data(da):
    """ drops data outside rectangle around Pacific """
    if 't_lat' in da.coords:  # ctrl
        lat, lon = 't_lat', 't_lon'
    elif 'nlat' in da.coords:  # lpd
        lat, lon = 'nlat', 'nlon'
    elif 'latitude' in da.coords:  # had
        lat, lon = 'latitude', 'longitude'
    else:  raise ValueError('xr DataArray does not have the right lat/lon coords.')
    da = da.dropna(dim=lat, how='all')
    da = da.dropna(dim=lon, how='all')
    return da

In [None]:
%%time
# 4min 15s
f, ax = plt.subplots(3,3, figsize=(12,8), sharex='col')
for i, extent in enumerate(['38S', 'Eq', '20N']):
    if extent=='38S':     latS, lonE = -38, 300
    elif extent=='Eq':    latS, lonE =   0, 285
    elif extent=='20N':   latS, lonE =  20, 255
    for j, domain in enumerate(['ocn_rect', 'ocn_low', 'ocn_had']):
        run = ['ctrl', 'lpd', 'had'][j]
        AREA = xr_AREA(domain=domain)
        Pac_MASK = mask_box_in_region(domain=domain, mask_nr=2,
                                      bounding_lats=(latS,68),
                                      bounding_lons=(110,lonE))
        area = AREA.where(Pac_MASK)
        if j==2:  area = shift_had(area)
        area = focus_data(area)
        area.to_netcdf(f'{path_prace}/geometry/AREA_{extent}_{domain}.nc')
        Pac_MASK.plot(ax=ax[i,j])
        print(f'{domain:10}, {extent:10}, {AREA.where(Pac_MASK).sum().values:5.2e}')

In [None]:
# %%time
# # ca 15 mins per extent
# monthly_ds_dt = [[f'{path_prace}/SST/SST_monthly_ds_dt_ctrl_{time[0]}_{time[1]}.nc' for time in times_ctrl],
#                  [f'{path_prace}/SST/SST_monthly_ds_dt_lpd_{time[0]}_{time[1]}.nc' for time in times_lpd], 
#                  [f'{path_prace}/SST/SST_monthly_ds_tfdt_had.nc']]

# for i, extent in enumerate(['38S', 'Eq', '20N']):
#     if i>0: continue
#     for j, domain in enumerate(['ocn_rect', 'ocn_low', 'ocn_had']):
#         if j>0: continue
#         run = ['ctrl', 'lpd', 'had'][j]
#         area = xr.open_dataarray(f'{path_prace}/geometry/AREA_{extent}_{domain}.nc')
#         monthly_fns = monthly_ds_dt[j]
#         for k, fn in tqdm(enumerate(monthly_fns)):
#             if k<10: continue
#             da = xr.open_dataarray(fn)
#             if j==2: da = shift_had(da)
#             da = focus_data(da)
#             da = da.where(area)
#             if j<2:
#                 time = [times_ctrl, times_lpd][j][k]
#                 fn = f'{path_prace}/SST/SST_monthly_ds_dt_{extent}_{run}_{time[0]}_{time[1]}.nc'
#             else:
#                 fn = f'{path_prace}/SST/SST_monthly_ds_dt_{extent}_{run}.nc'
#             da.to_netcdf(fn)

## actual EOF analysis

In [None]:
# %%time
# # 4:45 for 38S_ctrl, 5:07 for 38S_lpd, 3:42 for 38S_had : total 11:08
# # 2:50 for Eq_ctrl,  : total 11:08
# # total: 22min 19s
# for i, extent in tqdm(enumerate(['38S', 'Eq', '20N'])):
#     if i!=0:  continue
#     monthly_ds_dt = [[f'{path_prace}/SST/SST_monthly_ds_dt_{extent}_ctrl_{time[0]}_{time[1]}.nc' for time in times_ctrl],
#                      [f'{path_prace}/SST/SST_monthly_ds_dt_{extent}_lpd_{time[0]}_{time[1]}.nc' for time in times_lpd], 
#                      [f'{path_prace}/SST/SST_monthly_ds_dt_{extent}_had.nc']]
    
#     EOF_fns       = [[f'{path_prace}/SST/PMV_EOF_{extent}_ctrl_{time[0]}_{time[1]}.nc' for time in times_ctrl],
#                      [f'{path_prace}/SST/PMV_EOF_{extent}_lpd_{time[0]}_{time[1]}.nc' for time in times_lpd], 
#                      [f'{path_prace}/SST/PMV_EOF_{extent}_had.nc']]
    
#     for j, domain in tqdm(enumerate(['ocn_rect', 'ocn_low', 'ocn_had'])):
#         if j>0:  continue
#         area = xr.open_dataarray(f'{path_prace}/geometry/AREA_{extent}_{domain}.nc')
#         monthly_fns = monthly_ds_dt[j]
#         for k, fn in tqdm(enumerate(monthly_fns)):
#             if k > 4 and k<10: continue
#             fn = monthly_fns[k]
#             fn_EOF = EOF_fns[j][k]
#             da = xr.open_dataarray(fn)
#             AI().EOF_SST_analysis(xa=da, weights=area, neofs=1, npcs=1, fn=fn_EOF)

In [None]:
da = xr.open_dataset(f'{path_prace}/SST/PMV_EOF_38S_lpd_154_404.nc').eofs
da.plot()

In [None]:
da = xr.open_dataset(f'{path_prace}/SST/PMV_EOF_20N_lpd_154_404.nc').eofs
da.plot()

In [None]:
da.mean()

In [None]:
# check if files are present
plt.figure(figsize=(12,8))
ax = plt.gca()
for i, extent in enumerate(['38S', 'Eq', '20N']):
    EOF_fns       = [[f'{path_prace}/SST/PMV_EOF_{extent}_ctrl_{time[0]}_{time[1]}.nc' for time in times_ctrl],
                     [f'{path_prace}/SST/PMV_EOF_{extent}_lpd_{time[0]}_{time[1]}.nc' for time in times_lpd], 
                     [f'{path_prace}/SST/PMV_EOF_{extent}_had.nc']]
    ax.text(-20, 35*i+25, extent)
    ls = '-'
    for j, domain in enumerate(['ocn_rect', 'ocn_low', 'ocn_had']):
        (d, lat, lon) = dll_dims_names(['ocn_rect', 'ocn', 'ocn_had'][j])
        c = f'C{j}'
        tf = [1,365,365][j]  # time factor
        to = [130,300,0][j]  # time offset
        EOF_fns_ = EOF_fns[j]
        for k, fn in enumerate(EOF_fns_):
            assert os.path.exists(fn)
            cov = xr.open_dataset(fn, decode_times=False).eofs.mean(dim=[lat,lon])
            if cov<0:  factor=-1
            else:      factor= 1
            if j==1 and i==2: factor = factor*-1
            da = xr.open_dataset(fn, decode_times=False).pcs*factor
            ax.plot(da.time/tf+to, lowpass(da,5*12)+3*k+35*i, c=c, ls=ls)
for i in range(3):
    ax.text([70, 300, 575][i], 105, ['OBS', 'HIGH', 'LOW'][i])
    plt.yticks([])
plt.ylim((-4,110))

In [None]:
TPI_ctrl = xr.open_dataarray(f'{path_prace}/SST/TPI_ctrl.nc', decode_times=False)
TPI_lpd  = xr.open_dataarray(f'{path_prace}/SST/TPI_lpd.nc' , decode_times=False)
TPI_had  = xr.open_dataarray(f'{path_prace}/SST/TPI_had.nc' , decode_times=False)

In [None]:
plt.figure(figsize=(8,5))
ax = plt.gca()
for i in range(4):
    ax.axhline(i, c='grey', lw=.5)
TPI_fns = [f'{path_prace}/SST/TPI_ctrl.nc',
           f'{path_prace}/SST/TPI_lpd.nc', 
           f'{path_prace}/SST/TPI_had.nc']
for i, fn in enumerate(TPI_fns):
    tf = [1,365,365][i]  # time factor
    to = [130,300,0][i]  # time offset
    da = xr.open_dataarray(fn, decode_times=False)
    ax.plot(da.time[:250]/365+to, 4*lowpass(da[:250],13)+3, c=f'C{i}')
labels = []

for i, extent in enumerate(['38S', 'Eq', '20N']):
    EOF_fns       = [f'{path_prace}/SST/PMV_EOF_{extent}_ctrl_51_301.nc',
                     f'{path_prace}/SST/PMV_EOF_{extent}_lpd_154_404.nc',
                     f'{path_prace}/SST/PMV_EOF_{extent}_had.nc']
    for j, fn in enumerate(EOF_fns):
        (d, lat, lon) = dll_dims_names(['ocn_rect', 'ocn', 'ocn_had'][j])
        tf = [1,365,365][j]  # time factor
        to = [130,300,0][j]  # time offset
        cov = xr.open_dataset(fn, decode_times=False).eofs.mean(dim=[lat,lon])
        if cov<0:  factor=-1
        else:      factor= 1
        if j==1 and i==2: factor = factor*-1
        if j==2 and i==2: factor = factor*-1
        da = xr.open_dataset(fn, decode_times=False).pcs*factor
        ax.plot(da.time/tf+to, lowpass(da,13*12)+i, c=f'C{j}')
    ax.text([70, 300, 575][i], 3.7, ['OBS', 'HIGH', 'LOW'][i])
    labels.append(f'PC(>{extent})')
labels.append('TPI (x4)')

plt.xlim((-70,750))
plt.ylim((-1,4))
ax.set_yticks(range(4))
ax.set_yticklabels(labels)
plt.xlabel('time [years]')

## correlation plots

In [None]:
# %%time
# SST_rect_ctrl = xr.open_dataarray(f'{path_samoc}/SST/SST_monthly_rect_ctrl.nc', decode_times=False)
# SST_rect_rcp  = xr.open_dataarray(f'{path_samoc}/SST/SST_monthly_rect_rcp.nc' , decode_times=False)
# SST_rect_ds_dt_ctrl = lowpass(lowpass(notch(SST_rect_ctrl, 12), 12), 12) - SST_gm_rect_ds_ctrl[:-7]
# SST_rect_ds_dt_rcp  = lowpass(lowpass(notch(SST_rect_rcp , 12), 12), 12) - SST_gm_rect_ds_rcp[:-1]
# SST_rect_ds_dt_ctrl.to_netcdf(f'{path_samoc}/SST/SST_monthly_rect_ds_dt_ctrl.nc')
# SST_rect_ds_dt_rcp .to_netcdf(f'{path_samoc}/SST/SST_monthly_rect_ds_dt_rcp.nc' )

In [None]:
SST_rect_ds_dt_ctrl = xr.open_dataarray(f'{path_samoc}/SST/SST_monthly_rect_ds_dt_ctrl.nc', decode_times=False)
SST_rect_ds_dt_rcp  = xr.open_dataarray(f'{path_samoc}/SST/SST_monthly_rect_ds_dt_rcp.nc' , decode_times=False)

In [None]:
%%time
# 2:25 min
# ds_20N_ctrl = lag_linregress_3D(Pac_20N_ctrl.pcs[:-7,0], SST_rect_ds_dt_ctrl[24:-(24+7)], dof_corr=1./(12*13))
ds_38S_ctrl = lag_linregress_3D(Pac_38S_ctrl.pcs[:-7,0], SST_rect_ds_dt_ctrl[24:-(24+7)], dof_corr=1./(12*13))
# ds_20N_rcp  = lag_linregress_3D(-Pac_20N_rcp.pcs[:-7,0], SST_rect_ds_dt_rcp [24:-(24+7)], dof_corr=1./(12*13))
ds_38S_rcp  = lag_linregress_3D(Pac_38S_rcp .pcs[:-7,0], SST_rect_ds_dt_rcp [24:-(24+7)], dof_corr=1./(12*13))


In [None]:
for ds in [ds_20N_ctrl, ds_38S_ctrl]:
    ds.attrs['first_year'] = 102
    ds.attrs['last_year']  = 297
for ds in [ds_20N_rcp, ds_38S_rcp]:
    ds.attrs['first_year'] = 2002
    ds.attrs['last_year']  = 2097

In [None]:
ds_20N_ctrl

In [None]:
regr_map(ds=ds_20N_ctrl, index='PDO', run='ctrl', fn=None)

In [None]:
regr_map(ds=ds_38S_ctrl, index='IPO', run='ctrl', fn=None)

In [None]:
regr_map(ds=ds_20N_rcp, index='PDO', run='rcp', fn=None)

In [None]:
regr_map(ds=ds_38S_rcp, index='IPO', run='rcp', fn=None)

In [None]:
cartopy.__version__

In [None]:
# before