# Ocean Heat Content - regrid
To ease the computational load, I regridded the yearly averaged temperature and potential density files from the original curvilinear 0.1 degree grid to the a rectangular 0.4 degree grid.

This notebook compared the OHC behaviour of the `ctrl` and `lpd` runs.

In [None]:
import os
import sys
sys.path.append("..")
import numpy as np
import xarray as xr
import cmocean
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot
matplotlib.rc_file('../rc_file')

In [None]:
from tqdm import tqdm
from paths import path_samoc, path_results
from scipy.optimize import curve_fit
from xr_regression import xr_quadtrend

In [None]:
ctrl = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl.nc')
lpd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd.nc' )
spinup = 50  # how many years to ignore due to spinup effects: data from year 51 of ctrl run

f, ax = plt.subplots(1, 2 , figsize=(8,3), sharey=True, gridspec_kw={"width_ratios":[len(ctrl.time), len(lpd.time)]})
for i in range(2):
    ax[i].axhline(0, c='grey', lw=.5)
    ax[i].set_xlabel('time [model years]')
    
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    key = f'OHC_{ocean}_Ocean'
    c = ['k' ,'C0','C1','C3'][i]
    ax[0].plot(ctrl.time/365, (ctrl[key]-ctrl[key][0])/1e21, c=c)
    ax[1].plot(lpd .time/365, (lpd[key]-lpd[key][0]  )/1e21, c=c , ls='--', label=f'{ocean} Ocean')

ax[0].set_ylim((-200, 2000))
ax[0].set_ylabel('OHC anomaly [ZJ]')
    
ax[0].axvspan(0, spinup, alpha=0.3, color='grey')
ax[1].axvspan(lpd.time[0+300-spinup]/365, lpd.time[-1]/365, alpha=0.3, color='grey', label='discarded data')

ax[1].legend(frameon=False)

ctrl = ctrl.isel(time=np.arange(50,300))
ctrl = ctrl.assign_coords(time=ctrl.time.values/365)
lpd  = lpd .isel(time=np.arange(0,300-spinup))
lpd  = lpd .assign_coords(time=lpd .time.values/365)
# plt.savefig(f'{path_results}/OHC/OHC_anomalies_ctrl_lpd')

grey shaded areas indicated discarded data

In [None]:
ctrl

In [None]:
da5 = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl_0253.nc')
da6 = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl_0254.nc')


In [None]:
da5.OHC_vertical_0_6000m.plot()

In [None]:
da6.OHC_vertical_0_6000m.plot()

In [None]:
(da6.OHC_vertical_0_6000m/da5.OHC_vertical_0_6000m).plot()

In [None]:
for y in range(250):
    if ctrl.OHC_Global_Ocean[y]-ctrl.OHC_Global_Ocean[0]<0:
        print(ctrl.time[y].values)

In [None]:
from ac_derivation_OHC import DeriveOHC as DO

In [None]:
for y in [274]:
        DO().generate_OHC_files(run='ctrl', year=250)

In [None]:
DO().combine_yrly_OHC_integral_files(run='ctrl')

In [None]:
DO().fix_ctrl_year_205()

In [None]:
from timeseries import IterateOutputCESM

In [None]:
f, ax = plt.subplots(1,2)
for y, m, f in IterateOutputCESM(domain='ocn', run='ctrl', tavg='monthly'):
    if y>=250 and y<255:
        ds = xr.open_dataset(f, decode_times=False)
        ax[0].scatter(y+m/12, ds.TEMP[0,0,1000,200])
        ax[1].scatter(y+m/12, ds.PD[0,0,1000,200])
        

In [None]:
da1 = xr.open_dataset(f'/projects/0/samoc/andre/CESM/ctrl_rect/TEMP_PD_yrly_0274.interp900x602.nc')

In [None]:
da2 = xr.open_dataset(f'/projects/0/samoc/andre/CESM/ctrl_rect/TEMP_PD_yrly_0275.interp900x602.nc')

In [None]:
da3 = xr.open_dataset(f'/projects/0/samoc/andre/CESM/ctrl/ocn_yrly_TEMP_PD_0274.nc')

In [None]:
da4 = xr.open_dataset(f'/projects/0/samoc/andre/CESM/ctrl/ocn_yrly_TEMP_PD_0275.nc')

In [None]:
f, ax = plt.subplots(2,2, figsize=(8,8))
for i, da in enumerate([da3, da4, da1, da2]):
    im  = ax[int(i/2), i%2].imshow(da.TEMP[35,:,:], origin='upperleft')
    plt.colorbar(im, ax=ax[int(i/2), i%2], orientation='horizontal')

## equilibration time

In [None]:
def exp_decay(x, a, b):
    return a*np.exp(-x/b)


