## CESM2 - LARGE ENSEMBLE (LENS2)
- In this Notebook we want to control the temperature flow: speed and temperature. For velocity, we will compute the advective terms and the eddies. 

### Imports

In [1]:
# modules I am using in this example
import xarray as xr
import xgcm
from xgcm import Grid
import pop_tools
from dask.distributed import Client, wait
from ncar_jobqueue import NCARCluster
import dask
import intake
import intake_esm

### Dask

In [2]:
mem_per_worker = 40 # memory per worker in GB 
num_workers = 40 # number of workers
cluster = NCARCluster(cores=1, processes=1, memory=f'{mem_per_worker} GB',resource_spec=f'select=1:ncpus=1:mem={mem_per_worker}GB', walltime='2:00:00')
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 41827 instead


<Client: 'tcp://10.12.206.60:35370' processes=0 threads=0, memory=0 B>


0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mauricio/mrocha/proxy/41827/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mauricio/mrocha/proxy/41827/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.12.206.60:35370,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mauricio/mrocha/proxy/41827/status,Total threads: 0
Started: Just now,Total memory: 0 B


### Load data

In [3]:
%%time
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/UET/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.UET.205501-206412.nc'
ds_UET = xr.open_mfdataset(path,parallel=True)
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/VNT/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.VNT.205501-206412.nc'
ds_VNT = xr.open_mfdataset(path,parallel=True)
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/VVEL/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.VVEL.205501-206412.nc'
ds_VVEL = xr.open_mfdataset(path,parallel=True)
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/UVEL/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.UVEL.205501-206412.nc'
ds_UVEL = xr.open_mfdataset(path,parallel=True)
path = '/glade/campaign/cgd/cesm/CESM2-LE/timeseries/ocn/proc/tseries/month_1/TEMP/b.e21.BSSP370smbb.f09_g17.LE2-1301.019.pop.h.TEMP.205501-206412.nc'
ds_TEMP = xr.open_mfdataset(path,parallel=True)
del path
dsa = xr.merge([ds_UET,ds_VNT,ds_VVEL,ds_UVEL,ds_TEMP],compat='override')
del ds_UET,ds_VNT,ds_VVEL,ds_UVEL,ds_TEMP

CPU times: user 438 ms, sys: 113 ms, total: 551 ms
Wall time: 6.12 s


In [5]:
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)
cat_subset = catalog.search(component='ocn',variable=['UET','VNT','VVEL','UVEL','TEMP'],frequency='month_1')
# Load catalog entries for subset into a dictionary of xarray datasets
dset_dict_raw  = cat_subset.to_dataset_dict(zarr_kwargs={'consolidated': True}, storage_options={'anon': True})
print(f'\nDataset dictionary keys:\n {dset_dict_raw.keys()}')

  return pd.read_csv(catalog_path, **csv_kwargs), catalog_path



--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.stream.forcing_variant.variable'



Dataset dictionary keys:
 dict_keys(['ocn.ssp370.pop.h.smbb.UVEL', 'ocn.ssp370.pop.h.smbb.VNT', 'ocn.ssp370.pop.h.smbb.VVEL', 'ocn.historical.pop.h.cmip6.TEMP', 'ocn.historical.pop.h.smbb.VNT', 'ocn.historical.pop.h.cmip6.VVEL', 'ocn.historical.pop.h.cmip6.UVEL', 'ocn.historical.pop.h.smbb.UET', 'ocn.ssp370.pop.h.cmip6.UET', 'ocn.ssp370.pop.h.cmip6.VVEL', 'ocn.historical.pop.h.smbb.UVEL', 'ocn.ssp370.pop.h.cmip6.TEMP', 'ocn.ssp370.pop.h.cmip6.VNT', 'ocn.ssp370.pop.h.smbb.UET', 'ocn.historical.pop.h.smbb.TEMP', 'ocn.ssp370.pop.h.smbb.TEMP', 'ocn.historical.pop.h.cmip6.VNT', 'ocn.historical.pop.h.cmip6.UET', 'ocn.ssp370.pop.h.cmip6.UVEL', 'ocn.historical.pop.h.smbb.VVEL'])


In [6]:
ff=('cmip6','smbb') # Forcings
fb=(['UET','VNT','VVEL','UVEL','TEMP']) # Flux of Heat in grid-x direction, Flux of Heat in grid-y direction, 
#Velocity in grid-y direction, Velocity in grid-y direction, Potential Temperature
#dsi = dict()
dsi = []
for var in fb:
    # 1- combine historical and ssp370 (concatenate in time)
#   ds_dict_tmp = dict()
    ds_dict_tmp = []
    for scenario in ff:
