In [None]:
import os
import sys
import numpy as np
import xarray as xr
import cartopy
import cartopy.crs as ccrs
import scipy.stats as stats
import cmocean
import datetime
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
matplotlib.rc_file('../rc_file_paper')
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot

In [None]:
sys.path.append("..")
from maps import rect_polygon
from tqdm import tqdm_notebook
from paths import path_samoc, path_results, path_prace
from paths import file_ex_ocn_ctrl, file_ex_ocn_lpd, file_RMASK_ocn
from regions import SST_index_bounds
from filters import lowpass
from constants import spy, A_earth
from xr_regression import xr_quadtrend
from scipy.optimize import curve_fit
from SST_index_generation import times_ctrl, times_lpd, times_had
from bb_analysis_timeseries import AnalyzeTimeSeries as ATS

color conventions:
- global black
- Atlantic blue
- Pacific orange
- Southern red

In [None]:
# OHC data
ctrl = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl.nc').isel(time=np.arange(50,300))
lpd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd.nc' ).isel(time=np.arange(0,250))

ctrl_qd = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_ctrl_qd.nc', decode_times=False)
lpd_qd  = xr.open_dataset(f'{path_samoc}/OHC/OHC_integrals_lpd_qd.nc' , decode_times=False)

# top of atmosphere imbalance
TOA_ctrl = xr.open_dataarray(f'{path_prace}/TOA/TOM_ctrl.nc', decode_times=False).isel(time=slice(50,300))
TOA_lpd  = xr.open_dataarray(f'{path_prace}/TOA/TOM_lpd.nc' , decode_times=False).isel(time=slice(0,250))

# surface heat flux into the ocean
SHF_ctrl = xr.open_dataset(f'{path_prace}/OHC/SHF_ctrl.nc', decode_times=False).isel(time=slice(50,300))
SHF_lpd  = xr.open_dataset(f'{path_prace}/OHC/SHF_lpd.nc' , decode_times=False).isel(time=slice(0,250))

# Fig 1 alternative: global SHF + TOA

In [None]:
f, ax = plt.subplots(1, 2, sharey=True, figsize=(6.4, 2.5))
for i, run in enumerate(['ctrl', 'lpd']):
    da_SHF = [SHF_ctrl, SHF_lpd][i]['Global_Ocean']
    da_OHC = [ctrl, lpd][i]['OHC_Global_Ocean']
    da_TOA = [TOA_ctrl, TOA_lpd][i]
    ax[i].axhline(0, c='grey', lw=.5)
    ax[i].plot(da_SHF.time/365, da_SHF/1e21, c='C0', label='SHF', lw=.5, alpha=.7)
    ax[i].plot(da_SHF.time/365, xr_quadtrend(da_SHF)/1e21, c='C0', ls='--')
    ax[i].plot(da_SHF.time[7:-7]/365, lowpass(da_SHF,13)[7:-7]/1e21, c='C0')
    ax[i].set_xlabel('time [model years]')
    ax[i].text(.05,.9, ['HIGH', 'LOW'][i], transform=ax[i].transAxes)
#     ax[i].tick_params(labeltop=False, labelright=True)
# ax[0].legend(loc=3)
# Delta = r'$\Delta$'
ax[0].set_ylabel(f'heat flux into ocean [ZJ/yr]')
plt.savefig(f'{path_results}/paper/SHF_ctrl_lpd')

## Table 2: $\Delta$OHC mean and std

In [None]:
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
    means, stds = [], []
    for j, da in enumerate([ctrl, lpd]):
        da_ = (da[f'OHC_{ocean}_Ocean']-da[f'OHC_{ocean}_Ocean'].shift(time=1)).sel(time=slice(200,300))/1e21
#         print(da_.time.sel(time=slice(200,300)))
        means.append(da_.mean().values)
        stds.append(da_.std().values)
    print(f'{ocean:8} & ${means[0]:4.1f} \pm {stds[0]:3.1f}$ & ${means[1]:3.1f} \pm {stds[1]:3.1f}$ \\\\')

# Fig 2: SST index regression patterns
from `OHC-ctrl_vs_lpd.ipynb`