def exp_fit(da_fit):
    initial_guess = (da_fit.values[0], 1000.)
    return curve_fit(exp_decay, da_fit.time, da_fit.values, p0=initial_guess)[0]
   

def adjustment_time(da):
    Delta_da = (da - da.shift(time=1))[1:]
    if len(da.coords)==1:  # 1D time series
        popt = exp_fit(Delta_da)  # array of size ((2))
    else:
        stacked = False
        if len(da.coords)>2:  # need to be stacked
            stacked = True
            Delta_da = Delta_da.stack(stacked_coord=list(da.coords.keys())[1:])
            print(Delta_da)
        assert len(Delta_da.coords)==2
        coord = list(Delta_da.coords)[1]
        A = np.zeros((len(Delta_da[coord]), 2))
        for i, c in tqdm(enumerate(Delta_da[coord])):
            if np.any(np.isnan(c)):  continue  # skipping nans
            A[i,:] = exp_fit(Delta_da[:,i])
        D = Delta_da.isel({'time':0}).drop('time')  # skeleton DataArray with correct dimensions
        popt0 = D.copy(data=A[:,0])
        popt1 = D.copy(data=A[:,1])
        if stacked==True:  # unstacking
            popt0 = popt0.unstack()
            popt1 = popt1.unstack()
        popt0.name, popt1.name= 'popt0','popt1'
        popt = xr.merge([popt0, popt1])
    return popt



In [None]:
ctrl.OHC_levels_Global_Ocean.isel({'time':0}).drop('time')

In [None]:
ctrl_levels_at = adjustment_time(ctrl.OHC_levels_Global_Ocean)
lpd_levels_at  = adjustment_time(lpd .OHC_levels_Global_Ocean)

In [None]:
from grid import find_array_idx

In [None]:
find_array_idx(lpd.z_t, 2e5)

In [None]:
f'{lpd.z_t[45].values:.0f}'

In [None]:
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# def colorbar(mappable):
#     ax = mappable.axes
#     fig = ax.figure
#     divider = make_axes_locatable(ax)
#     cax = divider.append_axes("right", size="5%", pad=0.2)
#     return fig.colorbar(mappable, cax=cax)#, orientation='horizontal')

oceans = ['Global', 'Atlantic', 'Pacific', 'Southern', 'Indian', 'Mediterranean']
das = [ctrl, lpd]
for ylim in [(-6000,0), (-1000,0)]:
    f, ax = plt.subplots(len(oceans),2, figsize=(8,12), sharey=True, sharex='col')
    for j, ocean in enumerate(oceans):
        if j<5:  name = f'{ocean}_Ocean'
        else:    name = ocean
        da1 = das[0][f'OHC_levels_{name}']
        da2 = das[1][f'OHC_levels_{name}']
        maxv = np.max([(da1-da1.isel(time=slice(0,30)).mean(dim='time')).max(), 
                       (da2-da2.isel(time=slice(0,30)).mean(dim='time')).max()])/1e21/1.2
        for i in range(2):
            da = das[i][f'OHC_levels_{name}']
            X, Y = np.meshgrid(da.time, -da.coords[['depth_t', 'z_t'][i]]/[1, 1e2][i])
            
            im = ax[j,i].pcolormesh(X, Y, (da-da.isel(time=slice(0,30)).mean(dim='time')).T/1e21, vmin=-maxv, vmax=maxv, cmap=cmocean.cm.balance)
#             if i==1:  colorbar(im)
            if i==1:  f.colorbar(im, ax=ax[j,i], orientation='vertical', fraction=.1, label='[ZJ/m]')
        ax[j,0].set_ylabel(ocean)
        ax[j,0].set_ylim(ylim)
    for i in range(2):
        ax[0,i].text(.5,1.05,['CTRL', 'LPD'][i], transform=ax[0,i].transAxes)
        ax[-1,i].set_xlabel('time [model years]')
    plt.savefig(f'{path_results}/OHC/OHC_vertical_Hovmoeller_0_{ylim[0]}_ctrl_lpd')

In [None]:
plt.plot(ctrl_levels_at.depth_t, ctrl_levels_at.popt1)
plt.plot(lpd_levels_at .z_t/100, lpd_levels_at .popt1)
plt.ylim(-1000,10000)

In [None]:
# time series: depth-lat-lon integrated OHC
def plot_adjustment_timeseries(da):
    Delta_da = (da - da.shift(time=1))
    popt = adjustment_time(da)
    cax = plt.gca()
    cax.axhline(0, c='k', lw=.5)
    cax.plot(Delta_da.time[1:], Delta_da.values[1:]/1e21)
    cax.plot(Delta_da.time[1:], exp_decay(Delta_da.time[1:], *popt)/1e21, 'r-', label="Fitted Curve")
    tau_symbol = r'$\tau$'
    cax.text(.1, .85, f'{tau_symbol} = {popt[1]:3.0f} years', transform=cax.transAxes, fontsize=16)
    cax.set_ylim((-10,20))
    return

