# Obtain SST PCs

In [None]:
## Get equatorial pacfic SST
def get_pacific_SST():
    SST_anom_nonan = SST_anom.where(np.isnan(SST_anom)==False, drop=True)

    # Load ocean basin mask
    basin = xr.open_dataset('http://iridl.ldeo.columbia.edu/SOURCES/.NOAA/.NODC/.WOA09/.Masks/.basin/dods')
    basin = basin.rename({'X': 'longitude', 'Y': 'latitude'})

    # Select pacific basin, interpolate to ERA5 coords, and mask all values not in the pacific
    basin_surf = basin.basin[0]
    basin_surf_interp = basin_surf.interp_like(SST_anom_nonan, method='nearest')
    SST_pacific = SST_anom_nonan.where(basin_surf_interp == 2)

    # Select equatorial (10N - 10S) pacific SST only
    SST_eq_pacific = SST_pacific.sel(latitude=slice(-10,10))
    SST_eq_pacific_norm = SST_eq_pacific / SST_eq_pacific.weighted(weight).std()
    
    return SST_eq_pacific_norm

In [None]:
## Obtain PCs/EOFs of SST anomalies
def get_PCs_EOFs():
    solver = xMCA(get_pacific_SST().rename({'latitude':'lat','longitude':'lon'}))
    solver.solve()

    # Find SST EOFs and put them in xarray DataArray
    eof_SST = solver.eofs(n=25)['left']
    eof_SST = eof_SST.rename({'lon':'longitude','lat':'latitude'})
    eof_SST['mode'] = np.arange(0,25)
    # eof_SST[:,:,0] = -eof_SST[:,:,0]
    eof_SST[:,:,1] = -eof_SST[:,:,1]

    # Find SST PCs and put them in xarray DataArray
    pc_SST = solver.pcs(n=25)['left']
    pc_SST['mode'] = np.arange(0,25)
    # pc_SST[:,0] = -pc_SST[:,0]
    pc_SST[:,1] = -pc_SST[:,1]
    
    return pc_SST, eof_SST

# Use 1-2-1 low pass filter on PCs
def filtered(PCs):
    pc_filtered = np.zeros((len(PCs)-2,2))
    filter_coefs = np.array([1,2,1])
    for i in range(len(PCs) - 2):
        for j in range(2):
            pc_filtered[i,j] = np.sum(PCs[i:i+3,j] * filter_coefs / 4)

    PCs_filtered_xr = xr.DataArray(data=pc_filtered,coords={'time':PCs.time[0:-2],'mode':[0,1]})
    
    return PCs_filtered_xr

In [None]:
## Obtain the rotated EOFs/PCs by using a linear combination of the first two modes (Takahashi 2011)
def rotate(eofs):
    e_pattern = (eofs.sel(mode=0) - eofs.sel(mode=1)) / np.sqrt(2)
    c_pattern = (eofs.sel(mode=0) + eofs.sel(mode=1)) / np.sqrt(2)
    return e_pattern, c_pattern

pc_SST, eof_SST = get_PCs_EOFs()
pc_SST = filtered(pc_SST)

epc, cpc = rotate(pc_SST)
epat, cpat = rotate(eof_SST)

eof_patterns = xr.concat([epat,cpat],dim='mode').assign_coords({'mode':[0,1]}).T # 0 = e-pattern, 1 = c-pattern
pc_patterns = xr.concat([epc,cpc],dim='mode').assign_coords({'mode':[0,1]}).T # 0 = e-pattern, 1 = c-pattern

# Obtain Radiation PCs

