In [None]:
from glob import glob
import xarray as xr
import cftime
import nc_time_axis
import numpy as np
import matplotlib.pyplot as plt

## Read the IPSL Data

In [None]:
IPSL_Omon = '/archive/uda/CMIP6/CMIP/IPSL/IPSL-CM6A-LR/historical/r10i1p1f1/Omon'
IPSL_Ofx = '/archive/uda/CMIP6/CMIP/IPSL/IPSL-CM6A-LR/historical/r10i1p1f1/Ofx'
IPSL_thetao = f'{IPSL_Omon}/thetao/gn/v20180803'
IPSL_areacello = f'{IPSL_Ofx}/areacello/gn/v20180803'

In [None]:
filelist = glob(f'{IPSL_thetao}/*.nc')
filelist1 = glob(f'{IPSL_areacello}/*.nc')

In [None]:
filelist,filelist1

In [None]:
# It seems that if I don't specify chunks as they exist in the netcdf file (ncdump -sh file | grep -i chunk)
# open_mfdataset uses the size of the file as chunk. Big memory problems ensue...
IPSL_T = xr.open_mfdataset(filelist, chunks={'time': 1, 'olevel': 75, 'y': 332, 'x': 362})
IPSL_A = xr.open_mfdataset(filelist1, chunks={'y': 332, 'x': 362})

In [None]:
IPSL_T

In [None]:
IPSL_A

In [None]:
# IPSL_A['areacello'].attrs.update({'coordinates': 'nav_lon nav_lat'})

In [None]:
# IPSL_A.drop(['nav_lat nav_lon'])

### Calculate volcello 

In [None]:
olevel_diff = IPSL_T['olevel_bounds'].diff('axis_nbounds').squeeze()

In [None]:
volcello = olevel_diff*IPSL_A['areacello']

In [None]:
volcello1 = olevel_diff * IPSL_A['area']

In [None]:
volcello

In [None]:
volcello1

### Calculate OHC 

In [None]:
OHC = 3992 * 1025 * IPSL_T['thetao'] * volcello1

In [None]:
OHC

In [None]:
global_OHC_upper700m = OHC.sel(olevel=slice(0,700)).sum(dim=('olevel', 'y', 'x'))

In [None]:
global_OHC_upper2000m = OHC.sel(olevel=slice(0,2000)).sum(dim=('olevel', 'y', 'x'))

In [None]:
global_OHC_2000below = OHC.sel(olevel=slice(2000,7000)).sum(dim=('olevel', 'y', 'x'))

In [None]:
global_OHC = OHC.sum(dim=('olevel', 'y', 'x'))

In [None]:
global_OHC_level = OHC.sum(dim=('y', 'x'))

## Go Fast With Dask

In [None]:
from dask.distributed import Client

client = Client("tcp://140.208.147.155:42776")
client

In [None]:
%time _ = global_OHC_upper700m.load()

In [None]:
%time _ = global_OHC_upper2000m.load()

In [None]:
%time _ = global_OHC_2000below.load()

In [None]:
%time _ = global_OHC.load()

In [None]:
%time _ = global_OHC_level.load()

## Compare with Zanna et al. 

In [None]:
Zanna = xr.open_dataset('/net2/rnd/Zanna_2018/OHC_GF_1870_2018.nc')
Zanna = Zanna.rename({'time (starting 1870)': 'time'})
Zanna = Zanna.set_coords(['time'])

In [None]:
dates = []
for year in Zanna['time'].values:
    dates.append(cftime.DatetimeNoLeap(year, 7, 16, hour=12))

In [None]:
Zanna['cftime'] = xr.DataArray(np.array(dates), dims='cftime')
Zanna = Zanna.set_coords(['cftime'])

In [None]:
Zanna

In [None]:
def anom_yearly_avg(da):
    # compute yearly values
    yearly = da.groupby(da.time.dt.year).mean(dim='time')
    # anomamly to 1870 like in Zanna et al.
    anom = yearly - yearly.sel(year=1870)
    return anom

gOHCanom_upper700m_annual = anom_yearly_avg(global_OHC_upper700m)
gOHCanom_upper2000m_annual = anom_yearly_avg(global_OHC_upper2000m)
gOHCanom_2000below_annual = anom_yearly_avg(global_OHC_2000below)
gOHCanom_annual = anom_yearly_avg(global_OHC)

## Plot the results

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_upper700m_annual / 1e21).plot(ax=ax, label='IPSL', color='k')
Zanna['OHC_700m'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC upper 700m')
plt.grid()

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_upper2000m_annual / 1e21).plot(ax=ax, label='IPSL', color='k')
Zanna['OHC_2000m'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC upper 2000m')
plt.grid()

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_2000below_annual / 1e21).plot(ax=ax, label='IPSL', color='k')
Zanna['OHC_below_2000m'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC below 2000m')
plt.grid()

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_annual / 1e21).plot(ax=ax, label='IPSL', color='k')
Zanna['OHC_full_depth'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC full depth')
plt.grid()

### All in one

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=[10,10])

(gOHCanom_upper700m_annual / 1e21).plot(ax=axs[0,0], label='IPSLhist', color='k')
Zanna['OHC_700m'].plot(ax=axs[0,0], label='Zanna', color='r')
axs[0,0].legend(fontsize=16)
axs[0,0].set_title('OHC [ZJ] upper 700m')
axs[0,0].set_xlabel("")
axs[0,0].set_ylabel("")
axs[0,0].grid()

(gOHCanom_upper2000m_annual / 1e21).plot(ax=axs[0,1], label='IPSLhist', color='k')
Zanna['OHC_2000m'].plot(ax=axs[0,1], label='Zanna', color='r')
axs[0,1].legend(fontsize=16)
axs[0,1].set_title('OHC [ZJ] upper 2000m')
axs[0,1].set_xlabel("")
axs[0,1].set_ylabel("")
axs[0,1].grid()

(gOHCanom_2000below_annual / 1e21).plot(ax=axs[1,0], label='IPSLhist', color='k')
Zanna['OHC_below_2000m'].plot(ax=axs[1,0], label='Zanna', color='r')
axs[1,0].legend(fontsize=16)
axs[1,0].set_title('OHC [ZJ] below 2000m')
axs[1,0].set_xlabel("")
axs[1,0].set_ylabel("")
axs[1,0].grid()

(gOHCanom_annual / 1e21).plot(ax=axs[1,1], label='IPSLhist', color='k')
Zanna['OHC_full_depth'].plot(ax=axs[1,1], label='Zanna', color='r')
axs[1,1].legend(fontsize=16)
axs[1,1].set_title('OHC [ZJ] full depth')
axs[1,1].set_xlabel("")
axs[1,1].set_ylabel("")
axs[1,1].grid()