In [None]:
# test presence of index files
f, ax = plt.subplots(2,3, sharey='row', figsize=(12,8))
for i, run in enumerate(['had', 'ctrl', 'lpd']):
    if run=='ctrl':   ts = '_51_301'
    elif run=='lpd':  ts = '_154_404'
    elif run=='had':  ts = ''
    for j, idx in enumerate(['AMO', 'SOM', 'TPI', 'PMV_EOF_20N', 'PMV_EOF_Eq', 'PMV_EOF_38S']):
        if idx in ['AMO', 'SOM', 'TPI']:  dt = '_ds_dt_raw'
        else:                             dt = ''
        fn = f'{path_prace}/SST/{idx}{dt}_{run}{ts}.nc'
        try:  assert os.path.exists(fn)
        except:  print(f'does not exists: {fn}')
        if idx in ['AMO', 'SOM', 'TPI']:
            da = xr.open_dataarray(fn, decode_times=False)
            ax[0,i].plot(da.time[7*12:-7*12], lowpass(da, 12*13)[7*12:-7*12], label=idx)
            ax[0,i].legend()
        else:
            ds = xr.open_dataset(fn, decode_times=False)
            ax[1,i].plot(ds.time[7*12:-7*12], lowpass(ds.pcs.isel(mode=0).squeeze()[7*12:-7*12], 12*13), c=f'C{j}', ls='-' , label=idx)
            ax[1,i].plot(ds.time[7*12:-7*12], lowpass(ds.pcs.isel(mode=1).squeeze()[7*12:-7*12], 12*13), c=f'C{j}', ls='--', label=idx)
            ax[1,i].plot(ds.time[7*12:-7*12], lowpass(ds.pcs.isel(mode=2).squeeze()[7*12:-7*12], 12*13), c=f'C{j}', ls=':' , label=idx)
            vfs = ds.variance_fractions.values*100
            print(f'{run:4} {idx:12} variance fraction:     mode 0: {vfs[0]:4.1f}%,  mode 1: {vfs[1]:4.1f}%,   mode 2: {vfs[2]:4.1f}%')
        for k in range(2):  ax[k,i].axhline(c='k')
        
    print('\n')

In [None]:
f = plt.figure(figsize=(15,7.8), constrained_layout=False)

for i, idx in enumerate(['AMO', 'PMV_EOF_20N', 'SOM']):
#     maxv = [.4, .3, .25][i]
#     maxv = [3, 2, 1.5][i]
    maxv = [.23, .2, .13][i]
    ticks = [-.3,-.2,-.1,0,.1,.2,.3]
    ax = f.add_subplot(3, 5, 1+i*5)
    ax.set_position([.009,.01+(2-i)*.32,.02,.3])
    ax.text(.5, .5, ['AMV', 'PDO', 'SOM'][i], transform=ax.transAxes, rotation='vertical', va='center', ha='right', fontsize=20)
    ax.axis('off')
    
    rects = [[rect_polygon(SST_index_bounds('AMO'))], 
             [rect_polygon(SST_index_bounds('PDO'))],
             [rect_polygon(SST_index_bounds('SOM'))]
            ][i]
    
    for j, run in tqdm_notebook(enumerate(['had', 'ctrl', 'lpd'])):
        if run=='had':   ts = ''
        if run=='ctrl':  ts = '_51_301' #'_mean'
        if run=='lpd':   ts = '_154_404'#'_mean'
        if idx in ['AMO', 'TPI', 'SOM']:
            std = lowpass(xr.open_dataarray(f'{path_prace}/SST/{idx}_ds_dt_raw_{run}{ts}.nc'), 12*13).std().values
            xa = xr.open_dataset(f'{path_prace}/SST/{idx}_regr_{run}{ts}.nc')
        else:
            std = lowpass(xr.open_dataset(f'{path_prace}/SST/{idx}_{run}{ts}.nc').pcs.isel(mode=0), 12*13).std().values
            xa = xr.open_dataset(f'{path_prace}/SST/PMV_20N_regr_{run}{ts}.nc')
            
        if run=='had':     lats, lons = xa.latitude, xa.longitude; lons, lats = np.meshgrid(lons, lats)
        elif run=='ctrl':  lats, lons = xa.t_lat, xa.t_lon; lons, lats = np.meshgrid(lons, lats)
        elif run=='lpd':   lats, lons = xa.TLAT.values, xa.TLONG.values
        
        ax = f.add_subplot(3, 5, j+2+i*5, projection=ccrs.Robinson(central_longitude=[-60, 200, -60][i]))
        if i==0:  ax.text(.5, 1.05, ['HIST', 'HIGH', 'LOW'][j], transform=ax.transAxes, fontsize=20, ha='center')
        ax.set_global()
        ax.set_position([.02+j*.31,.01+(2-i)*.32,.305,.3])

        im = ax.pcolormesh(lons, lats, xa.slope*std, cmap='cmo.balance',
                           vmin=-maxv, vmax=maxv, transform=ccrs.PlateCarree())
        plt.tricontour(lons.flatten(), lats.flatten(), xa.pval.where(np.isnan(xa.pval.values)==False, .5).values.flatten(),
                       levels=[.01,.99], colors='purple', linestyles='dashed', transform=ccrs.PlateCarree())
        ax.add_feature(cartopy.feature.LAND, zorder=2, edgecolor='black', facecolor='grey')

        gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False)
        gl.ylocator = matplotlib.ticker.FixedLocator([-90, -60, -30, 0, 30, 60, 90])
        
        for rect in rects:
            ax.add_patch(matplotlib.patches.Polygon(xy=rect,
                                          facecolor='none', edgecolor='k',
                                          linewidth=1, zorder=3,
                                          transform=ccrs.PlateCarree(), ), )
            
    ax = f.add_subplot(3, 5, 5+i*5)
    ax.set_position([.955,.03+(2-i)*.32,.01,.26])
    cbar = plt.colorbar(im, cax=ax, shrink=.9, pad=.0, orientation='vertical', extend='both', ticks=ticks)
    cbar.ax.set_yticklabels(ticks, fontsize=12)
    f.text(.985,.16+(2-i)*.32, '[K]', fontsize=14, va='center', rotation=90)
