In [None]:
import os
import sys
sys.path.append("..")
import numpy as np
import xarray as xr
import cartopy
import cartopy.crs as ccrs
import cmocean
import datetime
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
matplotlib.rc_file('../rc_file_paper')
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot

In [None]:
from scipy.optimize import curve_fit
from paths import path_samoc, path_results
from xr_regression import xr_quadtrend

color conventions:
- global black
- Atlantic blue
- Pacific orange
- Southern red

# Fig 1: exp decay $\Delta$OHC / SST
from `OHC-ctrl_vs_lpd.ipynb`

In [None]:
# OHC data
ctrl = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl.nc')
ctrl = ctrl.assign_coords(time=ctrl.time.values/365)
lpd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd.nc' )
lpd  = lpd.assign_coords(time=lpd .time.values/365)

ctrl_qd = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl_qd.nc', decode_times=False)
lpd_qd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd_qd.nc' , decode_times=False)

In [None]:
# SST data
SST_ctrl = xr.open_dataarray(f'{path_samoc}/SST/SST_global_mean_timeseries_ctrl.nc')
SST_lpd  = xr.open_dataarray(f'{path_samoc}/SST/SST_global_mean_timeseries_lpd.nc')
SST_ctrl = SST_ctrl.assign_coords(time=SST_ctrl.time.values/365)
SST_lpd = SST_lpd.assign_coords(time=SST_lpd.time.values/365)


In [None]:
(SST_ctrl-SST_ctrl.shift(time=1)).plot()
(SST_lpd-SST_lpd.shift(time=1)).plot()
plt.axhline(0)

In [None]:
def exp_decay(x, a, b):
    return a*np.exp(-x/b)

def exp_fit(da_fit):
    initial_guess = (da_fit.values[0], 1000.)
    return curve_fit(exp_decay, da_fit.time, da_fit.values, p0=initial_guess)[0]
   
def adjustment_time(da):
    Delta_da = (da - da.shift(time=1))[1:]
    if len(da.coords)==1:  # 1D time series
        popt = exp_fit(Delta_da)  # array of size ((2))
    else:
        stacked = False
        if len(da.coords)>2:  # need to be stacked
            stacked = True
            Delta_da = Delta_da.stack(stacked_coord=list(da.coords.keys())[1:])
            print(Delta_da)
        assert len(Delta_da.coords)==2
        coord = list(Delta_da.coords)[1]
        A = np.zeros((len(Delta_da[coord]), 2))
        for i, c in tqdm(enumerate(Delta_da[coord])):
            if np.any(np.isnan(c)):  continue  # skipping nans
            A[i,:] = exp_fit(Delta_da[:,i])
        D = Delta_da.isel({'time':0}).drop('time')  # skeleton DataArray with correct dimensions
        popt0 = D.copy(data=A[:,0])
        popt1 = D.copy(data=A[:,1])
        if stacked==True:  # unstacking
            popt0 = popt0.unstack()
            popt1 = popt1.unstack()
        popt0.name, popt1.name= 'popt0','popt1'
        popt = xr.merge([popt0, popt1])
    return popt

In [None]:
# time series: depth-lat-lon integrated OHC
def plot_adjustment_timeseries(da, ax):
    Delta_da = (da - da.shift(time=1))
    popt = adjustment_time(da)
    print(popt)
    ax.axhline(0, c='k', lw=.5)
    ax.plot(Delta_da.time[1:], Delta_da.values[1:], lw=.2)
    ax.plot(Delta_da.time[1:], exp_decay(Delta_da.time[1:], *popt), 'r-', label="Fitted Curve")
    tau_symbol = r'$\tau$'
    if popt[1]>1000:
        tau = '>1000'
    elif popt[1]<5:
        tau = '<5'
    else:
        tau= f'{popt[1]:3.0f}'
    ax.text(.1, .85, f'{tau_symbol} = {tau} years', transform=ax.transAxes, fontsize=8)
#     cax.set_ylim((-10,20))
    return