fig = plt.figure(figsize=(8,8))
ax = []
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    Delta = r'$\Delta$'

    ax1 = fig.add_subplot(4,2,2*i+1)
    ax.append(ax1)
    ax1.set_ylabel(f'{Delta}OHC {ocean} [ZJ]')
    plot_adjustment_timeseries(ctrl[f'OHC_{ocean}_Ocean'])
    
    ax2 = fig.add_subplot(4,2,2*i+2)
    ax2.set_yticklabels([])
    ax.append(ax2)
    plot_adjustment_timeseries(lpd[f'OHC_{ocean}_Ocean'])
    
for i in range(3):
    j = 2*i
    ax[-1].get_shared_x_axes().join(ax[-1], ax[j+2])
    ax[-2].get_shared_x_axes().join(ax[-2], ax[j+3])
    ax[j].set_xticklabels([])
    ax[j+1].set_xticklabels([])
    if i<2:  ax[-i-1].set_xlabel('model years')

plt.subplots_adjust(bottom=0.06, left=.08, right=0.99, top=0.99, wspace=.1, hspace=.1)
# plt.savefig(f'{path_results}/OHC/OHC_equilibration_full')  # comment out time selection above
plt.savefig(f'{path_results}/OHC/OHC_equilibration_select')


## Latitudinal patterns

In [None]:
lpd_lat = lpd.TLAT.mean(axis=1)

In [None]:
f, ax = plt.subplots(1, 2 , figsize=(8,3))
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    key = f'OHC_zonal_{ocean}_Ocean'
    c = ['k' ,'C0','C1','C3'][i]
# mean
    ax[0].plot(ctrl.t_lat, ctrl[key].mean(dim='time')    , c=c , label=ocean)
    ax[0].plot(lpd_lat   , lpd [key].mean(dim='time')/100, c=c , ls='--')

# std of quad. detrended
    ax[1].plot(ctrl.t_lat, (ctrl[key]-xr_quadtrend(ctrl[key])).std(dim='time'), c=c)
    ax[1].plot(lpd_lat   , (lpd[key] -xr_quadtrend(lpd[key] )).std(dim='time')/100, c=c , ls='--')
ax[0].legend(frameon=False)

plt.savefig(f'{path_results}/OHC/OHC_zonal_mean_std_ctrl_lpd')

In [None]:
extents = [(-78,90), (-40,80), (-40,70), (-40,30)]
height_ratios = [a[1]-a[0] for a in extents]
f, ax = plt.subplots(4, 3 , figsize=(8,10), sharex='col', gridspec_kw={"width_ratios":[1,1, 0.05], "height_ratios":height_ratios})
cY, cX = np.meshgrid(ctrl.t_lat, ctrl.time)
lY, lX = np.meshgrid(lpd.TLAT.mean(axis=1), lpd.time)
vex, ims = [3e16, 2e16, 2e16, 1e16], []
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Indian']):
    kwargs = {'cmap':'RdBu', 'vmin':-vex[i], 'vmax':vex[i]}
    key = f'OHC_zonal_{ocean}_Ocean'
    im = ax[i,0].pcolormesh(cX, cY, ctrl[key] -xr_quadtrend(ctrl[key]), **kwargs)
    ims.append(im)
    ax[i,1].pcolormesh(lX, lY, (lpd[key]-xr_quadtrend(lpd[key]))/100, **kwargs)
    for j in range(2):  
        ax[i,j].axhline(0, c='grey', lw=.5, ls='--')
        ax[i,j].set_yticks(np.arange(-60,100,30))
        ax[i,j].set_ylim(extents[i])
    if i==0:
        ax[0,0].text(.05, .2, 'Southern Ocean', c='g', transform=ax[0,0].transAxes)
        for j in range(2):
            ax[0,j].axhline(-31.5, c='g', lw=.8)
            ax[0,j].text(.5, 1.02, ['CTRL', 'LPD'][j], transform=ax[0,j].transAxes)
            ax[-1,j].set_xlabel('time [model years]')
    ax[i,0].text(.05, .9, ocean, c='g', transform=ax[i,0].transAxes)
    ax[i,0].set_ylabel('latitude')
    ax[i,0].get_shared_y_axes().join(ax[i,0], ax[i,1])
    cb = f.colorbar(ims[i], cax=ax[i,2], ticks=np.arange(-3e16,4e16,1e16))
    cb.outline.set_visible(False)
plt.savefig(f'{path_results}/OHC/OHC_zonal_Hovmoeller_ctrl_lpd')

In [None]:
depth patterns

## depth integrated patterns

In [None]:
ctrl.