plt.savefig(f'{path_results}/paper/regression_patterns_AMV_PDO_SOM')

# Fig 3: SST index spectra
from `SST_indices.ipynb`

! recompute

### unfiltered periodogram

In [None]:
# calculate spectra and put them into dictionary
per_dict_unfilt = {}
ft = 'lowpass'

for i, run in enumerate(['had', 'ctrl', 'lpd']):
    if run=='ctrl':   ts = '_51_301' 
    elif run=='lpd':  ts = '_154_404'
    elif run=='had':  ts = ''
    for j, idx in tqdm_notebook(enumerate(['AMO', 'SOM', 'TPI', 'PMV_EOF_20N', 'PMV_EOF_Eq', 'PMV_EOF_38S'])):
        if idx in ['AMO','SOM','TPI']: dt = '_ds_dt_raw'
        else:                          dt = ''
        fc = 12*13
        fn = f'{path_prace}/SST/{idx}{dt}_{run}{ts}.nc'
        assert os.path.exists(fn), f'{fn} does not exist'
        if idx in ['AMO', 'SOM', 'TPI']:  da = xr.open_dataarray(fn, decode_times=False)
        else:                             da = xr.open_dataset(fn, decode_times=False).pcs.isel(mode=0).squeeze()
        da = da.isel(time=slice(0,int(250*12)))
        assert len(da) in [12*149, 12*250]
        per_dict_unfilt[f'{idx}_{run}_spec'] = ATS(da).periodogram()
        per_dict_unfilt[f'{idx}_{run}_rnnh'] = ATS(da).mc_ar1_spectrum(spectrum='per', N=10000)  # red noise spectrum

In [None]:
# final spectra
xticks = [80,40,20,10,5]
locmin = matplotlib.ticker.LogLocator(base=10.0,subs=[1/x for x in [70,60,50,30,9,8,7,6,4,3]],numticks=10)
f, ax = plt.subplots(3, 3, figsize=(6.4,4), sharex=True, sharey='row', constrained_layout=True)
for i, run in enumerate(['had', 'ctrl', 'lpd']):
    for j, idx in enumerate(['AMO', 'PMV_EOF_20N', 'SOM']):
        factor = 12
        spec = per_dict_unfilt[f'{idx}_{run}_spec']
        rnnh = per_dict_unfilt[f'{idx}_{run}_rnnh']
        ax[j,i].set_xscale('log', basex=10)
        ax[j,i].set_yscale('log', basey=10)
        ax[j,i].set_xlim((1/80,1/5))
        power = [0.5,1.5,0.5][j]
        ax[j,i].set_ylim(bottom=10**(power), top=10**(power-3))
        ax[j,i].plot(rnnh['freq']*factor, rnnh['median']/factor, c='C1')
        ax[j,i].plot(rnnh['freq']*factor, rnnh['95']    /factor, c='C1', ls='--', label='AR(1) 95%')
        ax[j,i].plot(rnnh['freq']*factor, rnnh['99']    /factor, c='C1', ls=':' , label='AR(1) 99%')
        ax[j,i].plot(spec[0]     *factor, spec[1]       /factor, c='C0', label='periodogram')
        ax[j,0].set_ylabel(f'{["AMV", "PDO", "SOM"][j]} index\nspectral power')
        ax[j,i].xaxis.set_minor_locator(locmin)
        ax[j,i].xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
        ax[j,i].set_xticks([1/x for x in xticks])
        ax[j,i].set_xticklabels(xticks)

    ax[0,i].title.set_text(['HIST', 'HIGH', 'LOW'][i])
    ax[-1,i].set_xlabel('period [years]')
