## CESM2 - LARGE ENSEMBLE (LENS2)
- In this Notebook we split the temperature flux into velocity and temperature. In addition, we compute the advective terms and the eddies. 

### Imports

In [None]:
# modules I am using in this example
import xarray as xr
import xgcm
import numpy as np
from xgcm import Grid
import pop_tools
from dask.distributed import Client, wait
from ncar_jobqueue import NCARCluster
import dask
import cartopy.crs as ccrs
import intake
import matplotlib.pyplot as plt
import intake_esm
import warnings, getpass, os

### Dask

In [None]:
mem_per_worker = 30 # memory per worker in GB 
num_workers = 30 # number of workers
cluster = NCARCluster(cores=1, 
                      processes=1,
                      memory=f'{mem_per_worker} GB',
                      resource_spec=f'select=1:ncpus=1:mem={mem_per_worker}GB',
                      walltime='4:00:00')
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

### Load data: using mfdataset
- We need to open a single member in this way to get the attributes. There is an issue when we open the data from the catalog, because it does not keep all the attributes. We need the attributes for xgcm.

In [None]:
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/UET/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.UET.205501-206412.nc'
ds_UET = xr.open_mfdataset(path,parallel=True)
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/VNT/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.VNT.205501-206412.nc'
ds_VNT = xr.open_mfdataset(path,parallel=True)
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/VVEL/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.VVEL.205501-206412.nc'
ds_VVEL = xr.open_mfdataset(path,parallel=True)
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/UVEL/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.UVEL.205501-206412.nc'
ds_UVEL = xr.open_mfdataset(path,parallel=True)
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/TEMP/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.TEMP.205501-206412.nc'
ds_TEMP = xr.open_mfdataset(path,parallel=True)
del path
dsa = xr.merge([ds_VNT,ds_VVEL,ds_TEMP,ds_UET,ds_UVEL],compat='override')
del ds_VNT,ds_VVEL,ds_TEMP, ds_UET, ds_UVEL

### Load data: using the catalog

In [None]:
%%time
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)
cat_subset = catalog.search(component='ocn',variable=['VNT','UET','UVEL','VVEL','TEMP'],frequency='month_1')
# Load catalog entries for subset into a dictionary of xarray datasets
dset_dict_raw  = cat_subset.to_dataset_dict(zarr_kwargs={'consolidated': True},
                                            storage_options={'anon': True},
                                            cdf_kwargs={'chunks': {'nlon': 18}})
print(f'\nDataset dictionary keys:\n {dset_dict_raw.keys()}')

### Concatenation of variables
- You can use lists or dictionaries.

In [None]:
ff=('cmip6','smbb') # Forcings
fb=(['VNT','UET','UVEL','VVEL','TEMP']) 
#Notes: 
# VNT: Flux of Heat in grid-y direction (degC/s)
# UET: Flux of Heat in grid-x direction (degC/s)
# VVEL: Velocity in grid-y direction (centimeter/s)
# UVEL: Velocity in grid-x direction (centimeter/s)
# TEMP: Potential Temperature (degC)
#dsi = dict()
dsi = []
for var in fb:
    # 1- combine historical and ssp370 (concatenate in time)
#   ds_dict_tmp = dict()
    ds_dict_tmp = []
    for scenario in ff:
#       ds_dict_tmp[scenario] = xr.combine_nested([dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'], 
#       dset_dict_raw[f'ocn.ssp370.pop.h.{scenario}.{var}']],concat_dim=['time'])
        ds_dict_tmp.append(xr.combine_nested([dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'], 
                                              dset_dict_raw[f'ocn.ssp370.pop.h.{scenario}.{var}']],concat_dim=['time'],
                                             data_vars='minimal',coords='minimal'))    
        # 2- combine cmip6 and smbb (concatenate in member_id)
#   dsi[var] = xr.combine_nested([ds_dict_tmp['cmip6'], ds_dict_tmp['smbb']], concat_dim=['member_id'])
    dsi.append(xr.combine_nested([ds_dict_tmp[0], ds_dict_tmp[1]], concat_dim=['member_id'],data_vars='minimal',coords='minimal'))
ds = xr.merge([dsi[0],dsi[1],dsi[2],dsi[3],dsi[4]],compat='override')

# Add attributes from last scenario / variable to ds
#for key in ['Conventions', 'calendar']:
#    ds.attrs[key] = dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'].attrs[key]
#ds['TLONG'].attrs['axis']='X'
#ds['TLAT'].attrs['axis']='Y'
#list(dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'].attrs.keys())

### Copy the attributes for each variable we open via mfdataset

In [None]:
for i in dsa.variables:
    ds[i].attrs=dsa[i].attrs

