## CESM2 - LARGE ENSEMBLE (LENS2)

- This Notebooks aims to compute the heat balance in the South Atlantic, defined by the difference of the meridional heat transport from the northern and southern boundaries and the total surface heat flux (area integral) 

### Imports

In [None]:
import xarray as xr
import pandas as pd
import numpy as np 
import dask
import cf_xarray
import intake
import cftime
import nc_time_axis
import intake_esm
import matplotlib.pyplot as plt
import pop_tools
from dask.distributed import Client, wait
from ncar_jobqueue import NCARCluster
import warnings, getpass, os
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cartopy.crs as ccrs
import cmocean
import dask
from scipy import stats

### Dask

In [None]:
mem_per_worker = 60 # memory per worker in GB 
num_workers = 40 # number of workers
cluster = NCARCluster(cores=1, processes=1, memory=f'{mem_per_worker} GB',resource_spec=f'select=1:ncpus=1:mem={mem_per_worker}GB', walltime='2:00:00')
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

### Read the data

In [None]:
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)

In [None]:
cat_subset = catalog.search(component='ocn',variable=['TEND_TEMP','SHF','N_HEAT'],frequency='month_1')

In [None]:
# Load catalog entries for subset into a dictionary of xarray datasets
dset_dict_raw  = cat_subset.to_dataset_dict(zarr_kwargs={'consolidated': True}, storage_options={'anon': True})
print(f'\nDataset dictionary keys:\n {dset_dict_raw.keys()}')

### Concatenation of variables

In [None]:
ff=('cmip6','smbb')               # Forcings
fb=(['TEND_TEMP','SHF','N_HEAT']) # Variable

ds_dict = dict()
for var in fb:
    # 1- combine historical and ssp370 (concatenate in time)
    ds_dict_tmp = dict()
    for scenario in ff:
        ds_dict_tmp[scenario] = xr.combine_nested([dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'], dset_dict_raw[f'ocn.ssp370.pop.h.{scenario}.{var}']],concat_dim=['time'])
        
        # 2- combine cmip6 and smbb (concatenate in member_id)
    ds_dict[var] = xr.combine_nested([ds_dict_tmp['cmip6'], ds_dict_tmp['smbb']], concat_dim=['member_id'])
    del ds_dict_tmp

In [None]:
dz=ds_dict['TEND_TEMP']['dz'].isel(time=0,member_id=0)*0.01 # 0.01 to convert cm into m 
# Test
dz.plot()

### Import POP grid

In [None]:
pop_grid = pop_tools.get_grid('POP_gx1v7')
ds_dict['TLONG'] = pop_grid.TLONG; ds_dict['TLAT'] = pop_grid.TLAT
ds_dict['ULONG'] = pop_grid.ULONG; ds_dict['ULAT'] = pop_grid.ULAT
del pop_grid

### Calculate the difference in heat transport to latitudes closer to the equator and 34S
- We chose 34 instead of 34.5S because at 34S we are sure that there is no water leakage South Africa 

In [None]:
%%time
#ilan = 0 # northernmost latitude
ilas = -34 # southernmost latitude
Vt_n=ds_dict['N_HEAT']['N_HEAT'].isel(transport_reg=1).sum(dim='transport_comp').isel(lat_aux_grid=190).load()
Vt_s=ds_dict['N_HEAT']['N_HEAT'].isel(transport_reg=1).sum(dim='transport_comp').sel(lat_aux_grid=ilas,method='nearest').load()

In [None]:
Vt_n.coords['time'][3000]

### Tendency

In [None]:
def myfunc(x):
    return slope * x + intercept

In [None]:
# Meridional Heat Transport Difference
Vt_s_trends=Vt_s.resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).mean('member_id')

Vt_n_trends=Vt_n.resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).mean('member_id')

Vt_sn_trends=(Vt_n-Vt_s).resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).mean('member_id')

x=np.squeeze(range(0,len(Vt_s_trends)))
slope, intercept, r, p, std_err = stats.linregress(x, Vt_s_trends)
mymodel_Vt_s_trends = list(map(myfunc, x))
mymodel_Vt_s_trends=mymodel_Vt_s_trends
m_Vt_s_trends=slope*10 # per decade
p_Vt_s_trends=p
r_Vt_s_trends=r*r

