# Ocean Heat Content - high vs. low resolution
To ease the computational load, I regridded the yearly averaged temperature and potential density files from the original curvilinear 0.1 degree grid to the a rectangular 0.4 degree grid.

This notebook compared the OHC behaviour of the `ctrl` and `lpd` runs.

In [None]:
import os
import sys
sys.path.append("..")
import numpy as np
import xarray as xr
import cmocean
import matplotlib
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs

from tqdm import tqdm
from dask.distributed import Client, LocalCluster
from scipy.optimize import curve_fit

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot
matplotlib.rc_file('../rc_file')

In [None]:
from paths import path_samoc, path_results
from xr_regression import xr_quadtrend
from regions import boolean_mask
from grid import create_dz_mean

In [None]:
cluster = LocalCluster(n_workers=2)
client = Client(cluster)

In [None]:
client

need to execute:
`ssh -N -L 8787:127.0.0.1:8787 dijkbio@int2-bb.cartesius.surfsara.nl`

In [None]:
ctrl = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl.nc')
ctrl = ctrl.assign_coords(time=ctrl.time.values/365)
lpd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd.nc' )
lpd  = lpd .assign_coords(time=lpd .time.values/365)

ctrl_qd = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl_qd.nc', decode_times=False)
lpd_qd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd_qd.nc' , decode_times=False)

spinup = 50  # how many years to ignore due to spinup effects: data from year 51 of ctrl run

def plot_global_OHC(ctrl, lpd, fn=None, offset=True):
    f, ax = plt.subplots(1, 2 , figsize=(8,3), sharey=True,
                         gridspec_kw={"width_ratios":[len(ctrl.time), len(lpd.time)]})
    for i in range(2):
        ax[i].axhline(0, c='grey', lw=.5)
        ax[i].set_xlabel('time [model years]')

    for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
        key = f'OHC_{ocean}_Ocean'
        c = ['k' ,'C0','C1','C3'][i]
        for j, da in enumerate([ctrl, lpd]):
            if offset==True:  x = (da[key]-da[key][0])/1e21
            else:             x = da[key]/1e21
            ax[j].plot(x.time, x, c=c ,label=f'{ocean} Ocean')

#     ax[0].set_ylim((-250, (np.max(lpd.OHC_Global_Ocean-lpd.OHC_Global_Ocean[0])/1e21)))
    ax[0].set_ylabel('OHC anomaly [ZJ]')
    if fn is not None:
        ax[0].axvspan(0, spinup, alpha=0.3, color='grey')
        ax[1].axvspan(lpd.time[0+300-spinup], lpd.time[-1], alpha=0.3, color='grey', label='discarded data')
    ax[1].legend(frameon=False)
    if fn is not None:  plt.savefig(fn)
    return

plot_global_OHC(ctrl, lpd, fn=f'{path_results}/OHC/OHC_anomalies_ctrl_lpd')

ctrl = ctrl.isel(time=np.arange(50,300))
lpd  = lpd .isel(time=np.arange(0,300-spinup))

plot_global_OHC(ctrl, lpd)
plot_global_OHC(ctrl_qd, lpd_qd, offset=False)

## equilibration time

In [None]:
def exp_decay(x, a, b):
    return a*np.exp(-x/b)

def exp_fit(da_fit):
    initial_guess = (da_fit.values[0], 1000.)
    return curve_fit(exp_decay, da_fit.time, da_fit.values, p0=initial_guess)[0]
   
def adjustment_time(da):
    Delta_da = (da - da.shift(time=1))[1:]
    if len(da.coords)==1:  # 1D time series
        popt = exp_fit(Delta_da)  # array of size ((2))
    else:
        stacked = False
        if len(da.coords)>2:  # need to be stacked
            stacked = True
            Delta_da = Delta_da.stack(stacked_coord=list(da.coords.keys())[1:])
            print(Delta_da)
        assert len(Delta_da.coords)==2
        coord = list(Delta_da.coords)[1]
        A = np.zeros((len(Delta_da[coord]), 2))
        for i, c in tqdm(enumerate(Delta_da[coord])):
            if np.any(np.isnan(c)):  continue  # skipping nans
            A[i,:] = exp_fit(Delta_da[:,i])
        D = Delta_da.isel({'time':0}).drop('time')  # skeleton DataArray with correct dimensions
        popt0 = D.copy(data=A[:,0])
        popt1 = D.copy(data=A[:,1])
        if stacked==True:  # unstacking
            popt0 = popt0.unstack()
            popt1 = popt1.unstack()
        popt0.name, popt1.name= 'popt0','popt1'
        popt = xr.merge([popt0, popt1])
    return popt