### Copy the global attributes

In [None]:
ds.attrs=dsa.attrs
ds

### Help functions

In [None]:
def to_index(dsf):
    dsf = dsf.copy()
    for dim in ds_sa.dims:
        if dim in ["nlon_t", "nlat_t", "nlon_u", "nlat_u"]:
            dsf = dsf.drop(dim).rename({dim: dim[:-2]})
    return dsf
def to_coord(dsf, lon, lat):
    return dsf.rename({"nlon": "nlon_" + lon,
                      "nlat": "nlat_" + lat})

### Calculate the volume of the cell and use xgcm to determine the coordinates T and U

In [None]:
# Get the cell volume and the grid from pop_tools. This should be easier for LENS low res, as your DZT is just your dz because you are not dealing with partial bottom cells (right Gustavo?) 
ds['cell_volume'] = ds.dz * ds.DXT * ds.DYT # Volume (check out later the unit) -> cm3
# DXT: x-spacing centered at T points (cm)
# DXT: y-spacing centered at T points (cm)
# DZ: thickness of layer k (cm)

#metrics = {
#    ("X",): ["DXU", "DXT"],     # X distances
#    ("Y",): ["DYU", "DYT"],     # Y distances
#    ("Z",): ["DZU", "DZT"],     # Z distances
#    ("X", "Y"): ["DAU", "DAT"], # Areas
#    ("X","Y","Z"): ["DVT"]      # Volumes
#}

metrics = {
    ("X",): ["DXU", "DXT"],     # X distances
    ("Y",): ["DYU", "DYT"],     # Y distances
    ("Z",): ["dz", "dz"],       # Z distances
}

ds.cell_volume.attrs={'long_name': 'cell volume', 'units': 'cm3', 'grid_loc': '3111', 'cell_methods':'time: mean'} # Adding attributes to the new variable 
grid, ds_ren = pop_tools.to_xgcm_grid_dataset(ds, # We gotta confirm if we may use just one variable to compute the grid and ds_ren
                                            metrics=metrics,
                                            periodic=['X'],
                                            boundary={"Y":"extend",
                                            "Z":"fill"},
                                            fill_value={"Z":0.}) 
#ds_ren['cell_volume'] = ds_ren.dz * ds_ren.DXT * ds_ren.DYT # Volume (check out later the unit)

In [None]:
%%time
warnings.simplefilter("ignore")
# Get the total advection as saved by the model 
# Total advection saved by model --> the difference betweem this and the mean is the eddy component
# Horizontal components
print('Getting total advection term.')
#st = time.time()
# for member_id in range(len(ds1.UET.coords['member_id'])-99): # per member
#if True:
# member_id=0
# Zonal component
#uadv = -( grid.diff(ds1.UET*ds1.cell_volume.values, # degC/s
#                    axis="X",
#                    boundary="extend")
#         /ds_ren.cell_volume)
uadv = -( grid.diff(to_coord((ds.cell_volume*ds.UET),'u','t'), # degC/s
                    axis="X",
                    boundary="extend")
         /ds_ren.cell_volume)

# Meridional component
vadv = -( grid.diff(to_coord((ds.cell_volume*ds.VNT),'t','u'), # degC/s 
                    axis="Y", 
                    boundary="extend")
         /ds_ren.cell_volume)

# Total horizontal 
h_adv = uadv + vadv # degC/s

### Some tests

In [None]:
ds_ren.cell_volume.coords

In [None]:
ds_ren.TEMP.coords

### Get the advection from the mean flow 

In [None]:
%%time
# Horizontal advection from mean flow -- you need this because the difference between this and the term above is the eddy term 
print('Getting horizontal advection from mean flow')
#st = time.time()
# for member_id in range(len(ds1.UET.coords['member_id'])): # per member
#if True:
#    member_id = 0
# u term
U_interp = grid.interp((ds_ren.UVEL*ds_ren.dz*ds_ren.DYU), # cm3/s
                       axis="Y",
                       boundary="extend")

T_interp_X = grid.interp(ds_ren.TEMP, # degC
                            axis="X",
                            boundary="extend")

uT = U_interp * T_interp_X # degC*cm3/s

H_ADV_mean = (-(grid.diff(uT, axis="X", boundary="extend") # degC/s
                /ds_ren.cell_volume))

# v term
V_interp = grid.interp((ds_ren.VVEL*ds_ren.dz*ds_ren.DXU), # cm3/s
                       axis="X",
                       boundary="extend")

T_interp_Y = grid.interp(ds_ren.TEMP, # degC
                            axis="Y",
                            boundary="extend")

vT = V_interp * T_interp_Y # degC*cm3/s

