# SST modes vs. GMST regression

1. regress modes of SST variability against GMST variability

In [None]:
import sys
sys.path.append("..")
import scipy as sp
import numpy as np
import xarray as xr
import pandas as pd
import statsmodels.api as sm
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
matplotlib.rc_file('../rc_file_paper')
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot

In [None]:
from paths import path_samoc
from filters import lowpass

## time series into dataframe

In [None]:
# GMST
gmst_had  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_had.nc', decode_times=False)
gmst_had  = gmst_had.isel({'time':slice(9,158)})
gmst_ctrl = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_ctrl.nc')
gmst_ctrl['time'] = (gmst_ctrl.time/365).astype(dtype=int)
gmst_lpd  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_lpd.nc')
gmst_lpd['time'] = (gmst_lpd.time/365).astype(dtype=int)

In [None]:
# indices


In [None]:
# OHC
ctrl_qd = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl_qd.nc', decode_times=False)
lpd_qd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd_qd.nc' , decode_times=False)

In [None]:
ctrl_qd

In [None]:
def merge_time_series(run):
    """creats a pandas dataframe of the GMST, index, and OHC time series"""
    
    # GMST
    gmst = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_{run}.nc', decode_times=False)
    gmst.name = 'GMST'
    if run=='had':  gmst  = gmst.isel({'time':slice(9,158)})
    else:           gmst['time'] = (gmst.time/365).astype(dtype=int)
    if run=='lpd':  gmst  = gmst.isel({'time':slice(0,250)})
    
    # indices
    if run=='had':  dt = ''
    else:           dt = '_quadratic_pwdt'
    for i, index in enumerate(['AMO', 'TPI', 'SOM']):
        da = xr.open_dataarray(f'{path_samoc}/SST/{index}{dt}_{run}.nc')
        da.name = index
        if i==0: da_temp = da
        else: da_temp = xr.merge([da_temp,da])
    da_temp['time'] = (da_temp.time/365).astype(dtype=int)
    if run=='had':  da_temp['time'] = da_temp.time + 1870
    if run=='ctrl': da_temp = da_temp.isel({'time':slice(1,251)})
    if run=='lpd':  da_temp = da_temp.isel({'time':slice(0,250)})
    
    da_temp = xr.merge([lowpass(gmst,13), da_temp])
    
    # OHC
    if run in ['ctrl', 'lpd']:
        qd = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_{run}_qd.nc', decode_times=False)
        qd['time'] = qd.time.astype(dtype=int)
        OHC_G = lowpass(qd.OHC_Global_Ocean,13)
        OHC_A = lowpass(qd.OHC_Atlantic_Ocean,13)
        OHC_P = lowpass(qd.OHC_Pacific_Ocean,13)
        OHC_S = lowpass(qd.OHC_Southern_Ocean,13)
        da_temp = xr.merge([da_temp, OHC_G, OHC_A, OHC_P, OHC_S])
    
    return da_temp.isel({'time':slice(7,-7)}).to_dataframe()

time_series_had  = merge_time_series(run='had')
time_series_ctrl = merge_time_series(run='ctrl')
time_series_lpd  = merge_time_series(run='lpd')

In [None]:
time_series_had.plot()

In [None]:
time_series_ctrl.plot()

In [None]:
time_series_lpd.GMST.plot()

### plot overview

In [None]:
f, ax = plt.subplots(3, 3, figsize=(9,5), sharey='row', sharex='col', gridspec_kw={"width_ratios":[149, 250, 250]})