fig, ax = plt.subplots(4,4, figsize=(6.4,5), sharex='col')
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    Delta = r'$\Delta$'
    for j in range(2):
        ax[i,2*j+1].set_yticklabels([])
        if i==0: plot_adjustment_timeseries([SST_ctrl, SST_lpd][j], ax[i,j])
    
        plot_adjustment_timeseries([ctrl,lpd][j][f'OHC_{ocean}_Ocean']/1e21, ax[i,j+2])
    ax[i,0].get_shared_y_axes().join(ax[i,0], ax[i,1])
    ax[i,2].get_shared_y_axes().join(ax[i,2], ax[i,3])
        
    ax[i,0].set_ylabel(f'{ocean} Ocean\n{Delta}SST [K]')
    ax[i,2].set_ylabel(f'{Delta}OHC [ZJ]')

for j in range(4):
    ax[-1,j].set_xlabel('time (model years)')
    ax[0,j].title.set_text(['CTRL', 'LPD'][j%2])

fig.align_ylabels()
    # share x-axis
#     if i<4:
        
    
    
    #     for k in range(4):
#     ax[-2].get_shared_x_axes().join(ax[-2], ax[j+3])
#     ax[j].set_xticklabels([])
#     ax[j+1].set_xticklabels([])

# for i in range(4):

# plt.subplots_adjust(bottom=0.08, left=.12, right=0.99, top=0.99, wspace=.06, hspace=.1)
fig.savefig(f'{path_results}/paper/equilibration_SST_OHC')  # comment out time selection above
# plt.savefig(f'{path_results}/OHC/OHC_equilibration_select')


In [None]:
spinup = 50  # how many years to ignore due to spinup effects: data from year 51 of ctrl run

ctrl = ctrl.isel(time=np.arange(50,300))
lpd  = lpd .isel(time=np.arange(0,300-spinup))

# Fig 2: SST index regression patterns
from `OHC-ctrl_vs_lpd.ipynb`

In [None]:
lpd_lat = lpd.TLAT.mean(axis=1)

f, ax = plt.subplots(1, 2 , sharex=True, figsize=(6.4,2.5))
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    key = f'OHC_zonal_{ocean}_Ocean'
    c = ['k' ,'C0','C1','C3'][i]
# mean
    ax[0].plot(ctrl.t_lat, ctrl[key].mean(dim='time')    , c=c , label=ocean)
    ax[0].plot(lpd_lat   , lpd [key].mean(dim='time')/100, c=c , ls='--')
    ax[0].set_ylabel('mean OHC [J/m]')
# std of quad. detrended
    ax[1].plot(ctrl.t_lat, (ctrl[key]-xr_quadtrend(ctrl[key])).std(dim='time'), c=c)
    ax[1].plot(lpd_lat   , (lpd[key] -xr_quadtrend(lpd[key] )).std(dim='time')/100, c=c , ls='--')
    ax[1].set_ylabel('standard deviation OHC [J/m]')

for i in range(2):
    ax[i].set_xlabel('latitude')
    ax[i].set_xlim((-80,92))
    ax[i].set_xticks(np.arange(-60,100,30))
    ax[i].set_xticklabels(np.arange(-60,100,30))
ax[0].legend(fontsize=8, frameon=False)

plt.savefig(f'{path_results}/paper/OHC_zonal_mean_std_ctrl_lpd')

# Fig 3: spatial correlation

# Fig 4: SST index spectra
from `SST_indices.ipynb`

In [None]:
from bb_analysis_timeseries import AnalyzeTimeSeries as ATS