# Total term
H_ADV_mean = H_ADV_mean - (grid.diff(vT,  # degC/s
                                     axis="Y",
                                     boundary="extend")
                           /ds_ren.cell_volume)

# Calculate difference for eddy part  
T_h_ADV_eddy = h_adv - H_ADV_mean    # degC/s

#### Check the coordinates

In [None]:
T_interp_Y.coords

In [None]:
V_interp.coords

In [None]:
T_interp_X.coords

In [None]:
U_interp.coords

### ---------------------------------------------------------------------------------------------------------------------------------------------------------
### Meridional Heat transport decomposition 

#### Get the mean value from the long term average for the entire series
- We selected the period between 2015 and 2100.

In [None]:
# Mean
V_interp_mean = V_interp.sel(time=slice('2015-01-01','2100-12-31')).mean(dim=['time']) # (cm3/s)
T_interp_Y_mean = T_interp_Y.sel(time=slice('2015-01-01','2100-12-31')).mean(dim=['time']) # (degC)

# Anomaly 
V_interp_anom = V_interp.sel(time=slice('2015-01-01','2100-12-31'))-V_interp_mean # (cm3/s)
T_interp_Y_anom = T_interp_Y.sel(time=slice('2015-01-01','2100-12-31'))-T_interp_Y_mean # (degC)

#### Expanding the average

In [None]:
V_interp_mean=V_interp_mean.expand_dims(dim={"time": V_interp_anom.coords['time']})
T_interp_Y_mean=T_interp_Y_mean.expand_dims(dim={"time": T_interp_Y_anom.coords['time']})

#### Group the data in a single dataset

In [None]:
ds_out = xr.merge([V_interp_mean.rename('V_mean'),
                   T_interp_Y_mean.rename('T_mean'),
                   V_interp_anom.rename('V_anom'),
                   T_interp_Y_anom.rename('T_anom'),
                  ])
del V_interp_mean, T_interp_Y_mean, V_interp_anom, T_interp_Y_anom
ds_out.attrs['description'] = 'Mean (2015-2100) and anomaly from the mean for meridional volume transport, temperature, and temperature flux.'
ds_out.attrs['units'] = 'cm3/s, degC'
ds_out.attrs['author'] = 'Mauricio Rocha'
ds_out.attrs['email'] = 'mauricio.rocha@usp.br'

#### Subset
- Get the South Atlantic region.
- P.S.: It is important to subset after using the xgcm function.

In [None]:
# We will use variables TLONG and TLAT
pop_grid = pop_tools.get_grid('POP_gx1v7')
pop_grid
atl = pop_grid.REGION_MASK.where(pop_grid.REGION_MASK==6)/pop_grid.REGION_MASK.where(pop_grid.REGION_MASK==6)
atl['TLAT'] = pop_grid['TLAT']
atl['TLONG'] = pop_grid['TLONG']
atl['ULAT'] = pop_grid['ULAT']
atl['ULONG'] = pop_grid['ULONG']
atl = atl.fillna(0)
south_atl = atl * atl.where(atl.TLAT<=0.)
south_atl = south_atl.fillna(0)
south_atl = south_atl * south_atl.where(south_atl.TLAT>=-34.)
south_atl = south_atl.fillna(0)
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = south_atl.plot.pcolormesh(ax=ax,
                              transform=ccrs.PlateCarree(),
                              x='TLONG',
                              y='TLAT')                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines()

In [None]:
mask = xr.Dataset({"REGION_MASK": south_atl},coords={'nlat': south_atl.coords['nlat'], 'nlon':south_atl.coords['nlon']})
mask.attrs=pop_grid.attrs
mask['REGION_MASK'].attrs=pop_grid['REGION_MASK'].attrs
grid_mask, ds_ren_mask = pop_tools.to_xgcm_grid_dataset(mask) 
ds_ren_mask = ds_ren_mask.drop('nlat_u', dim=None)
ds_ren_mask = ds_ren_mask.rename({'nlat_t': 'nlat_u'})
ds_ren_mask.REGION_MASK.plot()

<div class="alert alert-block alert-info">
<b>Note:</b> As you can see above, the mask contains no U-coordinates. We will rename it to U coordinate, and what we lose of coordinate is irrelevant to the final result. 
</div>

In [None]:
ds_ren_mask['nlat_u']=ds_out['nlat_u']
ds_out_subset=ds_out.where(ds_ren_mask.REGION_MASK != 0.)

##### Let's split heat transport into velocity $(\rm{V})$ and temperature $(\rm{T})$ components as follows:
##### $$\rm{VT} = (\rm{\bar{V}+V^{'})(\bar{T}+T^{'})},$$
##### $$\rm{VT} = \rm{\bar{V}\bar{T}+\bar{V}T^{'}+V^{'}\bar{T}+V^{'}T^{'}}.$$