handles = []
for j, run in enumerate(['had','ctrl','lpd']):
    for i in range(3):
        if (j==0 and i==2) is False:  ax[i,j].axhline(0, c='grey', lw=.5)
    if j==0:
        print('test')
        da0 = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_had.nc', decode_times=False)
        ax[0,0].plot(da0.time, da0, lw=.3)
        ax[0,0].plot(da0.time, lowpass(da0,5))
        ax[0,0].plot(da0.time, lowpass(da0,13))
        for k, index in enumerate(['AMO', 'TPI', 'SOM']):
            ds1 = xr.open_dataarray(f'{path_samoc}/SST/{index}_had.nc', decode_times=False)
            ax[1,0].plot(ds1.time/365+1870, ds1)
        ax[1,j].set_xlabel('time [years C.E.]')
        ax[2,j].set_xlim((1866, 2023))
        ax[2,j].axis('off')
    else:
    # GMST
        da0 = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_{run}.nc')
        ax[0,j].plot(da0.time/365, da0.values, lw=.3)
        ax[0,j].plot(da0.time/365, lowpass(da0.values,5))
        ax[0,j].plot(da0.time/365, lowpass(da0.values,13))

        # indices
        for k, index in enumerate(['AMO', 'TPI', 'SOM']):
            da1 = xr.open_dataarray(f'{path_samoc}/SST/{index}_quadratic_pwdt_{run}.nc')
            l, = ax[1,j].plot(da1.time[7:-7]/365, da1.values[7:-7], label=index)
            if j==2:  handles.append(l)

        # OHC
        for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
            key = f'OHC_{ocean}_Ocean'
            c = ['k' ,'C0','C1','C2'][i]
            x = [ctrl_qd, lpd_qd][j-1][key]/1e21
            ax[2,j].plot(x.time, lowpass(x,13), c=c ,label=f'{ocean}')

        ax[2,j].set_xlabel('time [model years]')
    ax[0,j].text(.05,.85, ['HIST', 'HIGH', 'LOW'][j], transform=ax[0,j].transAxes)

ax[2,2].set_xlim((150,408))
ax[0,0].set_ylabel('GMST [degC]')
ax[1,0].set_ylabel('indices [degC]')
ax[2,0].set_ylabel('OHC [ZJ]')
ax[1,2].legend(handles=handles, fontsize=8, ncol=3, frameon=False)
ax[2,2].legend(fontsize=8, ncol=2, frameon=False)

f.align_ylabels()

In [None]:
pd.plotting.scatter_matrix(time_series_had, figsize=(10, 10), diagonal='kde');

In [None]:
pd.plotting.scatter_matrix(time_series_ctrl, figsize=(10, 10), diagonal='kde');

In [None]:
pd.plotting.scatter_matrix(time_series_lpd, figsize=(10, 10), diagonal='kde');

## indices vs OHC

In [None]:
ts['GMST'].corr(ts['GMST'].shift(1))

In [None]:
def lagged_correlation_plot(ax, X, Y, kwargs={}):
    Dt = np.arange(-20,21)
    A = []
    for dt in Dt:
        A.append(X.corr(Y.shift(dt)))
    return ax.plot(Dt, A, **kwargs)

f, ax = plt.subplots(3, 3, figsize=(6.4,6), sharey=True, sharex=True,
                     gridspec_kw={"wspace":0.03, "hspace":0.05})

for j, run in enumerate(['had','ctrl','lpd']):
    ts = [time_series_had , time_series_ctrl, time_series_lpd ][j]
    ax[2,j].set_xlabel('lag [years]')
    for i in range(3):
        if j==0 and i in [1,2]:
            ax[i,j].spines['right'].set_visible(False)
            ax[i,j].spines['top'].set_visible(False)
        else:
            ax[i,j].axhline(0, c='grey', lw=.5)
            ax[i,j].axvline(0, c='grey', lw=.5)
        ax[i,j].set_yticks([-1,0,1])
        ax[i,j].set_xlim((-20,20))
        ax[i,j].set_ylim((-1.1,1.1))
    
    for i, index in enumerate(['AMO', 'TPI', 'SOM']):
        kwarg = {'label':index}
        lagged_correlation_plot(ax=ax[0,j], X=ts.GMST, Y=ts[index], kwargs=kwarg)
        ax[0,j].legend(frameon=False, fontsize=8)
    ax[0,j].text(10,.85,'index leads', fontsize=8, ha='center')
    ax[0,j].text(-10,.85,'GMST leads', fontsize=8, ha='center')
    
    if run in ['ctrl', 'lpd']:
        for i, ocean in enumerate(['Atlantic', 'Pacific', 'Southern']):
            kwarg = {'label':ocean}
            key = f'OHC_{ocean}_Ocean'
            lagged_correlation_plot(ax=ax[1,j], X=ts[key], Y=ts.GMST, kwargs=kwarg)