In [None]:
f, ax = plt.subplots(3, 3, figsize=(6.4,4), sharex=True, constrained_layout=True)
for i, run in enumerate(['had', 'ctrl', 'lpd']):
    if run=='had':
        dt = 'GMST_tfdt'
    else:
        dt = 'quadratic_pwdt'
    
    for j, idx in enumerate(['AMO', 'SOM', 'TPI']):
        if idx=='TPI':
            da1 = xr.open_dataarray(f'{path_samoc}/SST/{idx}1_{dt}_raw_{run}.nc')
            da2 = xr.open_dataarray(f'{path_samoc}/SST/{idx}2_{dt}_raw_{run}.nc')
            da3 = xr.open_dataarray(f'{path_samoc}/SST/{idx}3_{dt}_raw_{run}.nc')
            da = da2 - (da1+da3)/2
        else:
            da = xr.open_dataarray(f'{path_samoc}/SST/{idx}_{dt}_raw_{run}.nc')
        
        ft, fc = 'lowpass', 13
        spec = ATS(da).spectrum(filter_type=ft, filter_cutoff=fc)  # spectrum
        rnsp = ATS(da).mc_ar1_spectrum(filter_type=ft, filter_cutoff=fc)  # red noise spectrum

        ax[j,i].plot(1/rnsp[1,:], rnsp[3,:]   , c='C1',                 label='AR(1) 95% C.I.')
#         ax[j,i].plot(1/rnsp[1,:], rnsp[0,:]   , c='C1', lw=.5, ls=':' , label='AR(1) median')
        ax[j,i].plot(1/spec[1]  , spec[0]     , c='C0',                 label='MT spectrum')
        ax[j,i].loglog(1/spec[1], spec[2].T[1], c='C0', lw=.5, ls='--', label='jackknife est.')
        ax[j,i].set_xscale('log')
        ax[j,i].set_yscale('log')     
        ax[j,i].set_xlim((10,100))
        ax[j,i].set_ylim((1e-4, 1.5*np.max(spec[2].T[1])))
        ax[j,i].set_yticklabels([])
        ax[j,0].set_ylabel(f'{idx} index\nspectral power')
        if i==0 and j==0:  ax[j,i].legend(loc=4, fontsize=8, frameon=False)
#         ax[j,i].set_xticklabels([])
        ax[j,i].get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    ax[2,i].set_xticks([10,20,40,80])
    ax[2,i].set_xticklabels([10,20,40,80])
    
    ax[0,i].title.set_text(run.upper())
    ax[2,i].set_xlabel('period [years]')
plt.savefig(f'{path_results}/paper/spectra')

# Fig 5: OHC time series

In [None]:
from filters import lowpass

In [None]:
f, ax = plt.subplots(1, 2 , figsize=(6.4,2.5), sharey=True,
                     gridspec_kw={"width_ratios":[len(ctrl.time), len(lpd.time)]})
for i in range(2):
    ax[i].axhline(0, c='grey', lw=.5)
    ax[i].set_xlabel('time [model years]')

for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    key = f'OHC_{ocean}_Ocean'
    c = ['k' ,'C0','C1','C3'][i]
    for j, da in enumerate([ctrl_qd, lpd_qd]):
        x = da[key]/1e21
        ax[j].plot(x.time, x, lw=.5, c=c ,label=f'{ocean}')
#         ax[j].plot(x.time, lowpass(x,5), c=c ,label=f'{ocean}')

ax[0].set_ylabel('OHC anomaly [ZJ]')
ax[1].legend(fontsize=8, ncol=2, frameon=False)
plt.savefig(f'{path_results}/paper/OHC_time_series')

# Fig 6: OHC depth-zonal integral mean and std

# Fig 7: OHC depth-zonal integral Hovmöller diagram

In [None]:
extents = [(-78,90), (-40,80), (-40,70), (-40,30)]
height_ratios = [a[1]-a[0] for a in extents]
f, ax = plt.subplots(4, 3, figsize=(6.4,8), sharex='col',
                     gridspec_kw={"width_ratios":[1,1, 0.05], "height_ratios":height_ratios})