#       ds_dict_tmp[scenario] = xr.combine_nested([dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'], 
#       dset_dict_raw[f'ocn.ssp370.pop.h.{scenario}.{var}']],concat_dim=['time'])
        ds_dict_tmp.append(xr.combine_nested([dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'], 
                                              dset_dict_raw[f'ocn.ssp370.pop.h.{scenario}.{var}']],concat_dim=['time']))    
        # 2- combine cmip6 and smbb (concatenate in member_id)
#   dsi[var] = xr.combine_nested([ds_dict_tmp['cmip6'], ds_dict_tmp['smbb']], concat_dim=['member_id'])
    dsi.append(xr.combine_nested([ds_dict_tmp[0], ds_dict_tmp[1]], concat_dim=['member_id']))
ds = xr.merge([dsi[0],dsi[1],dsi[2],dsi[3],dsi[4]],compat='override')

# Add attributes from last scenario / variable to ds
for key in ['Conventions', 'calendar']:
    ds.attrs[key] = dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'].attrs[key]
ds['TLONG'].attrs['axis']='X'
ds['TLAT'].attrs['axis']='Y'
list(dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'].attrs.keys())

['intake_esm_varname',
 'revision',
 'cell_methods',
 'source',
 'calendar',
 'contents',
 'history',
 'Conventions',
 'model_doi_url',
 'time_period_freq',
 'intake_esm_dataset_key']

In [None]:
ds1=ds.copy()

In [None]:
for i in dsa.variables:
    ds1[i].attrs=dsa[i].attrs

In [None]:
ds1.attrs=dsa.attrs

In [None]:
ds1

### Help functions

In [None]:
def to_index(dsf):
    dsf = dsf.copy()
    for dim in ds_sa.dims:
        if dim in ["nlon_t", "nlat_t", "nlon_u", "nlat_u"]:
            dsf = dsf.drop(dim).rename({dim: dim[:-2]})
    return dsf
def to_coord(dsf, lon, lat):
    return dsf.rename({"nlon": "nlon_" + lon,
                      "nlat": "nlat_" + lat})

### Get the POP_grid

In [None]:
# get the cell volume and the grid from pop_tools. this should be easier for LENS low res, as your DZT is just your dz because you are not dealing with partial bottom cells (right Gustavo?) 
ds['cell_volume'] = ds.dz * ds.DXT * ds.DYT # Volume (check out later the unit)
grid, ds_ren = pop_tools.to_xgcm_grid_dataset(ds1) # We gotta confirm if we may use just one variable to compute the grid and ds_ren
ds_ren['cell_volume'] = ds_ren.dz * ds_ren.DXT * ds_ren.DYT # Volume (check out later the unit)

### Advection

In [None]:
ds_ren

In [None]:
ds_ren

In [None]:
%%time
# get the total advection as saved by the model 
# total advection saved by model --> the difference betweem this and the mean is the eddy component
# horizontal components
print('Getting total advection term.')
#st = time.time()
uadv = -( grid.diff(to_coord((ds.cell_volume * ds.UET), 'u', 't'), 
                   axis="X", boundary="extend")
        / ds_ren.cell_volume )
vadv = -( grid.diff(to_coord((ds.cell_volume * ds.VNT), 't', 'u'),
                   axis="Y", boundary="extend")
        / ds_ren.cell_volume )

In [None]:
# total horizontal
h_adv = uadv + vadv # I think you only want vnt? just copying everything here for completeness 

In [None]:
h_adv

In [None]:
# get the advection from the mean flow 
# horizontal advection from mean flow -- you need this because the difference between this and the term above is the eddy term 
print('Getting horizontal advection from mean flow.')
#st = time.time()
# u term
U_interp = grid.interp((ds_ren.UVEL * ds_ren.dz * ds_ren.DYU),
                      axis="Y",
                      boundary="extend")
uT = U_interp * grid.interp(ds_ren.TEMP,
                            axis="X",
                            boundary="extend")
H_ADV_mean = (-(grid.diff(uT, axis="X", boundary="extend")
               / ds_ren.cell_volume))
# v term
V_interp = grid.interp((ds_ren.VVEL * ds_ren.dz * ds_ren.DXU),
                       axis="X", boundary="extend")
vT = V_interp * grid.interp(ds_ren.TEMP,
                            axis="Y", boundary="extend")
# total term, again you don't need this, you just need the V term I think, and you do want to do the operation grid.diff(...) / cell_volume
H_ADV_mean = H_ADV_mean - (grid.diff(vT, axis="Y", boundary="extend")
                           / ds_ren.cell_volume)

In [None]:
#then calculate difference for eddy part  
T_h_ADV_eddy = h_adv - H_ADV_mean

In [None]:
T_h_ADV_eddy=(T_h_ADV_eddy.isel(nlon=0,nlat=0)).compute()

In [None]:
h_adv=h_adv.isel(nlon=0,nlat=0).compute()

In [None]:
H_ADV_mean=H_ADV_mean.isel(nlon=0,nlat=0).compute()