ax[0,0].legend(loc=3, fontsize=6, frameon=False, handlelength=1)
plt.savefig(f'{path_results}/paper/periodogram')

In [None]:
# matplotlib.rcParams.update({"text.usetex": True, 'text.latex.preamble': [r'\usepackage{nicefrac}']})
matplotlib.rcParams.update({'text.latex.preamble': [r'\usepackage{nicefrac}']})
plt.figure()
plt.text(0.5,0.5,r'$\nicefrac{2}{2}$')

# Fig. 4 & 5: spectra

In [None]:
gmst_had  = xr.open_dataarray(f'{path_prace}/GMST/GMST_dt_yrly_had.nc', decode_times=False)
gmst_had  = gmst_had.isel({'time':slice(9,158)})
gmst_ctrl = xr.open_dataset(f'{path_prace}/GMST/GMST_ctrl.nc').GMST.isel({'time':slice(50,300)})
gmst_ctrl['time'] = (gmst_ctrl.time/365).astype(dtype=int)
gmst_ctrl -= xr_quadtrend(gmst_ctrl)
gmst_lpd  = xr.open_dataset(f'{path_prace}/GMST/GMST_lpd.nc').GMST.isel({'time':slice(0,250)})
gmst_lpd['time'] = (gmst_lpd.time/365).astype(dtype=int)
gmst_lpd -= xr_quadtrend(gmst_lpd)

TOA_ctrl = xr.open_dataarray(f'{path_prace}/TOA/TOM_ctrl.nc', decode_times=False).isel(time=slice(50,300))
TOA_lpd  = xr.open_dataarray(f'{path_prace}/TOA/TOM_lpd.nc' , decode_times=False).isel(time=slice(0,250))

SHF_ctrl = xr.open_dataset(f'{path_prace}/OHC/SHF_ctrl.nc', decode_times=False).isel(time=slice(50,300))
SHF_lpd  = xr.open_dataset(f'{path_prace}/OHC/SHF_lpd.nc' , decode_times=False).isel(time=slice(0,250))

In [None]:
plt.figure(figsize=(6.4,3), constrained_layout=True)
# plt.gca().tick_params(labelsize=14)

h, h_ = [], []
for i, gmst in enumerate([gmst_had, gmst_ctrl, gmst_lpd]):
    ls = [':','-','--'][i]
    (spec, freq, jackknife) = ATS(gmst).spectrum()
#     l_, = plt.plot([0,0], [0,0], c='k', ls=ls, label=['HIST', 'HIGH', 'LOW'][i])
#     h_.append(l_)
    l, = plt.plot(freq, spec, label='GMST', c='C9', ls=ls)
    l_, = plt.plot([0,0],[0,1], ls=ls, c='grey', label=['HIST','HIGH','LOW'][i])
    h_.append(l_)
    
    if i==1: h.append(l)
l1 = plt.legend(handles=h_, bbox_to_anchor=(.3, .75), loc='center left', frameon=False)

for i in range(2):
    ls = ['-','--'][i]
    toa = [TOA_ctrl, TOA_lpd][i]
    toa_ = (toa-xr_quadtrend(toa))*spy/1e21
    shf  = [SHF_ctrl, SHF_lpd][i]
    shf_ = (shf[f'Global_Ocean'] - xr_quadtrend(shf[f'Global_Ocean']))/1e21
    div = toa_.assign_coords(time=shf_.time.values)-shf_
    (spec, freq, jackknife) = ATS(div).spectrum()
    l, = plt.plot(freq, spec, label=r'TOA$-$SHF', c='C8', ls=ls)
    if i==0: h.append(l)
l1 = plt.legend(handles=h_, loc='lower left', frameon=False, ncol=3)
plt.gca().add_artist(l1)
l2 = plt.legend(handles=h, loc='upper left', frameon=False, ncol=2)

plt.xlim(1/64,1/2)
plt.xlabel('period [year]')
plt.ylabel(r'spectral power')
# plt.semilogx()

