## CESM2 - LARGE ENSEMBLE (LENS2)

- This notebook aims to compute heat balance terms on the South Atlantic surface. 

### Imports

In [None]:
import xarray as xr
import pandas as pd
import numpy as np 
import dask
import cf_xarray
import intake
import cftime
import nc_time_axis
import intake_esm
import matplotlib.pyplot as plt
import pop_tools
from dask.distributed import Client, wait
from ncar_jobqueue import NCARCluster
import warnings, getpass, os
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cartopy.crs as ccrs
import cmocean
import dask
from matplotlib.offsetbox import AnchoredText

### Dask

In [None]:
mem_per_worker = 40 # memory per worker in GB 
num_workers = 40 # number of workers
cluster = NCARCluster(cores=1,
                      processes=1,
                      memory=f'{mem_per_worker} GB',
                      resource_spec=f'select=1:ncpus=1:mem={mem_per_worker}GB',
                      walltime='2:00:00',
                      log_directory='./dask-logs',
                     )
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

### Read the data

In [None]:
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)

#### Ocean Component

In [None]:
%%time

all_vars = ['LWDN_F','SHF','SHF_QSW','LWUP_F','EVAP_F','SENH_F']
cat_subset = catalog.search(component='ocn',variable=all_vars,frequency='month_1')
# Load catalog entries for subset into a dictionary of xarray datasets
dset_dict_raw  = cat_subset.to_dataset_dict(zarr_kwargs={'consolidated': True}, storage_options={'anon': True})#, xarray_open_kwargs=('chunks': {'':}))
print(f'\nDataset dictionary keys:\n {dset_dict_raw.keys()}')

In [None]:
%%time