In [None]:
## Regress y onto the principal components
def regOnPcs(pcs,y,n,mode,pval_corr):
    # regress onto PCs to get spatial maps
    if pcs.ndim == 2:
        if n == 0:
            X = pcs.sel(mode=mode).values.reshape(-1,1)
            m = [mode]
        else:
            X = pcs.sel(mode=slice(0,n)).to_pandas()
            m = np.arange(0,n+1)

        y_stack = y.stack(allpoints=('latitude','longitude')).dropna(dim='allpoints')
        
        reg = LinearRegression(fit_intercept=False).fit(X,y_stack)
        
        # Obtain pcals and apply FDR correction
        if pval_corr == 'y':
            pvals = []
            for i in range(0,len(y_stack.allpoints)):
                pvals.append(sm.OLS(y_stack[:,i].values,X).fit().pvalues)
            pvals_fdr = fdr_corr(pvals)

            reg_ds = xr.Dataset(data_vars={'coefs':(['dim_0','dim_1'],reg.coef_),
                                           'pvals':(['dim_0','dim_1'],pvals_fdr)},
                                  coords={'dim_0':y_stack.allpoints.values,
                                          'dim_1':m})
        else:
            reg_ds = xr.Dataset(data_vars={'coefs':(['dim_0','dim_1'],reg.coef_)},
                                coords={'dim_0':y_stack.allpoints.values,
                                        'dim_1':m})
        
        reg_rename = reg_ds.rename({'dim_0':'allpoints','dim_1':'mode'})
        reg_xr = reg_rename.reindex_like(y_stack)
        reg_unstack = reg_xr.unstack('allpoints')
        
        return reg_unstack
    
    # regress onto EOFs to get principal components
    elif pcs.ndim == 3:
        X_stack = pcs.stack(allpoints=('latitude','longitude')).transpose('allpoints',...).dropna(dim='allpoints')
        y_stack = y.stack(allpoints=('latitude','longitude')).transpose('allpoints',...).dropna(dim='allpoints')
        a_stack, b_stack = xr.align(X_stack,y_stack)
        
        if n == 0:
            reg = LinearRegression(fit_intercept=False).fit(a_stack.sel(mode=mode).values.reshape(-1,1),b_stack)
            m = [mode]
        else:
            reg = LinearRegression(fit_intercept=False).fit(a_stack.sel(mode=slice(0,n)),b_stack)
            m = np.arange(0,n+1)     
        
        reg_ds = xr.Dataset(data_vars={'coefs':(['dim_0','dim_1'],reg.coef_)},
                              coords={'dim_0':y_stack.time.values,
                                      'dim_1':m})
        reg_rename = reg_ds.rename({'dim_0':'time','dim_1':'mode'})
        
        return reg_rename
    
def fdr_corr(pvals):
    p_fdr = []
    for i in pvals.mode:
        p_fdr.append(multitest.fdrcorrection(pvals.sel(mode=i).stack(allpoints=('latitude','longitude')))[0])
    
    p_fdr_xr = xr.DataArray(data=p_fdr,coords={'mode':np.arange(0,len(pvals.mode)),'allpoints':pvals.stack(allpoints=('latitude','longitude')).allpoints.values})
    p_fdr_xr = p_fdr_xr.reindex_like(pvals.stack(allpoints=('latitude','longitude'))).unstack('allpoints')
    
    return p_fdr_xr

# Time-series Analysis

In [None]:
## Get global-mean, reconstructed, and equatorial pacific radiation time-series
def get_time_series(data,PCs):
    data_gm = global_mean(data).dropna(dim='time')#.rolling(time=12,center=True).mean().dropna(dim='time')
    data_recon = (PCs / PCs.std(dim='time')).sel(mode=slice(0,1)).sum(dim='mode')
    data_recon_align, data_gm_align = xr.align(data_recon, data_gm)#.rolling(time=12,center=True).mean().dropna(dim='time'), data_gm)
    
    data_pacific = data.where(basin_surf_interp == 2)
    data_trop,_ = xr.align(global_mean(data_pacific.sel(latitude=slice(-10,10))), data_gm)#.rolling(time=12,center=True).mean(), data_gm)
    data_nino34,_ = xr.align(global_mean(data.sel(latitude=slice(-5,5),longitude=slice(190,240))), data_gm)#.rolling(time=12,center=True).mean()

    return data_pacific, data_trop, data_nino34, data_recon

## Calculate autocorrelation of a time-series
def autocorrelation(x,lag):
    y_lag = []
    
    for i in lag:
        x_lag = x.shift(time=i).dropna(dim='time',how='all')
        a_lag, b = xr.align(x_lag, x)

        reg,_ = st.pearsonr(a_lag,b)

        y_lag.append(reg)
        
    y_lag_xr = xr.DataArray(data=y_lag,coords={'lag':lag})
    auto = y_lag_xr.interp(lag=np.arange(lag[0],lag[-1],0.01))
    
    # Find e-folding lag
    for i in range(0,len(auto)):
        if auto[i] <= np.exp(-1):
            return auto[i].lag.values

In [None]:
## Get power spectral density using Welch's method
def get_signal(ts_data):
    data_gm    = ts_data[0]
    data_recon = ts_data[1]
    data_trop  = ts_data[2]
    
    recon_sig = signal.welch(data_recon,nperseg=120,noverlap=60)
    gm_sig = signal.welch(data_gm,nperseg=120,noverlap=60)
    trop_sig = signal.welch(data_trop,nperseg=120,noverlap=60)
    return recon_sig,gm_sig,trop_sig