In [None]:
# time series: depth-lat-lon integrated OHC
def plot_adjustment_timeseries(da):
    Delta_da = (da - da.shift(time=1))
    popt = adjustment_time(da)
    cax = plt.gca()
    cax.axhline(0, c='k', lw=.5)
    cax.plot(Delta_da.time[1:], Delta_da.values[1:]/1e21)
    cax.plot(Delta_da.time[1:], exp_decay(Delta_da.time[1:], *popt)/1e21, 'r-', label="Fitted Curve")
    tau_symbol = r'$\tau$'
    cax.text(.1, .85, f'{tau_symbol} = {popt[1]:3.0f} years', transform=cax.transAxes, fontsize=16)
    cax.set_ylim((-10,20))
    return

fig = plt.figure(figsize=(8,8))
ax = []
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    Delta = r'$\Delta$'

    ax1 = fig.add_subplot(4,2,2*i+1)
    ax.append(ax1)
    ax1.set_ylabel(f'{Delta}OHC {ocean} [ZJ]')
    plot_adjustment_timeseries(ctrl[f'OHC_{ocean}_Ocean'])
    
    ax2 = fig.add_subplot(4,2,2*i+2)
    ax2.set_yticklabels([])
    ax.append(ax2)
    plot_adjustment_timeseries(lpd[f'OHC_{ocean}_Ocean'])
    
for i in range(3):
    j = 2*i
    ax[-1].get_shared_x_axes().join(ax[-1], ax[j+2])
    ax[-2].get_shared_x_axes().join(ax[-2], ax[j+3])
    ax[j].set_xticklabels([])
    ax[j+1].set_xticklabels([])
    if i<2:  ax[-i-1].set_xlabel('model years')

plt.subplots_adjust(bottom=0.06, left=.08, right=0.99, top=0.99, wspace=.1, hspace=.1)
# plt.savefig(f'{path_results}/OHC/OHC_equilibration_full')  # comment out time selection above
# plt.savefig(f'{path_results}/OHC/OHC_equilibration_select')


## depth warming patterns

In [None]:
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# def colorbar(mappable):
#     ax = mappable.axes
#     fig = ax.figure
#     divider = make_axes_locatable(ax)
#     cax = divider.append_axes("right", size="5%", pad=0.2)
#     return fig.colorbar(mappable, cax=cax)#, orientation='horizontal')

def plot_vertical_Hovmoeller(ctrl, lpd, ylim, fn, offset=True):
    oceans = ['Global', 'Atlantic', 'Pacific', 'Southern', 'Indian', 'Mediterranean']
    das = [ctrl, lpd]
    f, ax = plt.subplots(len(oceans),2, figsize=(8,12), sharey=True, sharex='col')
    for j, ocean in enumerate(oceans):
        if j<5:  name = f'{ocean}_Ocean'
        else:    name = ocean
        da1 = das[0][f'OHC_levels_{name}']
        da2 = das[1][f'OHC_levels_{name}']
        maxv = np.max([(da1-da1.isel(time=slice(0,30)).mean(dim='time')).max(), 
                       (da2-da2.isel(time=slice(0,30)).mean(dim='time')).max()])/1e21/1.2
        for i in range(2):
            da = das[i][f'OHC_levels_{name}']
            X, Y = np.meshgrid(da.time, -da.coords[['depth_t', 'z_t'][i]]/[1, 1e2][i])
            if offset==True: x = (da-da.isel(time=slice(0,30)).mean(dim='time')).T/1e21
            else:            x = da.T/1e21
            im = ax[j,i].pcolormesh(X, Y, x, vmin=-maxv, vmax=maxv, cmap=cmocean.cm.balance)
#             if i==1:  colorbar(im)
            if i==1:  f.colorbar(im, ax=ax[j,i], orientation='vertical', fraction=.1, label='[ZJ/m]')
        ax[j,0].set_ylabel(ocean)
        ax[j,0].set_ylim(ylim)
    for i in range(2):
        ax[0,i].text(.5,1.05,['CTRL', 'LPD'][i], transform=ax[0,i].transAxes)
        ax[-1,i].set_xlabel('time [model years]')
    plt.savefig(fn)
    return

        