# Concatenation of variables
ff=('cmip6','smbb')                      # Forcings
ds_dict_ocn = dict()
for var in all_vars:
    # 1- combine historical and ssp370 (concatenate in time)
    ds_dict_tmp = dict()
    for scenario in ff:
        ds_dict_tmp[scenario] = xr.combine_nested([dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'], dset_dict_raw[f'ocn.ssp370.pop.h.{scenario}.{var}']],concat_dim=['time'])
        
        # 2- combine cmip6 and smbb (concatenate in member_id)
    ds_dict_ocn[var] = xr.combine_nested([ds_dict_tmp['cmip6'], ds_dict_tmp['smbb']], concat_dim=['member_id'])
    del ds_dict_tmp

ds_dict_ocn['EVAP_F']['latent_heat_vapor'] = ds_dict_ocn['EVAP_F']['latent_heat_vapor'].chunk(member_id=1)

### Import POP grid for the ocean component

In [None]:
pop_grid = pop_tools.get_grid('POP_gx1v7')
ds_dict_ocn['TLONG'] = pop_grid.TLONG; ds_dict_ocn['TLAT'] = pop_grid.TLAT
ds_dict_ocn['ULONG'] = pop_grid.ULONG; ds_dict_ocn['ULAT'] = pop_grid.ULAT
del pop_grid

### Cut and center the variable in the South Atlantic

In [None]:
%%time
# Cutting out and centering the variables in the South Atlantic
dask.config.set({"array.slicing.split_large_chunks": True})

# Ocean component
ilon1, flon1, ilon2, flon2 = 307, 320, 0, 54 # longitude (initial (i), final (f)) 
ilan=0
ilas=-34
fb=(['LWDN_F','SHF','SHF_QSW','LWUP_F','EVAP_F','SENH_F'])
for var in fb:
    if var not in ds_dict_ocn:
        continue
    ds_dict_ocn[f'{var}']=xr.combine_nested([[
        ds_dict_ocn[f'{var}'].where((ds_dict_ocn[f'{var}'].TLAT >= ilas) & (ds_dict_ocn[f'{var}'].TLAT <= ilan), drop=True).isel(
            nlon = slice(ilon1,flon1)),
        ds_dict_ocn[f'{var}'].where((ds_dict_ocn[f'{var}'].TLAT >= ilas) & (ds_dict_ocn[f'{var}'].TLAT <= ilan), drop=True).isel(
            nlon = slice(ilon2,flon2))]],
        concat_dim=['nlat','nlon'])
    ds_dict_ocn[f'{var}'].coords['nlon'] = (ds_dict_ocn[f'{var}'].coords['nlon'] + 180) % 360 - 180 
    ds_dict_ocn[f'{var}'] = ds_dict_ocn[f'{var}'].sortby(ds_dict_ocn[f'{var}'].nlon)
    # ds_tmp1 = ds_dict_ocn[var].where((ds_dict_ocn[var].TLAT >= ilas) & (ds_dict_ocn[var].TLAT <= ilan), drop=True).isel(
    #             nlon = slice(ilon1,flon1))
    # ds_tmp2 = ds_dict_ocn[var].where((ds_dict_ocn[var].TLAT >= ilas) & (ds_dict_ocn[var].TLAT <= ilan), drop=True).isel(
    #             nlon = slice(ilon2,flon2))
    # ds_tmp3 = xr.combine_nested([[ds_tmp1, ds_tmp2]], concat_dim=['nlat','nlon'])
    # ds_tmp4 = ds_tmp3.assign_coords({'nlon': (ds_tmp3[f'{var}'].coords['nlon'] + 180) % 360 - 180})
    # ds_dict_ocn[var] = ds_tmp4.sortby(ds_tmp4.coords['nlon'])
    # del ds_tmp1, ds_tmp2, ds_tmp3, ds_tmp4
ds_dict_ocn['EVAP_F']['latent_heat_vapor']

### Mask the continent 
- The area saved by the ocean component needs to be masked over the continent

In [None]:
%%time
fb=(['SHF'])
warnings.simplefilter("ignore")
for var in fb:
    mask_array = dict()
    mask_ocean = 2 * np.ones((len(ds_dict_ocn[f'{var}'][f'{var}'].coords['nlat']), # ocean
                          len(ds_dict_ocn[f'{var}'][f'{var}'].coords['nlon']))
                        ) * np.isfinite(ds_dict_ocn[f'{var}'][f'{var}'].isel(time=0))  
    mask_land  = 1 * np.ones((len(ds_dict_ocn[f'{var}'][f'{var}'].coords['nlat']), # continent
                          len(ds_dict_ocn[f'{var}'][f'{var}'].coords['nlon']))
                        ) * np.isnan(ds_dict_ocn[f'{var}'][f'{var}'].isel(time=0))  
    mask_array[f'{var}'] = mask_ocean + mask_land
    ds_dict_ocn[f'{var}']['TAREA']=ds_dict_ocn[f'{var}']['TAREA'].where(mask_array[f'{var}'] != 1.).isel(time=0)*1e-4 # cm -> m
    del mask_array
# ds_dict_ocn['SHF']['TAREA']=ds_dict_ocn['SHF']['TAREA'].chunk(chunks=(50,1,67)) 
ds_dict_ocn['SHF']['TAREA']=ds_dict_ocn['SHF']['TAREA'].compute()
ds_dict_ocn['SHF']['TAREA']

### Integrate in the area
- We rearrange the size of the chunks and calculate the integral in the data area. We do this for each component, because they have different grids, i.e. different areas of each cell. However, the total area has to be the same. If it is not because of the differences in the grids, we might need to do some interpolation

In [None]:
%%time
# Ocean component
fb=(['LWDN_F','SHF','SHF_QSW','LWUP_F','SENH_F'])
for var in fb:
    ds_dict_ocn[f'{var}'][f'{var}']=ds_dict_ocn[f'{var}'][f'{var}'].chunk(chunks=(50,1980,1,67))
    st=f'ds_{var}=[]' 
    exec(st)
    st=f'ds_{var}=ds_dict_ocn[\'SHF\'][\'TAREA\']*ds_dict_ocn[\'{var}\'][\'{var}\']*(1e-15)' # PW (1e-15 to convert the units from W to PW) 
    exec(st)   
    st=f'ds_{var}=ds_{var}.sum(dim=[\'nlat\',\'nlon\'],skipna=True).load()' # PW
    exec(st)
    print(f'Done with variable: {var}') 

### The ocean component does not provide the latent heat flux, but we can calculate it as follows: 

In [None]:
%%time
ds_dict_ocn['EVAP_F']['latent_heat_vapor'] = ds_dict_ocn['EVAP_F']['latent_heat_vapor'].persist()
wait(ds_dict_ocn['EVAP_F']['latent_heat_vapor'])
ds_dict_ocn['EVAP_F']['latent_heat_vapor']

In [None]:
%%time
ds_dict_ocn['EVAP_F']['EVAP_F'] = ds_dict_ocn['EVAP_F']['EVAP_F'].persist()
wait(ds_dict_ocn['EVAP_F']['EVAP_F'])
ds_dict_ocn['EVAP_F']['EVAP_F']

In [None]:
%%time
ds_LATENT_tmp = ((ds_dict_ocn['SHF']['TAREA' # Area in m2
                              ]*ds_dict_ocn['EVAP_F']['EVAP_F' # Mass flux of water vapor in kg/m2/s 
                                                     ]*ds_dict_ocn['EVAP_F']['latent_heat_vapor' # Latent Heat Vapor in J/kg
                                                                            ])).persist()
wait(ds_LATENT_tmp)
ds_LATENT = ds_LATENT_tmp.sum(dim=['nlat','nlon'],skipna=True)*1e-15 # W -> PW
ds_LATENT

In [None]:
ds_LATENT=ds_LATENT.compute()

In [None]:
ds_SHF.mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(label='SHF1')
(ds_LWDN_F+ds_LWUP_F+ds_SHF_QSW+ds_SENH_F+ds_LATENT).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(label='SHF2')
#(ds_LWDN_F+ds_LWUP_F).mean(dim=['member_id']).resample(time='5Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(label='LW')
#(ds_SHF_QSW).mean(dim=['member_id']).resample(time='5Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(label='SW')
#(-ds_SHFLX).mean(dim=['member_id']).resample(time='5Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(label='SH')
#(-ds_LHFLX).mean(dim=['member_id']).resample(time='5Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(label='LH')

plt.legend()

In [None]:
def calculate_ticks(ax, ticks, round_to=0.01, center=True):
    upperbound = np.ceil(ax.get_ybound()[1]/round_to)
    lowerbound = np.floor(ax.get_ybound()[0]/round_to)
    dy = upperbound - lowerbound
    fit = np.floor(dy/(ticks - 1)) + 1
    dy_new = (ticks - 1)*fit
    if center:
        offset = np.floor((dy_new - dy)/2)
        lowerbound = lowerbound - offset
    values = np.linspace(lowerbound, lowerbound + dy_new, ticks)
    return values*round_to

In [None]:
from matplotlib.pyplot import figure
letts=['A','B','C','D','E','F']
vari=['ds_SHF','ds_LWDN_F','ds_LWUP_F','ds_SHF_QSW','ds_SENH_F','ds_LATENT']
fig, axs = plt.subplots(1,6, figsize=(25, 7))
ds_SHF.mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[0],label='SHF1', linewidth=1,color='blue')
(ds_LWDN_F+ds_LWUP_F+ds_SHF_QSW+ds_SENH_F+ds_LATENT).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[0],label='SHF2', linewidth=1,color='red')
(ds_LWDN_F).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[1],label='LWDN', linewidth=1,color='maroon')
(ds_LWUP_F).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[2],label='LWUP', linewidth=1,color='green')
(ds_SHF_QSW).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[3],label='SW', linewidth=1,color='orange')
(ds_SENH_F).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[4],label='SH', linewidth=1,color='purple')
(ds_LATENT).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[5],label='LH', linewidth=1,color='c')
for i in range(len(axs)):
    axs[i].legend(loc="upper right",fontsize=16, ncol=1)
    axs[i].grid(color='gray', linestyle='-', linewidth=0.7)
    axs[i].set_xlabel('Time [Years]',fontsize=16) 
    axs[i].tick_params(axis='x', labelsize=16); axs[i].tick_params(axis='y', labelsize=16)
    at = AnchoredText(letts[i], prop=dict(size=20), frameon=True, loc='lower left'); at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
    axs[i].add_artist(at)
    axs[i].set_yticks(calculate_ticks(axs[i], 8))
    