## Get magnitude squared coherence
def get_coherence(data_gm,data_recon):
    aligned_gm,aligned_recon = xr.align(data_gm,data_recon)
    f, Cxy = signal.coherence(aligned_gm,aligned_recon,nperseg=120,noverlap=60)
    return f, Cxy

## Obtain a confidence interval, specify the dof using phi
def conf_int(data,phi):
    ci = st.chi2.ppf([0.025,0.975],df=2*phi)
    ci = 2*phi / ci
    
    return ci[0]*data, ci[1]*data

## Return a normalized PSD
def norm(data):
    return data / data.sum()

In [None]:
## Get a white & red noise spectrum from data's variance n times, then average the 95-th percentile values
def noise(data,n):        
    x, y = xr.align(data,data.shift(time=1).dropna(dim='time'))
    auto = st.pearsonr(x,y)[0]
    
    # Monte Carlo generation of white/red noise
    wh_psd = []
    rd_psd = []
    for i in range(0,n+1):
        wh_ts = np.random.normal(loc=0,scale=data.std(),size=len(data.time))
        
        rd_ts = []
        for j in range(0,len(data.time)):
            if j == 0:
                rd_ts.append(wh_ts[j])
            else:
                calc = auto*rd_ts[j-1] + wh_ts[j]*np.sqrt(1-auto**2)
                rd_ts.append(calc)
                
        f,ps_wh = signal.welch(wh_ts,nperseg=120,noverlap=60)
        f,ps_rd = signal.welch(rd_ts,nperseg=120,noverlap=60)
    
    wh_psd.append(ps_wh)
    rd_psd.append(ps_rd)
    
    return np.percentile(a=wh_psd,q=97.5,axis=0), np.percentile(a=rd_psd,q=97.5,axis=0)

# Lagged Regression Analysis

In [None]:
### Perform lagged regression (w/ the possibility of normal regression for lag = 0) on two datasets x and y.
def lagged_regress(x,y,lag): 
    x_lag = x.shift(time=lag)
    # Stack y if it hs dimensions of (lat, lon, t)
    if y.ndim == 3:
        y_stack = y.stack(allpoints=('latitude','longitude')).dropna(dim='allpoints')
    else:
        y_stack = y

    # Align to remove nans in x and y
    x_lag_nonan = x_lag.dropna(dim='time')
    y_nonan = y_stack.dropna(dim='time')
    a_lag, b_stack = xr.align(x_lag_nonan, y_nonan)

    # Perform linear regression
    if x.ndim == 1:
        reg = LinearRegression().fit(a_lag.values.reshape(-1,1),b_stack)
    else:
        reg = LinearRegression().fit(a_lag.values,b_stack)
    
    reg_da = xr.DataArray(data=reg.coef_)
    
    # Place betas in xarray, unstack (if betas are 2D), and return
    if reg_da.ndim == 2:
        reg_rename = reg_da.rename({'dim_0':'allpoints','dim_1':'mode'})#.assign_coords({'mode':y.mode})
        reg_reindex = reg_rename.reindex_like(y_stack)
        reg_unstack = reg_reindex.unstack('allpoints')
        
        return reg_unstack
    elif reg_da.ndim == 1 and len(reg_da) > 1:
        reg_rename = reg_da.rename({'dim_0':'mode'}).assign_coords({'mode':y.mode})
        reg_reindex = reg_rename.reindex_like(y_stack)
        
        return reg_reindex
    else:
        return reg_da

In [None]:
## Convert data from time to lag coords via lagged regression
def time2lag(x,y,lag):
    y_lag = []

    for i in lag:
        y_lag.append(lagged_regress(x,y,i))

    y_lag_xr = xr.concat(y_lag,dim='lag').assign_coords({'lag':lag})
    
    return y_lag_xr.squeeze()

In [None]:
## Get r2 between full and reconstructed radiation maps vs lag
def get_r2_lags(full_rad_data,EOFs):
    r2 = np.zeros([len(lag_list),len(EOFs.mode)])
    for i in lag_list:
        for j in EOFs.mode:
            dR_align, EOFs_align = xr.align(full_rad_data,EOFs)
            r2[i+18,j] = (sm.OLS(dR_align.sel(lag=i,mode=j).stack(allpoints=('latitude','longitude')).values,
                             EOFs_align.sel(lag=i,mode=j).stack(allpoints=('latitude','longitude')).values,missing='drop').fit().rsquared)
    
    return xr.DataArray(data=r2,coords={'lag':lag_list,'mode':EOFs.mode})