cY, cX = np.meshgrid(ctrl.t_lat, ctrl.time)
lY, lX = np.meshgrid(lpd.TLAT.mean(axis=1), lpd.time)
vex, ims = [3e16, 2e16, 2e16, 1e16], []
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Indian']):
    kwargs = {'cmap':'RdBu', 'vmin':-vex[i], 'vmax':vex[i]}
    key = f'OHC_zonal_{ocean}_Ocean'
    im = ax[i,0].pcolormesh(cX, cY, ctrl[key]-xr_quadtrend(ctrl[key]), **kwargs)
    ims.append(im)
    ax[i,1].pcolormesh(lX, lY, (lpd[key]-xr_quadtrend(lpd[key]))/100, **kwargs)
    for j in range(2):  
        ax[i,j].axhline(0, c='grey', lw=.5, ls='--')
        ax[i,j].set_yticks(np.arange(-60,100,30))
        ax[i,j].set_ylim(extents[i])
    if i==0:
        ax[0,0].text(.05, .2, 'Southern Ocean', c='g', transform=ax[0,0].transAxes)
        for j in range(2):
            ax[0,j].axhline(-31.5, c='g', lw=.8)
            ax[0,j].text(.5, 1.02, ['CTRL', 'LPD'][j], transform=ax[0,j].transAxes, ha='center')
            ax[-1,j].set_xlabel('time [model years]')
    ax[i,1].set_yticklabels([])
    ax[i,0].text(.05, .9, ocean, c='g', transform=ax[i,0].transAxes)
    ax[i,0].set_ylabel('latitude')
    ax[i,0].get_shared_y_axes().join(ax[i,0], ax[i,1])
    cb = f.colorbar(ims[i], cax=ax[i,2], ticks=np.arange(-3e16,4e16,1e16))
    cb.outline.set_visible(False)
plt.savefig(f'{path_results}/paper/OHC_zonal_Hovmoeller')

# Fig 8: OHC horizontal integral Hovmöller diagram

In [None]:
def plot_vertical_Hovmoeller(ctrl, lpd, ylim, fn, offset=True):
    oceans = ['Global', 'Atlantic', 'Pacific', 'Southern']
    das = [ctrl, lpd]
    f, ax = plt.subplots(len(oceans),2, figsize=(6.4,5), sharey=True, sharex='col')
    for j, ocean in enumerate(oceans):
        if j<5:  name = f'{ocean}_Ocean'
        else:    name = ocean
        da1 = das[0][f'OHC_levels_{name}']
        da2 = das[1][f'OHC_levels_{name}']
        maxv = np.max([(da1-da1.isel(time=slice(0,30)).mean(dim='time')).max(), 
                       (da2-da2.isel(time=slice(0,30)).mean(dim='time')).max()])/1e21/1.2
        for i in range(2):
            da = das[i][f'OHC_levels_{name}']
            X, Y = np.meshgrid(da.time, -da.coords[['depth_t', 'z_t'][i]]/[1e3, 1e5][i])
            if offset==True: x = (da-da.isel(time=slice(0,30)).mean(dim='time')).T/1e21
            else:            x = da.T/1e21
            im = ax[j,i].pcolormesh(X, Y, x, vmin=-maxv, vmax=maxv, cmap=cmocean.cm.balance)
#             if i==1:  colorbar(im)
            if i==1:  f.colorbar(im, ax=ax[j,i], orientation='vertical', fraction=.1, label='[ZJ/m]')
        ax[j,0].set_ylabel(ocean)
        ax[j,0].set_ylim(ylim)
    for i in range(2):
        ax[0,i].text(.5,1.05,['CTRL', 'LPD'][i], transform=ax[0,i].transAxes)
        ax[-1,i].set_xlabel('time [model years]')
    plt.savefig(fn)
    return

        
c, l = ctrl_qd, lpd_qd
for j, ylim in enumerate([(-6,0), (-1,0)]):
    ext = '_qd'
    fn = f'{path_results}/paper/OHC_vertical_Hovmoeller_0-{-ylim[0]}km_ctrl_lpd{ext}'
    print(fn)
    plot_vertical_Hovmoeller(c, l, ylim, fn, offset=False)