x=np.squeeze(range(0,len(Vt_n_trends)))
slope, intercept, r, p, std_err = stats.linregress(x, Vt_n_trends)
mymodel_Vt_n_trends = list(map(myfunc, x))
mymodel_Vt_n_trends=mymodel_Vt_n_trends
m_Vt_n_trends=slope*10 # per decade
p_Vt_n_trends=p
r_Vt_n_trends=r*r

x=np.squeeze(range(0,len(Vt_sn_trends)))
slope, intercept, r, p, std_err = stats.linregress(x, Vt_sn_trends)
mymodel_Vt_sn_trends = list(map(myfunc, x))
mymodel_Vt_sn_trends=mymodel_Vt_sn_trends
m_Vt_sn_trends=slope*10 # per decade
p_Vt_sn_trends=p
r_Vt_sn_trends=r*r

In [None]:
%%time
fig, ax = plt.subplots(figsize=(10, 8))
Vt_s.mean('member_id').sel(time=slice('1850-01-01','2100-12-31')).resample(time='1Y', closed='left').mean('time').plot(
    x="time",color='purple',linewidth=1,label=r'$\rm{MHT}: 34^oS$')
plt.plot(Vt_s.resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_Vt_s_trends,color='purple',linestyle='dashed')
plt.annotate(f'{m_Vt_s_trends*1000:.2f} TW per decade', xy=(0.4, 0.1), color='purple',fontsize=20,xycoords=ax.transAxes)


Vt_n.mean('member_id').sel(time=slice('1850-01-01','2100-12-31')).resample(time='1Y', closed='left').mean('time').plot(
    x="time",color='orange',linewidth=1,label=r'$\rm{MHT}: 0^o$')
plt.plot(Vt_n.resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_Vt_n_trends,color='orange',linestyle='dashed')
plt.annotate(f'{m_Vt_n_trends*1000:.2f} TW per decade', xy=(0.4, 0.6), color='orange',fontsize=20,xycoords=ax.transAxes)


(Vt_n-Vt_s).mean('member_id').sel(time=slice('1850-01-01','2100-12-31')).resample(time='1Y', closed='left').mean('time').plot(
    x="time",color='red',linewidth=1,label=r'$\rm{MHTD}: 0^o-34^oS$')
plt.plot((Vt_n-Vt_s).resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_Vt_sn_trends,color='red',linestyle='dashed')
plt.annotate(f'{m_Vt_sn_trends*1000:.2f} TW per decade', xy=(0.6, 0.38), color='red',fontsize=20,xycoords=ax.transAxes)


