## CESM2 - LARGE ENSEMBLE (LENS2)

#### by Mauricio Rocha and Dr. Gustavo Marques

- This notebooks servers as an example on how to extract surface (or any other 2D spatial field) properties from a selected spacial region accross all LENS2 members

## Imports

In [None]:
import intake
import intake_esm
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import fsspec
import cmocean
import cartopy
import cartopy.feature as cfeature
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import pop_tools
import sys
from distributed import Client
from ncar_jobqueue import NCARCluster
sys.path.append('../functions')
import util
from cartopy.util import add_cyclic_point
from misc import get_ij
import warnings, getpass, os

<div class="alert alert-block alert-info">
<b>Note:</b> comment the following line when debugging
</div>

In [None]:
warnings.filterwarnings("ignore")

### Local functions

In [None]:
def rms_da(da, dims=('nlat', 'nlon'), weights=None,  weights_sum=None):
  """
  Calculates the rms in DataArray da (optional weighted rms).

  ----------
  da : xarray.DataArray
        DataArray for which to compute (weighted) rms.

  dims : tuple, str
    Dimension(s) over which to apply reduction. Default is ('yh', 'xh').

  weights : xarray.DataArray, optional
    weights to apply. It can be a masked array.

  weights_sum : xarray.DataArray, optional
    Total weight (i.e., weights.sum()). Only computed if not provided.

  Returns
  -------
  reduction : DataSet
      xarray.Dataset with (optionally weighted) rms for da.
  """

  if weights is not None:
    if weights_sum is None: weights_sum = weights.sum(dim=dims)
    out = np.sqrt((da**2 * weights).sum(dim=dims)/weights_sum)
    # copy attrs
    out.attrs = da.attrs
    return out
  else:
    return np.sqrt((da**2).mean(dim=dims, keep_attrs=True))

### Dask workers

In [None]:
mem_per_worker = 60 # in GB 
num_workers = 80 
cluster = NCARCluster(cores=4, processes=3, memory=f'{mem_per_worker} GB',resource_spec=f'select=1:ncpus=6:mem={mem_per_worker}GB')
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

### Data Ingest

In [None]:
%%time
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)

In [None]:
catalog.search(component='ocn').unique('frequency')

In [None]:
# this prints all ocean variables that have montly frequency
#catalog.search(component='ocn', frequency='month_1').unique('variable')

### Let's search for TEMP with montly frequency

In [None]:
cat_subset = catalog.search(component='ocn',
                            frequency='month_1',
                            variable='TEMP')

In [None]:
%%time
dset_dict_raw = cat_subset.to_dataset_dict()

In [None]:
# print keys
[key for key in dset_dict_raw.keys()]

In [None]:
ds_hist_cmip6 = dset_dict_raw['ocn.historical.pop.h.cmip6.TEMP'] 

In [None]:
ds_hist_smbb = dset_dict_raw['ocn.historical.pop.h.smbb.TEMP'] 

In [None]:
ds_all = xr.concat([ds_hist_cmip6,ds_hist_smbb], dim='member_id', 
                     data_vars='minimal',coords="minimal",
                     compat="override")
ds_all.TEMP.nbytes*1e-12 # in TB

In [None]:
ds_all.TEMP

### Import the POP grid

If you choose the ocean component of LENS2, you will need to import the POP grid. For the other components, you can use the emsemble's own grid. 

In ds, TLONG and TLAT have missing values (NaNs), so we need to override them with the values from pop_grid, which does not have missing values.

In [None]:
# Read the pop 1 deg grid from pop_tools
# We will use variables TLONG and TLAT
pop_grid = pop_tools.get_grid('POP_gx1v7')
ds_all['TLONG'] = pop_grid.TLONG   # Longitud
ds_all['TLAT'] = pop_grid.TLAT     # Latitudes

### Plot SST

In [None]:
%%time
temp_tmp = ds_all.TEMP.isel(member_id=0, time=0, z_t=0)#.plot()

In [None]:
%%time
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = temp_tmp.plot.pcolormesh(ax=ax,
                    transform=ccrs.PlateCarree(),
                    cmap=cmocean.cm.balance,
                    x='TLONG',
                    y='TLAT',
                    vmin=-3,
                    vmax=30,
                    cbar_kwargs={"orientation": "horizontal"})                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();

### Centralize the South Atlantic 
Need to combine the domain in the east/west direction to centralize the South Atlantic

In [None]:
sa_ds=xr.combine_nested([
     [ds_all.isel(nlat = slice(115,190),nlon = slice(300,320)),
      ds_all.isel(nlat = slice(115,190),nlon = slice(0,60))]],
    concat_dim=['nlat','nlon']
)

In [None]:
# simple check
sa_ds.TEMP.isel(time=2, member_id=0, z_t=0).plot()

In [None]:
%%time
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = sa_ds.TEMP.isel(time=2, member_id=0, z_t=0).plot.pcolormesh(ax=ax,
                    transform=ccrs.PlateCarree(),
                    cmap=cmocean.cm.balance,
                    x='TLONG',
                    y='TLAT',
                    vmin=20,
                    vmax=35,
                    cbar_kwargs={"orientation": "horizontal"})                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();

### Extract correponding area 

In [None]:
area_sa = xr.combine_nested([
    
    [pop_grid.TAREA.isel(nlat = slice(115,190),nlon = slice(300,320)),
     pop_grid.TAREA.isel(nlat = slice(115,190),nlon = slice(0,60))]],
    concat_dim=['nlat','nlon']
)

In [None]:
# simple check
area_sa.plot();

### Select time window 
<div class="alert alert-block alert-info">
<b>Note:</b> We should process the entire dataset once we fix the issue with Dask.
</div>

In [None]:
start="2000-01-01"
end="2009-12-31"
sst = sa_ds.TEMP.isel(z_t=0).sel(time=slice(start,end))

### Perfom computations
Calculate area mean, min, max, and rms for the surface temperature of the selected region

In [None]:
# area weighted sst
sst_mean = sst.weighted(area_sa).mean(dim=("nlon", "nlat")).load()

In [None]:
# max sst
sst_max = sst.max(dim=("nlon", "nlat")).load()

In [None]:
# min sst
sst_min = sst.min(dim=("nlon", "nlat")).load()

In [None]:
# sst rms
sst_rms = rms_da(sst, weights=area_sa,  weights_sum=area_sa.sum()).load()

### TODO
plot some time series to check calculations

### Merge data and save on disk

In [None]:
ds_out.sst_min.isel(member_id=0).plot()
ds_out.sst_max.isel(member_id=0).plot()
ds_out.sst_mea.isel(member_id=0).plot()

In [None]:
ds_out = xr.merge([sst_rms.rename('sst_rms'),
                    sst_mean.rename('sst_mean'),
                    sst_max.rename('sst_max'),
                    sst_min.rename('sst_min')])

# TODO: add more attrs e.g., date it was create, by who (name, email), 
## improve description, add lat/lon region etc
ds_out.attrs['description'] = 'sst statistics for the South Atlantic'

In [None]:
# create a directory on scratch to save the output
#path = '/glade/scratch/{}/LENS2_south_atlantic/SST/'.format(getpass.getuser())
path = '/glade/scratch/mauricio/LENS2_south_atlantic/SST/'.format(getpass.getuser())
#os.system('mkdir -p '+path)

In [None]:
ds_out.to_netcdf(path+'sst_stats.nc')

In [None]:
cluster.close()
client.close()