xticks = [80,40,20,10,5,2]
locmin = matplotlib.ticker.LogLocator(base=10.0,subs=[1/x for x in [70,60,50,30,9,8,7,6,4,3]],numticks=10)
plt.gca().set_yscale('log', basey=10)
plt.gca().set_xscale('log', basex=10)
plt.gca().xaxis.set_minor_locator(locmin)
plt.gca().xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
plt.gca().set_xticks([1/x for x in xticks])
plt.gca().set_xticklabels(xticks)

plt.savefig(f'{path_results}/paper/GMST_TOA-SHF_spectra')

In [None]:
plt.figure(figsize=(6.4,4), constrained_layout=True)
# plt.gca().tick_params(labelsize=14)
h, h_ = [], []
for i, ts in enumerate([SHF_ctrl, SHF_lpd]):
    ls = ['-', '--'][i]
    for j, basin in enumerate(['Global', 'Atlantic','Pacific', 'Southern']):
        c = ['k' ,'C0','C1','C3'][j]
        ts_ = (ts[f'{basin}_Ocean'] - xr_quadtrend(ts[f'{basin}_Ocean']))/1e21
        (spec, freq, jackknife) = ATS(ts_).spectrum()
        shift= [1,1/5,1/400,1/600][j]
        l, = plt.plot(freq, spec*shift,
                      label=f'SHF {basin}',
                      ls=ls, c=c, alpha=[1,.7,.7,.7][j])
        l_, = plt.plot([0,0],[0,1], ls=ls, c='grey', label=['HIGH','LOW'][i])
        if j==0:  h_.append(l_)
        if i==0:
            h.append(l)
            if j>0:  plt.arrow(1/78+2**(j), spec[4], 0, -(spec[4]-spec[4]*shift),\
                               length_includes_head=True, width=1e-5, head_width=.0004, head_length=.01, color=c)
l1 = plt.legend(handles=h_, loc='lower left', frameon=False, ncol=2)
plt.gca().add_artist(l1)
l2 = plt.legend(handles=h, loc='upper left', frameon=False, ncol=2)

plt.xlim(1/80,1/2)
plt.xlabel(r'period [year]')
plt.ylabel('surface heat flux anomaly spectral power')
# plt.semilogx()
xticks = [80,40,20,10,5,2]
locmin = matplotlib.ticker.LogLocator(base=10.0,subs=[1/x for x in [70,60,50,30,9,8,7,6,4,3]],numticks=10)
plt.gca().set_yscale('log', basey=10)
plt.gca().set_xscale('log', basex=10)
plt.gca().xaxis.set_minor_locator(locmin)
plt.gca().xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
plt.gca().set_xticks([1/x for x in xticks])
plt.gca().set_xticklabels(xticks)
plt.savefig(f'{path_results}/paper/SHF_spectra')

# Fig 5: OHC depth-zonal integral Hovmöller diagram + standard deviation (y)

In [None]:
lpd_lat = lpd.TLAT.mean(axis=1)
extents = [(-78,90), (-34,70), (-34,70), (-34,30)]
height_ratios = [a[1]-a[0] for a in extents]
f, ax = plt.subplots(4, 4, figsize=(6.4,8), sharex='col',
                     gridspec_kw={"width_ratios":[1, 1, 0.05, .5], "height_ratios":height_ratios, "wspace":0.03, "hspace":0.03})
cY, cX = np.meshgrid(ctrl.t_lat, ctrl.time)
lY, lX = np.meshgrid(lpd.TLAT.mean(axis=1), lpd.time)
vex, ims = [2.5e16, 1.5e16, 2e16, .7e16], []
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Indian']):
    kwargs = {'cmap':cmocean.cm.balance, 'vmin':-vex[i], 'vmax':vex[i]}
    key = f'OHC_zonal_{ocean}_Ocean'
    im = ax[i,0].pcolormesh(cX, cY, ctrl[key]-xr_quadtrend(ctrl[key]), **kwargs)
    ims.append(im)
    ax[i,1].pcolormesh(lX, lY, (lpd[key]-xr_quadtrend(lpd[key]))/100, **kwargs)
    for j in [0,1,3]:  
        ax[i,j].axhline(0, c='grey', lw=.5, ls='--')
        ax[i,j].set_yticks(np.arange(-60,100,30))
        ax[i,j].set_ylim(extents[i])
    ax[i,1].set_yticklabels([])
    ax[i,0].text(60, extents[i][1]-10, ocean, c='g')
    ax[i,0].set_ylabel('latitude')
    
    ax[i,3].plot((ctrl[key]-xr_quadtrend(ctrl[key])).std(dim='time')    , ctrl.t_lat, c='k')
    ax[i,3].plot((lpd[key] -xr_quadtrend(lpd[key] )).std(dim='time')/100, lpd_lat   , c='k', ls='--')
    ax[i,3].set_yticklabels([])
    
    ax[i,0].get_shared_y_axes().join(ax[i,0], ax[i,1])
    ax[i,0].get_shared_y_axes().join(ax[i,0], ax[i,3])
    
    cb = f.colorbar(ims[i], cax=ax[i,2])#, ticks=np.arange(-3e16,4e16,1e16))
    cb.outline.set_visible(False)
    
