# SST modes vs. GMST regression

1. regress modes of SST variability against GMST variability

In [None]:
import sys
sys.path.append("..")
import scipy as sp
import numpy as np
import xarray as xr
import pandas as pd
import statsmodels.api as sm
import matplotlib
import matplotlib.pyplot as plt
from itertools import combinations 

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
matplotlib.rc_file('../rc_file_paper')
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot

In [None]:
from paths import path_samoc, path_results
from filters import lowpass

## time series into dataframe

In [None]:
# GMST
gmst_had  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_had.nc', decode_times=False)
gmst_had  = gmst_had.isel({'time':slice(9,158)})
gmst_ctrl = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_ctrl.nc')
gmst_ctrl['time'] = (gmst_ctrl.time/365).astype(dtype=int)
gmst_lpd  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_lpd.nc')
gmst_lpd['time'] = (gmst_lpd.time/365).astype(dtype=int)

In [None]:
# indices


In [None]:
# OHC
ctrl_qd = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl_qd.nc', decode_times=False)
lpd_qd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd_qd.nc' , decode_times=False)

In [None]:
ctrl_qd

In [None]:
def merge_time_series(run):
    """creats a pandas dataframe of the GMST, index, and OHC time series"""
    
    # GMST
    gmst = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_{run}.nc', decode_times=False)
    gmst.name = 'GMST'
    if run=='had':  gmst  = gmst.isel({'time':slice(9,158)})
    else:           gmst['time'] = (gmst.time/365).astype(dtype=int)
    if run=='lpd':  gmst  = gmst.isel({'time':slice(0,250)})
    
    # indices
    if run=='had':  dt = ''
    else:           dt = '_quadratic_pwdt'
    for i, index in enumerate(['AMO', 'TPI', 'SOM']):
        da = xr.open_dataarray(f'{path_samoc}/SST/{index}{dt}_{run}.nc')
        da.name = index
        if i==0: da_temp = da
        else: da_temp = xr.merge([da_temp,da])
    da_temp['time'] = (da_temp.time/365).astype(dtype=int)
    if run=='had':  da_temp['time'] = da_temp.time + 1870
    if run=='ctrl': da_temp = da_temp.isel({'time':slice(1,251)})
    if run=='lpd':  da_temp = da_temp.isel({'time':slice(0,250)})
    
    da_temp = xr.merge([lowpass(gmst,13), da_temp])
    
    # OHC
    if run in ['ctrl', 'lpd']:
        qd = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_{run}_qd.nc', decode_times=False)
        qd['time'] = qd.time.astype(dtype=int)
        OHC_G = lowpass(qd.OHC_Global_Ocean,13)
        OHC_A = lowpass(qd.OHC_Atlantic_Ocean,13)
        OHC_P = lowpass(qd.OHC_Pacific_Ocean,13)
        OHC_S = lowpass(qd.OHC_Southern_Ocean,13)
        da_temp = xr.merge([da_temp, OHC_G, OHC_A, OHC_P, OHC_S])
    
    return da_temp.isel({'time':slice(7,-7)}).to_dataframe()

time_series_had  = merge_time_series(run='had')
time_series_ctrl = merge_time_series(run='ctrl')
time_series_lpd  = merge_time_series(run='lpd')

In [None]:
time_series_ctrl.OHC_Atlantic_Ocean.plot()

### plot overview

In [None]:
f, ax = plt.subplots(3, 3, figsize=(9,5), sharey='row', sharex='col', gridspec_kw={"width_ratios":[149, 250, 250]})

handles = []
for j, run in enumerate(['had','ctrl','lpd']):
    for i in range(3):
        if (j==0 and i==2) is False:  ax[i,j].axhline(0, c='grey', lw=.5)
    if j==0:
        print('test')
        da0 = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_had.nc', decode_times=False)
        ax[0,0].plot(da0.time, da0, lw=.3)
        ax[0,0].plot(da0.time, lowpass(da0,5))
        ax[0,0].plot(da0.time, lowpass(da0,13))
        for k, index in enumerate(['AMO', 'TPI', 'SOM']):
            ds1 = xr.open_dataarray(f'{path_samoc}/SST/{index}_had.nc', decode_times=False)
            ax[1,0].plot(ds1.time/365+1870, ds1)
        ax[1,j].set_xlabel('time [years C.E.]')
        ax[2,j].set_xlim((1866, 2023))
        ax[2,j].axis('off')
    else:
    # GMST
        da0 = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_{run}.nc')
        ax[0,j].plot(da0.time/365, da0.values, lw=.3)
        ax[0,j].plot(da0.time/365, lowpass(da0.values,5))
        ax[0,j].plot(da0.time/365, lowpass(da0.values,13))

        # indices
        for k, index in enumerate(['AMO', 'TPI', 'SOM']):
            da1 = xr.open_dataarray(f'{path_samoc}/SST/{index}_quadratic_pwdt_{run}.nc')
            l, = ax[1,j].plot(da1.time[7:-7]/365, da1.values[7:-7], label=index)
            if j==2:  handles.append(l)

        # OHC
        for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
            key = f'OHC_{ocean}_Ocean'
            c = ['k' ,'C0','C1','C2'][i]
            x = [ctrl_qd, lpd_qd][j-1][key]/1e21
            ax[2,j].plot(x.time, lowpass(x,13), c=c ,label=f'{ocean}')

        ax[2,j].set_xlabel('time [model years]')
    ax[0,j].text(.05,.85, ['HIST', 'HIGH', 'LOW'][j], transform=ax[0,j].transAxes)