for i, [c, l] in enumerate([[ctrl, lpd], [ctrl_qd, lpd_qd]]):
    for j, ylim in enumerate([(-6000,0), (-1000,0)]):
        ext = ['','_qd'][i]
        fn = f'{path_results}/OHC/OHC_vertical_Hovmoeller_0-{-ylim[0]}m_ctrl_lpd{ext}'
        print(fn)
        plot_vertical_Hovmoeller(c, l, ylim, fn, offset=[True, False][i])

## Latitudinal patterns

In [None]:
lpd_lat = lpd.TLAT.mean(axis=1)

In [None]:
f, ax = plt.subplots(1, 2 , figsize=(8,3))
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    key = f'OHC_zonal_{ocean}_Ocean'
    c = ['k' ,'C0','C1','C3'][i]
# mean
    ax[0].plot(ctrl.t_lat, ctrl[key].mean(dim='time')    , c=c , label=ocean)
    ax[0].plot(lpd_lat   , lpd [key].mean(dim='time')/100, c=c , ls='--')

# std of quad. detrended
    ax[1].plot(ctrl.t_lat, (ctrl[key]-xr_quadtrend(ctrl[key])).std(dim='time'), c=c)
    ax[1].plot(lpd_lat   , (lpd[key] -xr_quadtrend(lpd[key] )).std(dim='time')/100, c=c , ls='--')
ax[0].legend(frameon=False)

plt.savefig(f'{path_results}/OHC/OHC_zonal_mean_std_ctrl_lpd')

In [None]:
# testing whether latitude binning influences the std
test = ctrl[f'OHC_zonal_Global_Ocean'].coarsen(dim={'t_lat':5}, boundary='trim').mean()
plt.plot(ctrl.t_lat, (ctrl[f'OHC_zonal_Global_Ocean']-xr_quadtrend(ctrl[f'OHC_zonal_Global_Ocean'])).std(dim='time'), c=c)
plt.plot(test.t_lat, (test-xr_quadtrend(test)).std(dim='time'), c=c)

In [None]:
extents = [(-78,90), (-40,80), (-40,70), (-40,30)]
height_ratios = [a[1]-a[0] for a in extents]
f, ax = plt.subplots(4, 3 , figsize=(8,10), sharex='col', gridspec_kw={"width_ratios":[1,1, 0.05], "height_ratios":height_ratios})
cY, cX = np.meshgrid(ctrl.t_lat, ctrl.time)
lY, lX = np.meshgrid(lpd.TLAT.mean(axis=1), lpd.time)
vex, ims = [3e16, 2e16, 2e16, 1e16], []
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Indian']):
    kwargs = {'cmap':'RdBu', 'vmin':-vex[i], 'vmax':vex[i]}
    key = f'OHC_zonal_{ocean}_Ocean'
    im = ax[i,0].pcolormesh(cX, cY, ctrl[key] -xr_quadtrend(ctrl[key]), **kwargs)
    ims.append(im)
    ax[i,1].pcolormesh(lX, lY, (lpd[key]-xr_quadtrend(lpd[key]))/100, **kwargs)
    for j in range(2):  
        ax[i,j].axhline(0, c='grey', lw=.5, ls='--')
        ax[i,j].set_yticks(np.arange(-60,100,30))
        ax[i,j].set_ylim(extents[i])
    if i==0:
        ax[0,0].text(.05, .2, 'Southern Ocean', c='g', transform=ax[0,0].transAxes)
        for j in range(2):
            ax[0,j].axhline(-31.5, c='g', lw=.8)
            ax[0,j].text(.5, 1.02, ['CTRL', 'LPD'][j], transform=ax[0,j].transAxes, ha='center')
            ax[-1,j].set_xlabel('time [model years]')
    ax[i,0].text(.05, .9, ocean, c='g', transform=ax[i,0].transAxes)
    ax[i,0].set_ylabel('latitude')
    ax[i,0].get_shared_y_axes().join(ax[i,0], ax[i,1])
    cb = f.colorbar(ims[i], cax=ax[i,2], ticks=np.arange(-3e16,4e16,1e16))
    cb.outline.set_visible(False)
plt.savefig(f'{path_results}/OHC/OHC_zonal_Hovmoeller_ctrl_lpd')

### change in depth-latitude space

In [None]:
ctrl_qd[f'OHC_zonal_levels_{ocean}_Ocean'].isel(time=0).plot()