ax[0,0].text(60, -75, 'Southern Ocean', c='g')
for j in range(2):
    ax[0,j].axhline(-31.5, c='g', lw=.8)
    ax[0,j].text(.5, 1.02, ['HIGH', 'LOW'][j], transform=ax[0,j].transAxes, ha='center')
    ax[-1,j].set_xlabel('time [model years]')
    ax[-1,-j-1].set_xlabel('[ZJ/m]')
ax[0,3].text(.5, 1.02, 'st. dev.', transform=ax[0,3].transAxes, ha='center')
    
f.align_xlabels()
plt.savefig(f'{path_results}/paper/OHC_zonal_Hovmoeller')

In [None]:
lpd_lat = lpd.TLAT.mean(axis=1)
extents = [(-78,90), (-34,70), (-34,70), (-34,30)]
height_ratios = [a[1]-a[0] for a in extents]
f, ax = plt.subplots(4, 4, figsize=(6.4,8), sharex='col',
                     gridspec_kw={"width_ratios":[1, 1, 0.05, .5], "height_ratios":height_ratios, "wspace":0.03, "hspace":0.03},
                     constrained_layout=True)
cY, cX = np.meshgrid(ctrl.t_lat, ctrl.time/365)
lY, lX = np.meshgrid(lpd.TLAT.mean(axis=1), lpd.time/365)
# vex, ims = [15, 10, 10, 5], []

vex, ims = [12, 9, 9, 6], []
for i, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Indian']):
    kwargs = {'cmap':cmocean.cm.balance, 'vmin':-vex[i], 'vmax':vex[i]}
    key = f'OHC_zonal_{ocean}_Ocean'
    im = ax[i,0].pcolormesh(cX, cY, lowpass(ctrl[key]-xr_quadtrend(ctrl[key]),13)/1e15, **kwargs)
    ims.append(im)
    ax[i,1].pcolormesh(lX, lY, lowpass((lpd[key]-xr_quadtrend(lpd[key]))/100,13)/1e15, **kwargs)
    for j in [0,1,3]:  
        ax[i,j].axhline(0, c='grey', lw=.5, ls='--')
        ax[i,j].set_yticks(np.arange(-60,100,30))
        ax[i,j].set_ylim(extents[i])
    ax[i,1].set_yticklabels([])
#     ax[i,0].text(60, extents[i][1]-10, ocean, c='g')
    ax[i,0].yaxis.set_label_coords(-0.14,.5)
    ax[i,0].set_ylabel(f'{ocean} Ocean\nlatitude')
    
    ax[i,3].plot(lowpass(ctrl[key]-xr_quadtrend(ctrl[key]),13).std(dim='time')/1e15    , ctrl.t_lat, c='k', label='HIGH')
    ax[i,3].plot(lowpass(lpd[key] -xr_quadtrend(lpd[key] ),13).std(dim='time')/1e15/100, lpd_lat   , c='k', label='LOW' , ls='--')
    ax[i,3].set_yticklabels([])
    
    ax[i,0].get_shared_y_axes().join(ax[i,0], ax[i,1])
    ax[i,0].get_shared_y_axes().join(ax[i,0], ax[i,3])
    
    cb = f.colorbar(ims[i], cax=ax[i,2], ticks=np.arange(-12,13,3))#, ticks=np.arange(-3e16,4e16,1e16))
    cb.outline.set_visible(False)
    
ax[0,3].legend(fontsize=6, handlelength=1.5, bbox_to_anchor=(1.05, .88), loc='center right', frameon=False)
# ax[0,0].text(60, -75, 'Southern Ocean', c='g')
for j in range(2):
    ax[0,j].axhline(-31.5, c='g', lw=.8)
    ax[0,j].text(.5, 1.02, ['HIGH', 'LOW'][j], transform=ax[0,j].transAxes, ha='center')
    ax[-1,j].set_xlabel('time [model years]')
    ax[-1,-j-1].set_xlabel('[PJ/m]')
ax[0,3].text(.5, 1.02, 'st. dev.', transform=ax[0,3].transAxes, ha='center')
    