ax[2,2].set_xlim((150,408))
ax[0,0].set_ylabel('GMST [degC]')
ax[1,0].set_ylabel('indices [degC]')
ax[2,0].set_ylabel('OHC [ZJ]')
ax[1,2].legend(handles=handles, fontsize=8, ncol=3, frameon=False)
ax[2,2].legend(fontsize=8, ncol=2, frameon=False)

f.align_ylabels()
plt.savefig(f'{path_results}/REGR/overview_GMST_indices_OHC_had_ctrl_lpd')

In [None]:
pd.plotting.scatter_matrix(time_series_had, figsize=(10, 10), diagonal='kde');

In [None]:
pd.plotting.scatter_matrix(time_series_ctrl, figsize=(10, 10), diagonal='kde');

In [None]:
pd.plotting.scatter_matrix(time_series_lpd, figsize=(10, 10), diagonal='kde');

## indices vs OHC

In [None]:
ts['GMST'].corr(ts['GMST'].shift(1))

In [None]:
def lagged_correlation_plot(ax, X, Y, kwargs={}):
    Dt = np.arange(-25,26)
    A = []
    for dt in Dt:
        A.append(X.corr(Y.shift(dt)))
    return ax.plot(Dt, A, **kwargs)

f, ax = plt.subplots(3, 3, figsize=(6.4,6), sharey=True, sharex=True,
                     gridspec_kw={"wspace":0.03, "hspace":0.05})

for j, run in enumerate(['had','ctrl','lpd']):
    ax[0,j].title.set_text(['HIST', 'HIGH', 'LOW'][j])
    ts = [time_series_had , time_series_ctrl, time_series_lpd ][j]
    ax[2,j].set_xlabel('lag [years]')
    for i in range(3):
        if j==0 and i in [1,2]:
            ax[i,j].spines['right'].set_visible(False)
            ax[i,j].spines['top'].set_visible(False)
        else:
            ax[i,j].axhline(0, c='grey', lw=.5)
            ax[i,j].axvline(0, c='grey', lw=.5)
        ax[i,j].set_yticks([-1,0,1])
        ax[i,j].set_xlim((-25,25))
        ax[i,j].set_ylim((-1.1,1.1))
    
    for i, index in enumerate(['AMO', 'TPI', 'SOM']):
        kwarg = {'label':index}
        lagged_correlation_plot(ax=ax[0,j], X=ts.GMST, Y=ts[index], kwargs=kwarg)
        ax[0,j].legend(frameon=False, fontsize=8, loc=8)
    ax[0,j].text(10,.85,'index leads', fontsize=8, ha='center')
    ax[0,j].text(-10,.85,'GMST leads', fontsize=8, ha='center')
    
    if run in ['ctrl', 'lpd']:
        ax[1,j].text(10,.85,'GMST leads', fontsize=8, ha='center')
        ax[1,j].text(-10,.85,'OHC leads', fontsize=8, ha='center')

        ax[2,j].text(10,.85,'OHC leads', fontsize=8, ha='center')
        ax[2,j].text(-10,.85,'index leads', fontsize=8, ha='center')
        
        for i, ocean in enumerate(['Atlantic', 'Pacific', 'Southern']):
            kwarg = {'label':ocean, 'c':f'C{i}'}
            key = f'OHC_{ocean}_Ocean'
            lagged_correlation_plot(ax=ax[1,j], X=ts[key], Y=ts.GMST, kwargs=kwarg)
#             kwarg['label'] = r'$\Delta$'+f' {ocean}'
            kwarg['label'] = None
            kwarg['ls'] = '--'
        
            lagged_correlation_plot(ax=ax[1,j], X=ts[key]-ts[key].shift(1), Y=ts.GMST, kwargs=kwarg)
            ax[1,j].legend(frameon=False, fontsize=8, loc=3)
            
            kwarg['label'] = ocean
            kwarg['ls'] = '-'
            index = ['AMO', 'TPI', 'SOM'][i]
            lagged_correlation_plot(ax=ax[2,j], X=ts[index], Y=ts[key], kwargs=kwarg)
            kwarg['label'] = None
            kwarg['ls'] = '--'
            lagged_correlation_plot(ax=ax[2,j], X=ts[index], Y=ts[key]-ts[key].shift(1), kwargs=kwarg)
            ax[2,j].legend(frameon=False, fontsize=8, loc=3)
            