axs[0].set_ylabel('Heat Flux [PW]',fontsize=16)
plt.subplots_adjust(wspace=0.37)
plt.savefig('Qnet.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
def myfunc(x):
    return slope * x + intercept

In [None]:
from scipy import stats

In [None]:
SHF_sts=ds_SHF.mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31'))
x=np.squeeze(range(0,len(SHF_sts)))
slope, intercept, r, p, std_err = stats.linregress(x, SHF_sts)
mymodel_SHF_sts = list(map(myfunc, x))
mymodel_SHF_sts=mymodel_SHF_sts
m_SHF_sts=slope*10 # per decade
p_SHF_sts=p
r_SHF_sts=r*r

Neg_sts=(ds_LWDN_F+ds_LWUP_F+ds_SENH_F+ds_LATENT).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31'))
x=np.squeeze(range(0,len(Neg_sts)))
slope, intercept, r, p, std_err = stats.linregress(x, Neg_sts)
mymodel_Neg_sts = list(map(myfunc, x))
mymodel_Neg_sts=mymodel_Neg_sts
m_Neg_sts=slope*10 # per decade
p_Neg_sts=p
r_Neg_sts=r*r
             
Pos_sts=ds_SHF_QSW.mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31'))
x=np.squeeze(range(0,len(Pos_sts)))
slope, intercept, r, p, std_err = stats.linregress(x, Pos_sts)
mymodel_Pos_sts = list(map(myfunc, x))
mymodel_Pos_sts=mymodel_Pos_sts
m_Pos_sts=slope*10 # per decade
p_Pos_sts=p
r_Pos_sts=r*r

In [None]:
from matplotlib.pyplot import figure
letts=['A','B','C','D','E']
fig, axs = plt.subplots(1,3, figsize=(20, 7))
(ds_LWDN_F+ds_LWUP_F+ds_SHF_QSW+ds_SENH_F+ds_LATENT).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[0],label='SHF2', linewidth=1,color='red')
(ds_LWDN_F+ds_LWUP_F+ds_SENH_F+ds_LATENT).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[1],label='LWDW+LWUP+SH+LH', linewidth=1,color='maroon')
(ds_SHF_QSW).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31')).plot(
    ax=axs[2],label='SW', linewidth=1,color='orange')