f.align_xlabels()
plt.savefig(f'{path_results}/paper/OHC_zonal_Hovmoeller_lowpass13')

# Fig 6: OHC horizontal integral Hovmöller diagram

In [None]:
oceans = ['Global', 'Atlantic', 'Pacific', 'Southern']
das = [ctrl_qd, lpd_qd]
maxv = 60

fig = plt.figure(figsize=(6.4,9), constrained_layout=True)
gs0 = matplotlib.gridspec.GridSpec(4, 3, left=.1, right=.98, bottom=.11, top=.98, wspace=.05, hspace=.045, width_ratios=[1, 1, .25])

# ax[0,3].text(.5, 1.02, 'st. dev.', transform=ax[0,3].transAxes, ha='center')

# if offset==True: x = (da-da.isel(time=slice(0,30)).mean(dim='time')).T/1e21

for i, ocean in enumerate(oceans):
    name = f'{ocean}_Ocean'    
    
    for j, da in enumerate(das):
        
        if j==0:
            gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs0[i,2], hspace=0)
            ax_t = fig.add_subplot(gs00[0])
            ax_t.set_ylim((-1500,0))
            ax_t.set_xticks([])
            ax_t.set_yticks([-1500, -1000, -500, 0])
#             ax_t.spines['top'].set_visible(False)
#             ax_t.spines['right'].set_visible(False)
            
            ax_b = fig.add_subplot(gs00[1])
            ax_b.set_ylim((-6000,-1500))
            ax_b.set_yticks([-6000,-4500,-3000])
#             ax_b.spines['top'].set_visible(False)
#             ax_b.spines['right'].set_visible(False)
            
        
        gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs0[i,j], hspace=0)
        ax_top = fig.add_subplot(gs00[0])
        ax_top.set_ylim((-1500,0))
        ax_top.set_xticks([])
        ax_top.set_yticks([-1500, -1000, -500, 0])
        
        ax_bot = fig.add_subplot(gs00[1])
        ax_bot.set_ylim((-6000,-1500))
        ax_bot.set_yticks([-6000,-4500,-3000])
        
        da = das[j][f'OHC_levels_{name}']
        x = da.T/1e18
        X, Y = np.meshgrid(da.time, da.coords[['depth_t', 'z_t'][j]]/[1, 1e2][j])
        
        for k, ax in enumerate([ax_t, ax_b]):
            ax.plot(da.std(dim='time').values/1e18, -da.coords[['depth_t', 'z_t'][j]]/[1, 1e2][j], ls=['-','--'][j], c='k')

        for k, ax in enumerate([ax_top, ax_bot]):
            im = ax.pcolormesh(X, -Y, x, vmin=-maxv, vmax=maxv, cmap=cmocean.cm.balance)
        if i==0:
            ax_top.text(.5,1.05,['HIGH', 'LOW'][j], transform=ax_top.transAxes)
                
        if j==2:
            for ax in [ax_top, ax_bot]:
                ax.spines['right'].set_visible(False)
                ax.spines['top'].set_visible(False)
                ax.plot(da)
            
        
        if j==0:
            ax_bot.set_ylabel(f'{ocean} Ocean', horizontalalignment = 'left')
        
        if j>0:
            ax_top.set_yticklabels([])
            ax_bot.set_yticklabels([])
            ax_t.set_yticklabels([])
            ax_b.set_yticklabels([])
            
                
        
        if i==len(oceans)-1:
            ax_bot.set_xlabel('time [model years]')
            ax_b.set_xlabel('std [EJ/m]')
        else:
            ax_bot.set_xticklabels([])
            ax_t.set_xticklabels([])
            ax_b.set_xticklabels([])
            
            
cax = fig.add_axes([0.1, 0.04, 0.88, 0.02])
fig.colorbar(im, cax=cax, orientation='horizontal', label='OHC anomaly [EJ/m]', extend='both')
plt.savefig(f'{path_results}/paper/OHC_vertical_Hovmoeller_0-6km_ctrl_lpd_qd')


In [None]:
oceans = ['Global', 'Atlantic', 'Pacific', 'Southern']
das = [ctrl_qd, lpd_qd]
maxv = 50

fig = plt.figure(figsize=(6.4,9), constrained_layout=True)
gs0 = matplotlib.gridspec.GridSpec(4, 3, left=.12, right=.98, bottom=.11, top=.98, wspace=.07, hspace=.06, width_ratios=[1, 1, .25])