#             kwarg['label'] = r'$\Delta$'+f' {ocean}'
            kwarg['ls'] = '--'
            lagged_correlation_plot(ax=ax[1,j], X=ts[key]-ts[key].shift(1), Y=ts.GMST-ts.GMST.shift(1), kwargs=kwarg)
            ax[1,j].legend(frameon=False, fontsize=8, loc=3, ncol=2, handlelength=.8)
            
            kwarg['ls'] = '-'
            index = ['AMO', 'TPI', 'SOM'][i]
            lagged_correlation_plot(ax=ax[2,j], X=ts[index], Y=ts[key], kwargs=kwarg)
            kwarg['ls'] = '--'
            lagged_correlation_plot(ax=ax[2,j], X=ts[index], Y=ts[key]-ts[key].shift(1), kwargs=kwarg)
            ax[2,j].legend(frameon=False, fontsize=8, loc=3, ncol=2, handlelength=.8)
            
        ax[1,j].text(10,.85,'GMST leads', fontsize=8, ha='center')
        ax[1,j].text(-10,.85,'OHC leads', fontsize=8, ha='center')

        ax[2,j].text(10,.85,'OHC leads', fontsize=8, ha='center')
        ax[2,j].text(-10,.85,'index leads', fontsize=8, ha='center')
#         ax[j,0].axis('off')
        
ax[0,0].set_ylabel('GMST vs. indices')
ax[1,0].set_ylabel('OHC vs. GMST')
ax[2,0].set_ylabel('indices vs. OHC')

In [None]:
ts.AMO.shift(0).plot()
(ts.AMO-ts.AMO.shift(1)).plot()
# ts.AMO.shift(10).plot()

In [None]:
ts.GMST.corr(ts.AMO.shift(10))

## indices vs GMST

## OHC vs GMST

In [None]:
ds1 = xr.open_dataarray(f'{path_samoc}/SST/AMO_had.nc', decode_times=False)
ds1.name = 'AMO'
df1 = ds1.to_dataframe()

In [None]:
ds2 = xr.open_dataarray(f'{path_samoc}/SST/TPI_had.nc', decode_times=False)
ds2.name ='TPI'
df2 = ds2.to_dataframe()

In [None]:
ds3 = xr.open_dataarray(f'{path_samoc}/SST/SOM_had.nc', decode_times=False)
ds3.name = 'SOM'
df3 = ds3.to_dataframe()

In [None]:

ds4['GMST'] = lowpass(ds4['GMST']-ds4['lin_fit'], 13)
df4 = ds4['GMST'].to_dataframe()

In [None]:
dfs = [df1, df2, df3, df4]
for i, df in enumerate(dfs):
    plt.plot(df.index+10000*i)

In [None]:
df4.index = df4.index+31

In [None]:
df = df1.join([df2, df3, df4])

In [None]:
df.plot()

In [None]:
df.head()

In [None]:
np.all(df.GMST.index==df.AMO.index)

In [None]:
plt.plot(df.AMO.index)
plt.plot(df.SOM.index)
plt.plot(df.AMO.values)
plt.plot(df.SOM.values)

In [None]:
plt.scatter(df.GMST, df.AMO)
plt.scatter(df.GMST, df.TPI)
plt.scatter(df.GMST, df.SOM)

In [None]:
pd.plotting.scatter_matrix(df, figsize=(10, 10), diagonal='kde');

In [None]:
from itertools import combinations 

In [None]:
for n in np.arange(1,4):
    for index in combinations(['AMO', 'TPI', 'SOM'], n):
        X = sm.add_constant(df.dropna(axis=0)[list(index)])
        y = df.dropna(axis=0).GMST.dropna()
        model = sm.OLS(y, X).fit()
#         print(index, '\n', model.summary())
        print(f'{index}    R^2: {model.rsquared:4.2e} \n params:\n{model.params}\n')