plt.tight_layout()
plt.legend(fontsize=20)
plt.ylabel(r'Heat Flux [PW]',fontsize=20)
plt.grid(color='gray', linestyle='-', linewidth=0.7)
plt.title(None)
plt.ylim(0.15,0.85)
plt.xlim(Vt_n.coords['time'][1319].values,Vt_n.coords['time'][3000].values)
plt.vlines(Vt_n.sel(time=slice('2015-01-01','2100-12-31')).coords['time'][0].values,0,0.9,linestyle='dashed',color="black")
plt.xlabel('Time [Years]',fontsize=20)
ax.tick_params(axis='x', labelsize=20)
ax.tick_params(axis='y', labelsize=20)
plt.savefig('MHT.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
%%time
ilan = 0 # northernmost latitude
ilas = -34 # southernmost latitude
ds_N_HEAT_diff=(ds_dict['N_HEAT']['N_HEAT'].isel(transport_reg=1,lat_aux_grid=190)-ds_dict['N_HEAT']['N_HEAT'].isel(transport_reg=1).sel(lat_aux_grid=ilas,method='nearest')).sum(dim='transport_comp').load()

### Cut and center the variable in the South Atlantic

In [None]:
%%time
# Cutting out and centering the variables in the South Atlantic
dask.config.set({"array.slicing.split_large_chunks": True})
ilon1, flon1, ilon2, flon2 = 307, 320, 0, 54 # longitude (initial (i), final (f)) 

fb=(['TEND_TEMP','SHF'])

for var in fb:
    ds_dict[f'{var}']=xr.combine_nested([[
        ds_dict[f'{var}'].where((ds_dict[f'{var}'].TLAT >= ilas) & (ds_dict[f'{var}'].TLAT <= ilan), drop=True).isel(
            nlon = slice(ilon1,flon1)),
        ds_dict[f'{var}'].where((ds_dict[f'{var}'].TLAT >= ilas) & (ds_dict[f'{var}'].TLAT <= ilan), drop=True).isel(
            nlon = slice(ilon2,flon2))]],
        concat_dim=['nlat','nlon'])   
    ds_dict[f'{var}'].coords['nlon'] = (ds_dict[f'{var}'].coords['nlon'] + 180) % 360 - 180 
    ds_dict[f'{var}'] = ds_dict[f'{var}'].sortby(ds_dict[f'{var}'].nlon)
del ilan, ilas, ilon1, flon1, ilon2, flon2 

In [None]:
ds_dict['TEND_TEMP']['TEND_TEMP'].isel(member_id=0,time=0,z_t=0).plot()

### Mask the continent 

In [None]:
fb=(['TEND_TEMP','SHF'])
for var in fb:
    mask_array = dict()
    mask_ocean = 2 * np.ones((len(ds_dict[f'{var}'][f'{var}'].coords['nlat']), # ocean
                          len(ds_dict[f'{var}'][f'{var}'].coords['nlon']))
                        ) * np.isfinite(ds_dict[f'{var}'][f'{var}'].isel(time=0))  
    mask_land  = 1 * np.ones((len(ds_dict[f'{var}'][f'{var}'].coords['nlat']), # continent
                          len(ds_dict[f'{var}'][f'{var}'].coords['nlon']))
                        ) * np.isnan(ds_dict[f'{var}'][f'{var}'].isel(time=0))  
    mask_array[f'{var}'] = mask_ocean + mask_land
    ds_dict[f'{var}']['TAREA']=ds_dict[f'{var}']['TAREA'].where(mask_array[f'{var}'] != 1.).isel(time=0)*1e-4 # 1e-4 to convert cm2 into m2
    del mask_array

### Integrate the SHF in the area

In [None]:
ds_dict['SHF']['TAREA']=ds_dict['SHF']['TAREA'].chunk(chunks=(50,1,67))
ds_dict['SHF']['SHF']=ds_dict['SHF']['SHF'].chunk(chunks=(50,1980,1,67))

In [None]:
%%time
ds_SHF=ds_dict['SHF']['TAREA']*ds_dict['SHF']['SHF']*(1e-15) # PW (1e-15 to convert the units from W to PW) 
ds_SHF=ds_SHF.sum(dim=['nlat','nlon'],skipna=True).load() # PW

In [None]:
# Test
ds_N_HEAT_diff.mean('member_id').resample(time='1Y', closed='left').mean('time').sel(time=slice('1851-01-01','2100-12-31')).plot(label='MHTD')
ds_SHF.mean('member_id').resample(time='1Y', closed='left').mean('time').sel(time=slice('1851-01-01','2100-12-31')).plot(label='SHF')
plt.legend()
plt.grid()
plt.xlabel('Time [Years]')
plt.ylabel('PW')
#plt.savefig('Heat_balance.png',dpi=300,bbox_inches='tight')
plt.show()

### Compute the heat balance

In [None]:
ds_SHF_N_HEAT_diff=ds_SHF-ds_N_HEAT_diff # PW

In [None]:
# Test
ds_SHF_N_HEAT_diff.mean('member_id').resample(time='1Y', closed='left').mean('time').sel(time=slice('1851-01-01','2100-12-31')).plot(label='HS',color='red')
plt.legend()
plt.grid()
plt.title(None)
plt.xlabel('Time [Years]')
plt.ylabel('PW')
#plt.savefig('Heat_storage.png',dpi=300,bbox_inches='tight')
plt.show()

<div class="alert alert-block alert-info">
Here it was necessary to do the difference and not the sum of the terms to get the heat balance. This is because the SHF convection is positive to the ocean. The balance is given by every heat flux entering from the surface (positive direction of the z-axis) is equal to every flux leaving from the meridional heat transport (positive direction of the y-axis). The meirdional heat transport has a positive y-axis direction, but the SHF has not a negative z-axis direction. 
</div>

### Compute the heat storage (HS) to compare with the heat storage due to the difference between the heat fluxes
- the vertical integral of the temperature tendency 

#### Equation: $$\rm{HS = \uprho_\uptheta~C_p~\int_{z_2}^{z_1}\uptheta_{(z)}'~dz},$$
##### where:
##### * HS is heat storage ($\rm{J~m^{-2}}$),
##### * $\uprho_\uptheta$ is the density of sea water,
##### * $\rm{C_p}$ is the specific heat of sea water,
##### * $\rm{z}$ is the depth limit on the calculation in meters,
##### * and $\uptheta$' is the potential temperature monthly anomaly (successor month minus previous month) at each depth in degress Kelvin or Celsius or, the temperature tendency. 

### Calculate the heat stored per layer

In [None]:
%%time
warnings.simplefilter("ignore")
#layers=('0','1e+5','2e+5')
layers=('0','1e+5','6e+5')
for layer in range(0,len(layers)-1):
    print(f'Done with layer: {layer}')
    var_array = list() # Build a list
    for member_id in range(len(ds_dict['TEND_TEMP']['TEND_TEMP'].coords['member_id'])): # per member
        st=f'ds_HS_TEMP=ds_dict[\'TEND_TEMP\'][\'TEND_TEMP\'].isel(member_id=member_id).sel(z_t=slice({layers[layer]},{layers[layer+1]}))*dz.sel(z_t=slice({layers[layer]},{layers[layer+1]}))' # 1- Multiply by dz. Unit: oC.s-1.m
        exec(st); del st
        st=f'ds_HS_TEMP=ds_HS_TEMP*ds_dict[\'TEND_TEMP\'][\'TAREA\'].isel(member_id=member_id).sel(z_t=slice({layers[layer]},{layers[layer+1]}))' # 2- Multiply by the area. Unit: oC.s-1.m3
        exec(st); del st
        ds_HS_TEMP=ds_HS_TEMP.sum(dim=['z_t','nlon','nlat']) # 3- Integral in dz,dy,dx. Unit: oC.s-1.m3
        ds_HS_TEMP=ds_HS_TEMP*1026 # 4- Multiply by the density of the sea water. Unit: oC.s-1.kg
        ds_HS_TEMP=ds_HS_TEMP*3996 # 5- Multiply by the heat capacity of the sea water. Unit: W
        ds_HS_TEMP=ds_HS_TEMP*1e-15 # 6- Get the variable in PW 
        var_small=ds_HS_TEMP.load() # 7- Annual mean and load
        var_array.append(var_small) # 8- Add items to the end of a given list
        del ds_HS_TEMP 
        print(f'Done with member: {member_id}') # 9- Go to the next member
    st=f'ds_HS_TEMP_merged_{layer} = xr.concat(var_array, dim=\'member_id\', compat=\'override\', join=\'override\', coords=\'minimal\')' # 10- Concat the members   
    exec(st); del st

In [None]:
ds_HS_TEMP_merged_0.sel(time=slice('1852-01-01','2100-12-31')).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(label='0-1000')
ds_HS_TEMP_merged_1.sel(time=slice('1852-01-01','2100-12-31')).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(label='1000-6000')
plt.legend()
plt.show()

In [None]:
ds_out_var = xr.merge([ds_HS_TEMP_merged_0.rename('HS_0'), # Heat Storage (first layer)
                       ds_HS_TEMP_merged_1.rename('HS_1'), # Heat Storage (second layer)
                      ]) # Total Surface Heat Flux
ds_out_var.attrs['description'] = 'Heat balance components per layers for the South Atlantic: HS_0 (0-1000m), HS_1 (1000-6000m)'
ds_out_var.attrs['units'] = 'PW'
ds_out_var.attrs['author'] = 'Mauricio Rocha'
ds_out_var.attrs['email'] = 'mauricio.rocha@usp.br'
# create a directory on scratch to save the output
path = '/glade/scratch/mauricio/Data/LENS2/HEAT_BALANCE/'.format(getpass.getuser())
os.system('mkdir -p '+path)
ds_out_var.to_netcdf(path+'heat_storage_per_layer_0_6000m.nc')