#     
# if offset==True: x = (da-da.isel(time=slice(0,30)).mean(dim='time')).T/1e21
for i in range(3):
    ax_title = fig.add_subplot(gs0[0,i])
    ax_title.axis('off')
    ax_title.text(.5, 1.02, ['HIGH', 'LOW', 'st. dev.'][i], transform=ax_title.transAxes, ha='center')

for i, ocean in enumerate(oceans):
    
    name = f'{ocean}_Ocean'    
    
    for j, da in enumerate(das):
        
        if j==0:  # std plots
            gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs0[i,2], hspace=0)
            ax_t = fig.add_subplot(gs00[0])
            ax_t.set_ylim((-1500,0))
            ax_t.set_xticks([])
            ax_t.set_yticks([-1500, -1000, -500, 0])
            
            ax_b = fig.add_subplot(gs00[1])
            ax_b.set_ylim((-6000,-1500))
            ax_b.set_yticks([-6000,-4500,-3000])
            
        
        da = das[j][f'OHC_levels_{name}']
        x = da.T/1e18
        X, Y = np.meshgrid(da.time, da.coords[['depth_t', 'z_t'][j]]/[1e3, 1e5][j])
        
        gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs0[i,j], hspace=0)
        ax_top = fig.add_subplot(gs00[0])
        ax_top.set_ylim((-1.5,0.0))
        ax_top.set_xticks([])
        ax_top.set_yticks([-1.5, -1.0, -0.5, 0.0])
        
        ax_bot = fig.add_subplot(gs00[1])
        ax_bot.set_ylim((-6.0,-1.5))
        ax_bot.set_yticks([-6.0,-4.5,-3.0])
        
        for k, ax in enumerate([ax_t, ax_b]):
            ax.plot(lowpass(da,13).std(dim='time').values/1e18, -da.coords[['depth_t', 'z_t'][j]]/[1, 1e2][j], ls=['-','--'][j], c='k', label=['HIGH', 'LOW'][j])
                
        for k, ax in enumerate([ax_top, ax_bot]):
            im = ax.pcolormesh(X, -Y, lowpass(x.T,13).T, vmin=-maxv, vmax=maxv, cmap=cmocean.cm.balance)
        
        if j==0:
            ax_bot.yaxis.set_label_coords(-0.17,1)
            ax_bot.set_ylabel(f'{ocean} Ocean\ndepth [km]', horizontalalignment = 'center')
#             ax_bot.text(60, -5.5, f'{ocean} Ocean', horizontalalignment = 'left', c='green')
        
        if j>0:
            ax_top.set_yticklabels([])
            ax_bot.set_yticklabels([])
            ax_t.set_yticklabels([])
            ax_b.set_yticklabels([])
            
                
        
        if i==len(oceans)-1:
            ax_bot.set_xlabel('time [model years]')
            ax_b.set_xlabel('[EJ/m]')
        else:
            ax_bot.set_xticklabels([])
            ax_t.set_xticklabels([])
            ax_b.set_xticklabels([])
            
            
    if i==0: ax_b.legend(frameon=False, handlelength=1.5, fontsize=6, loc=4)
            
            
cax = fig.add_axes([0.1, 0.04, 0.88, 0.015])
fig.colorbar(im, cax=cax, orientation='horizontal', label='OHC anomaly [EJ/m]', extend='both')
plt.savefig(f'{path_results}/paper/OHC_vertical_Hovmoeller_0-6km_ctrl_lpd_qd_lowpass13')


# Table 2: variances for different bandpass filters

How much area/volume do the major ocean basins represent?

In [None]:
from xr_DataArrays import xr_DZ, xr_AREA
from paths import file_RMASK_ocn
from regions import regions_dict
DZT = xr_DZ('ocn')
AREA = xr_AREA('ocn')
MASK = xr.open_dataarray(file_RMASK_ocn)

In [None]:
AREA

In [None]:
total_area = AREA.where(MASK>0).sum(dim=['nlat','nlon'],skipna=True).values
total_volume = (DZT*AREA.where(MASK>0)).sum(dim=['z_t','nlat','nlon'],skipna=True).values
print(f'total area {total_area} volume {total_volume}')
for i in [1,2,6]:
    area = AREA.where(MASK==i).sum(dim=['nlat','nlon'],skipna=True).values
    volume = (DZT*AREA.where(MASK==i)).sum(dim=['z_t','nlat','nlon'],skipna=True).values
    print(f'{regions_dict[i]:15}: area {area/total_area*100:2.0f}%; volume {volume/total_volume*100:2.0f}%')
    

In [None]:
25.8+38.1+17.7

In [None]:
26.8+41.2+17.8