In [None]:
TprimeVprime = xr.open_dataset('/glade/scratch/mauricio/Data/LENS2/N_Heat_Decomposition/TprimeVprime.nc')
TbarVprime = xr.open_dataset('/glade/scratch/mauricio/Data/LENS2/N_Heat_Decomposition/TbarVprime.nc')
VbarTprime = xr.open_dataset('/glade/scratch/mauricio/Data/LENS2/N_Heat_Decomposition/VbarTprime.nc')
VbarTbar = xr.open_dataset('/glade/scratch/mauricio/Data/LENS2/N_Heat_Decomposition/VbarTbar.nc')

In [None]:
%%time
VbarTbar.VbarTbar.isel(nlat_u=85).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(x="time",color='purple',linewidth=1,label=r'$\rm{\bar{T}\bar{V}}: 34^oS$')                                                                           
VbarTbar.VbarTbar.isel(nlat_u=186).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(x="time",color='plum',linewidth=1,label=r'$\rm{\bar{T}\bar{V}}: 0^o$')

TbarVprime.TbarVprime.isel(nlat_u=85).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(x="time",color='blue',linewidth=1,label=r'$\rm{\bar{T}V{^\prime}}: 34^oS$')
TbarVprime.TbarVprime.isel(nlat_u=186).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(x="time",color='green',linewidth=1,label=r'$\rm{\bar{T}V{^\prime}}: 0^o$')

VbarTprime.VbarTprime.isel(nlat_u=85).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(x="time",color='orange',linewidth=1,label=r'$\rm{\bar{V}T{^\prime}}: 34^oS$')
VbarTprime.VbarTprime.isel(nlat_u=186).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(x="time",color='red',linewidth=1,label=r'$\rm{\bar{V}T{^\prime}}: 0^o$')

TprimeVprime.TprimeVprime.isel(nlat_u=85).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(x="time",color='aquamarine',linewidth=1,label=r'$\rm{T{^\prime}V{^\prime}}: 34^oS$')
TprimeVprime.TprimeVprime.isel(nlat_u=186).mean('member_id').resample(time='1Y', closed='left').mean('time').plot(x="time",color='lime',linewidth=1,label=r'$\rm{T{^\prime}V{^\prime}}: 0^o$')                                                                           

(TprimeVprime.TprimeVprime.isel(nlat_u=85).mean('member_id')+VbarTprime.VbarTprime.isel(nlat_u=85).mean('member_id')+TbarVprime.TbarVprime.isel(nlat_u=85).mean('member_id')+VbarTbar.VbarTbar.isel(nlat_u=85).mean('member_id')).resample(time='1Y', closed='left').mean('time').plot(x="time",color='black',linewidth=1,label=r'$\rm{VT}: 34^oS$')
(TprimeVprime.TprimeVprime.isel(nlat_u=186).mean('member_id')+VbarTprime.VbarTprime.isel(nlat_u=186).mean('member_id')+TbarVprime.TbarVprime.isel(nlat_u=186).mean('member_id')+VbarTbar.VbarTbar.isel(nlat_u=186).mean('member_id')).resample(time='1Y', closed='left').mean('time').plot(x="time",color='gray',linewidth=1,label=r'$\rm{VT}: 0^o$')                                                                           
                                                                                                                      
plt.tight_layout()
plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.ylabel(r'Temperature transport [$\rm{Sv.^oC}$]')
plt.grid(color='gray', linestyle='-', linewidth=0.7)
plt.title(None)
plt.xlabel('Time [Years]')
plt.savefig('MHT_components.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
Total_34S=(TprimeVprime.TprimeVprime.isel(nlat_u=85).mean('member_id')+VbarTprime.VbarTprime.isel(nlat_u=85).mean('member_id')+TbarVprime.TbarVprime.isel(nlat_u=85).mean('member_id')+VbarTbar.VbarTbar.isel(nlat_u=85).mean('member_id')).resample(time='1Y', closed='left').mean('time')
Total_equator=(TprimeVprime.TprimeVprime.isel(nlat_u=186).mean('member_id')+VbarTprime.VbarTprime.isel(nlat_u=186).mean('member_id')+TbarVprime.TbarVprime.isel(nlat_u=186).mean('member_id')+VbarTbar.VbarTbar.isel(nlat_u=186).mean('member_id')).resample(time='1Y', closed='left').mean('time')

((Total_equator-Total_34S)*1026*3996).plot(x="time",color='gray',linewidth=1,label=r'$\rm{VT}: 0^o$')    
plt.show()