axs[0].plot((ds_LWDN_F).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_SHF_sts,color='red',linestyle='dashed')
axs[0].annotate(f'{m_SHF_sts*1000:.2f} TW per decade', xy=(0.3, 0.2), color='red',fontsize=20,xycoords=axs[0].transAxes)
axs[1].plot((ds_LWDN_F).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_Neg_sts,color='maroon',linestyle='dashed')
axs[1].annotate(f'{m_Neg_sts*1000:.2f} TW per decade', xy=(0.3, 0.2), color='maroon',fontsize=20,xycoords=axs[1].transAxes)
axs[2].plot((ds_LWDN_F).mean(dim=['member_id']).resample(time='1Y', closed='left').mean('time').sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_Pos_sts,color='orange',linestyle='dashed')
axs[2].annotate(f'{m_Pos_sts*1000:.2f} TW per decade', xy=(0.3, 0.2), color='orange',fontsize=20,xycoords=axs[2].transAxes)

for i in range(len(axs)):
    axs[i].legend(loc="upper right",fontsize=16, ncol=1)
    axs[i].grid(color='gray', linestyle='-', linewidth=0.7)
    axs[i].set_xlabel('Time [Years]',fontsize=16) 
    axs[i].tick_params(axis='x', labelsize=16); axs[i].tick_params(axis='y', labelsize=16)
    at = AnchoredText(letts[i], prop=dict(size=20), frameon=True, loc='upper left'); at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
    axs[i].add_artist(at)
    axs[i].set_yticks(calculate_ticks(axs[i], 8))
axs[0].set_ylabel('Heat Flux [PW]',fontsize=16)
plt.subplots_adjust(wspace=0.2)
plt.savefig('Qnet_1.png',dpi=300,bbox_inches='tight')
plt.show()