In [None]:
def plot_vertical_latitude_change(ctrl, lpd, ylim=(-6000,0), y=None, fn=None):
    oceans = ['Global', 'Atlantic', 'Pacific', 'Indian']
    das = [ctrl, lpd]
    f, ax = plt.subplots(len(oceans),2, figsize=(8,10), sharey=True, sharex='col')
    for i, ds in enumerate([ctrl, lpd]):
        X, Y = np.meshgrid([ctrl.t_lat,lpd_lat][i], -ds.coords[['depth_t', 'z_t'][i]]/[1, 1e2][i])
        
        
        for j, ocean in enumerate(oceans):
            if j==0:
                ax[0,i].text(.5, 1.02, ['CTRL', 'LPD'][j], transform=ax[0,i].transAxes, ha='center')
                ax[-1,i].set_xlabel('time [model years]')
                ax[0,i].axvline(-31.5, c='g', lw=.8)
                
            if i==0:
                ax[j,0].set_ylabel(f'{ocean}')
                
                
            da = ds[f'OHC_zonal_levels_{ocean}_Ocean']
            
            if y is None:
                maxv = 8e13
                x = (da.isel(time=slice(-31,-1)).mean(dim='time') -\
                     da.isel(time=slice(0,30)).mean(dim='time'))*[1,1e-2][i]
            else:
                assert type(y)==int and y>=0 and y<len(da.time)
                maxv = 2e13
                x = da.isel(time=y)*[1,1e-2][i]
                
            im = ax[j,i].pcolormesh(X, Y, x,
                                    vmin=-maxv, vmax=maxv,
                                    cmap=cmocean.cm.balance)
            if i==1:
                f.colorbar(im, ax=ax[j,i], orientation='vertical', fraction=.1, label=r'[J/m$^{2}$]')
            ax[j,i].set_ylim(ylim)
            if y is not None: ax[0,0].text(300, 100, f'year {y:3.0f}')
            if fn is not None:  plt.savefig(fn)

fn  = f'{path_results}/OHC_vertical_latitude_difference_ctrl_lpd'
# plot_vertical_latitude_change(ctrl, lpd, fn=fn)
for i in range(250):
    fn  = f'{path_results}/OHC_vertical_latitude_anomaly_ctrl_lpd_{i:04d}'
    plot_vertical_latitude_change(ctrl_qd, lpd_qd, y=i, fn=fn)

## depth integrated patterns

In [None]:
maxv = []
for j, depths in enumerate([(0,6000), (0,100), (0,700), (700,2000)]):
    key = f'OHC_vertical_{depths[0]}_{depths[1]}m'
    maxv.append(np.max([np.abs(ctrl_qd[key]).max(), np.abs(lpd_qd[key]).max()])/3)
print(maxv)

In [None]:
%%time
for y in range(1):
    f, ax = plt.subplots(4, 3 , figsize=(10,10),
                         gridspec_kw={"width_ratios":[1,1, 0.05]}, 
                         subplot_kw=dict(projection=ccrs.EqualEarth(central_longitude=300)))
    for i, ds in enumerate([ctrl_qd, lpd_qd]):
        name = ['CTRL', 'LPD'][i]
        MASK = boolean_mask(['ocn_rect', 'ocn_low'][i], mask_nr=0)
        if i==0:   X, Y = np.meshgrid(ds.t_lon, ds.t_lat)
        else:      X, Y = ds.TLONG, ds.TLAT
        for j, depths in enumerate([(0,6000), (0,100), (0,700), (700,2000)]):
            key = f'OHC_vertical_{depths[0]}_{depths[1]}m'
            im = ax[j,i].pcolormesh(X, Y, ds[key][y,:,:].where(MASK),
                                    transform=ccrs.PlateCarree(),
                                    vmin=-maxv[j], vmax=maxv[j],
                                    cmap=cmocean.cm.balance)
            ax[j,i].add_feature(cartopy.feature.LAND,
                                zorder=2, edgecolor='black', facecolor='w')
            if j==0:
                year = f'{ds.time.values[y]:3.0f}'
                ax[0,i].text(.5, 1.1, f'{name} (year {year})',
                             transform=ax[0,i].transAxes, ha='center')
            ax[j,i].text(.5, 1.02, ['full depth (0-6000m)', 'surface (0-100m)',
                                    'upper ocean (0-700m)', 'lower ocean (700-2000m)'][j],
                         transform=ax[j,i].transAxes, ha='center')
            if i==1:
                cb = f.colorbar(im, ax=ax[j,2], orientation='vertical', label=r'OHC [J/m$^{2}$]')#, ticks=np.arange(-3e16,4e16,1e16))
                cb.outline.set_visible(False)
#                 ax[j,0].set_ylabel(['full depth (0-6000m)', 'surface (0-100m)',
#                                     'upper ocean (0-700m)', 'lower ocean (700-2000m)'][j])