#         ax[j,0].axis('off')
        
ax[0,0].set_ylabel('GMST vs. indices')
ax[1,0].set_ylabel('OHC vs. GMST')
ax[2,0].set_ylabel('indices vs. OHC')
plt.savefig(f'{path_results}/REGR/lag_correlation_GMST_indices_OHC_had_ctrl_lpd')

In [None]:
def plot_bars(model, ax):
    ax.bar(k, model.rsquared, width=.3, color='k')
    ax.text(k, -.15, f'{model.rsquared*100:2.0f}%', fontsize=6, ha='center')
    for p, param in enumerate(model.params):
        if p>0:
            if model.params.keys()[p]in ['AMO', 'OHC_Atlantic_Ocean']:  c='C0'
            if model.params.keys()[p]in ['TPI', 'OHC_Pacific_Ocean' ]:  c='C1'
            if model.params.keys()[p]in ['SOM', 'OHC_Southern_Ocean']:  c='C2'
            ax.bar(k+.05+.2*p, param, width=.2, color=c)
    return

f, ax = plt.subplots(4, 7, figsize=(12,6), sharey='row', sharex=True,
                     gridspec_kw={"wspace":0.01, "hspace":0.01})
ax[0,0].set_ylabel('GMST vs. indices')
ax[1,0].set_ylabel('GMST vs. OHC')
ax[2,0].set_ylabel(r'GMST vs. $\Delta$OHC')
ax[3,0].set_ylabel('OHC global vs. OHC')

for k, df in enumerate([time_series_had, time_series_ctrl, time_series_lpd]):
    i, ii = 0, 0
    for n in np.arange(1,4):
        for index in combinations(['AMO', 'TPI', 'SOM'], n):
            X = sm.add_constant(df.dropna(axis=0)[list(index)])
            y = df.dropna(axis=0).GMST.dropna()
            model = sm.OLS(y, X).fit()
            plot_bars(model=model, ax=ax[0,i])
            if k==0:  ax[0,i].text(.02,.05,', '.join(index), fontsize=8, transform=ax[0,i].transAxes)
            i += 1
            
        if k>0:
            for oceans in combinations(['Atlantic', 'Pacific', 'Southern'], n):
                ocean_keys = ['OHC_'+ocean+'_Ocean' for ocean in list(oceans)]
                
                # GMST vs OHC
                X = sm.add_constant(df.dropna(axis=0)[ocean_keys]/1e22)
                y = df.dropna(axis=0).GMST.dropna()/df.GMST.std()
                model = sm.OLS(y, X).fit()
#                 print(f'{oceans}    R^2: {model.rsquared*100:4.2f} \n params:\n{model.params}\n')
                plot_bars(model=model, ax=ax[1,ii])
                if k==1:  ax[1,ii].text(.02,.05,', '.join(oceans), fontsize=8, transform=ax[1,ii].transAxes)
                
                # GMST vs \Delta OHC
                X = sm.add_constant((df-df.shift(1)).dropna(axis=0)[ocean_keys]/1e21)
                y = df.dropna(axis=0).GMST.dropna()[1:]/df.GMST.std()
                y.align(X)
                model = sm.OLS(y, X).fit()
                plot_bars(model=model, ax=ax[2,ii])
                if k==1:  ax[2,ii].text(.02,.05,', '.join(oceans), fontsize=8, transform=ax[2,ii].transAxes)
                
                # OHC global vs OHC
                X = sm.add_constant(df.dropna(axis=0)[ocean_keys])
                y = df.dropna(axis=0).OHC_Global_Ocean.dropna()
                model = sm.OLS(y, X).fit()
                plot_bars(model=model, ax=ax[3,ii])
                if k==1:  ax[3,ii].text(.02,.05,', '.join(oceans), fontsize=8, transform=ax[3,ii].transAxes)

                ii +=1

for k in range(4):
    for i in range(7):
        ax[k,i].axhline(0, c='grey', lw=.5)
        ax[k,i].axhline(1, c='grey', lw=.5)
        ax[k,i].set_xlim((-.4,3))
        ax[k,i].set_xticks([.3,1.3,2.3])
        ax[k,i].set_xticklabels(['HIST', 'HIGH', 'LOW'])
        if k>0:  ax[k,i].axvspan(-.15,.75, color='grey', alpha=.2)
            
eqn = r'$Y(t) = \Sigma_{i} \alpha_i X_i(t) + const.$'
f.suptitle(f'Multiple linear regression models:  {eqn}    (all time series 13 year lowpass filtered)')
f.align_ylabels()
plt.savefig(f'{path_results}/REGR/multiple_linear_regression_GMST_indices_OHC_had_ctrl_lpd')