## CESM2 - LARGE ENSEMBLE (LENS2)

- This notebook serves as an example on how to extract surface (or any other 2D spatial field) properties from a selected spacial region accross all LENS2 members for the ocean component.

## Imports

In [None]:
import intake
import intake_esm
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import fsspec
import cmocean
import cartopy
import cartopy.feature as cfeature
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import pop_tools
import sys
from distributed import Client
from ncar_jobqueue import NCARCluster
sys.path.append('../functions')
import util
from cartopy.util import add_cyclic_point
from misc import get_ij
import warnings, getpass, os

<div class="alert alert-block alert-info">
<b>Note:</b> comment the following line when debugging
</div>

In [None]:
warnings.filterwarnings("ignore")

### Local functions

In [None]:
def rms_da(da, dims=('nlat', 'nlon'), weights=None,  weights_sum=None):
  """
  Calculates the rms in DataArray da (optional weighted rms).

  ----------
  da : xarray.DataArray
        DataArray for which to compute (weighted) rms.

  dims : tuple, str
    Dimension(s) over which to apply reduction. Default is ('yh', 'xh').

  weights : xarray.DataArray, optional
    weights to apply. It can be a masked array.

  weights_sum : xarray.DataArray, optional
    Total weight (i.e., weights.sum()). Only computed if not provided.

  Returns
  -------
  reduction : DataSet
      xarray.Dataset with (optionally weighted) rms for da.
  """

  if weights is not None:
    if weights_sum is None: weights_sum = weights.sum(dim=dims)
    out = np.sqrt((da**2 * weights).sum(dim=dims)/weights_sum)
    # copy attrs
    out.attrs = da.attrs
    return out
  else:
    return np.sqrt((da**2).mean(dim=dims, keep_attrs=True))

### Dask workers

In [None]:
mem_per_worker = 300 # in GB 
num_workers = 80 
cluster = NCARCluster(cores=4, processes=3, memory=f'{mem_per_worker} GB',resource_spec=f'select=1:ncpus=6:mem={mem_per_worker}GB')
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

### Data Ingest

In [None]:
%%time
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)

In [None]:
catalog.search(component='ocn').unique('frequency')

### Let's search for variables with montly frequency

In [None]:
#choose the variable:
var='TEMP' # SHF or XMXL or TEMP (Total Surface Heat Flux or Maximum Mixed Layer Depth or Sea temperature)
cat_subset = catalog.search(component='ocn',
                            frequency='month_1',
                            variable=['TEMP'])  

In [None]:
%%time
dset_dict_raw = cat_subset.to_dataset_dict()

In [None]:
cat_subset

In [None]:
# print keys
[key for key in dset_dict_raw.keys()]

In [None]:
pd=('historical','ssp370')
ff=('cmip6','smbb')
for iff in range(0,len(ff)):
    for ipd in range(0,len(pd)):
        str=f'ds_{pd[ipd]}_{ff[iff]}_{var} = dset_dict_raw[\'ocn.{pd[ipd]}.pop.h.{ff[iff]}.{var}\']'; exec(str)
    str=f'ds_{ff[iff]}_{var}=xr.combine_nested([ds_{pd[0]}_{ff[iff]}_{var},ds_{pd[1]}_{ff[iff]}_{var}],concat_dim=[\'time\'])'; exec(str)
    str=f'del ds_{pd[0]}_{ff[iff]}_{var},ds_{pd[1]}_{ff[iff]}_{var}'; exec(str) 
str=f'ds_{var}=xr.combine_nested([ds_{ff[0]}_{var},ds_{ff[1]}_{var}],concat_dim=[\'member_id\'])'; exec(str)
str=f'del ds_{ff[0]}_{var}, ds_{ff[1]}_{var}'; exec(str)
print(f'Done!')

In [None]:
# Annual Mean
str=f'ds_{var}=ds_{var}.{var}.resample(time=\'1Y\', closed=\'left\').mean(\'time\').isel(z_t=50)'; exec(str)

In [None]:
ds_TEMP

### Import the POP grid

If you choose the ocean component of LENS2, you will need to import the POP grid. For the other components, you can use the emsemble's own grid. 

In ds, TLONG and TLAT have missing values (NaNs), so we need to override them with the values from pop_grid, which does not have missing values.

In [None]:
# Read the pop 1 deg grid from pop_tools
# We will use variables TLONG and TLAT
pop_grid = pop_tools.get_grid('POP_gx1v7')
str=f'ds_{var}[\'TLONG\'] = pop_grid.TLONG'     # Longitud
exec(str)
str=f'ds_{var}[\'TLAT\'] = pop_grid.TLAT'       # Latitudes
exec(str)
str=f'ds_{var}[\'TLONG\'] = pop_grid.TLONG'     # Longitud
exec(str)
str=f'ds_{var}[\'TLAT\'] = pop_grid.TLAT'       # Latitudes
exec(str)

In [None]:
str=f'ds_var = ds_{var}.isel(member_id=0,time=0)'; exec(str)
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = ds_var.plot.pcolormesh(ax=ax,
                            transform=ccrs.PlateCarree(),
                            cmap=cmocean.cm.balance,
                            x='TLONG',
                            y='TLAT',
                            cbar_kwargs={'orientation': 'horizontal'})                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines()
del ds_var

### Centralize the South Atlantic 
Need to combine the domain in the east/west direction to centralize the South Atlantic

In [None]:
ilat, flat = 85, 187
ilon1, flon1, ilon2, flon2 = 308, 320, 0, 54
str=f'sa_ds_{var}=xr.combine_nested([[ds_{var}.isel(nlat = slice(ilat,flat),nlon = slice(ilon1,flon1)),ds_{var}.isel(nlat = slice(ilat,flat),nlon = slice(ilon2,flon2))]],concat_dim=[\'nlat\',\'nlon\'])'
exec(str)
str=f'sa_ds_{var}.coords[\'TLONG\'] = (sa_ds_{var}.coords[\'TLONG\'] + 180) % 360 - 180' # change the longitudes: -180 0 180
exec(str)

In [None]:
# simple check
str=f'sa_ds_{var}.isel(time=2, member_id=0).plot()'
exec(str)

In [None]:
%%time
str=f'ds_var = sa_ds_{var}.isel(member_id=0,time=0)'; exec(str)
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = ds_var.plot.pcolormesh(ax=ax,
                            transform=ccrs.PlateCarree(),
                            cmap=cmocean.cm.balance,
                            x='TLONG',
                            y='TLAT',
                            cbar_kwargs={"orientation": "horizontal"})                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();

### Extract correponding area 

In [None]:
area_sa = xr.combine_nested([
    [pop_grid.TAREA.isel(nlat = slice(ilat,flat),nlon = slice(ilon1,flon1)),
     pop_grid.TAREA.isel(nlat = slice(ilat,flat),nlon = slice(ilon2,flon2))]],
    concat_dim=['nlat','nlon']
)

In [None]:
# simple check
area_sa.plot()

### Perfom computations
Calculate area mean, min, max, and rms for the variables of the selected region

In [None]:
sa_ds_TEMP.isel(member_id=0,time=0).plot()

In [None]:
%%time
# Mean
str=f'var_mean_{var} = sa_ds_{var}.weighted(area_sa).mean(dim=(\'nlon\',\'nlat\')).load()'
exec(str)
print(f'var_mean_{var}')
# Maximum
str=f'var_max_{var} = sa_ds_{var}.max(dim=(\'nlon\',\'nlat\')).load()'
exec(str)
print(f'var_max_{var}')
# Minimum
str=f'var_min_{var} = sa_ds_{var}.min(dim=(\'nlon\',\'nlat\')).load()'
exec(str)
print(f'var_min_{var}')
# RMS
str=f'var_rms_{var} = rms_da(sa_ds_{var}, weights=area_sa, weights_sum=area_sa.sum()).load()'
exec(str)
print(f'var_rms_{var}')

### TODO
Plot some time series to check calculations

### Merge data and save on disk

In [None]:
units=('oC')
long_name=('Temperature')
str=f'ds_out_{var} = xr.merge([var_rms_{var}.rename(\'{var}_rms\'),var_mean_{var}.rename(\'{var}_mean\'),var_max_{var}.rename(\'{var}_max\'),var_min_{var}.rename(\'{var}_min\')])'
exec(str)
str=f'ds_out_{var}.attrs[\'description\'] = \'{long_name} ({var}) statistics for the South Atlantic (52.93749146W-20.18750056E and 33.81089045S-0.13356644S)\''
exec(str)
str=f'ds_out_{var}.attrs[\'units\'] = \'{units}\''
exec(str)
str=f'ds_out_{var}.attrs[\'author\'] = \'Mauricio Rocha\''
exec(str)
str=f'ds_out_{var}.attrs[\'email\'] = \'mauricio.rocha@usp.br\''
exec(str)

In [None]:
# Total Surface Heat Fux
fig, axes = plt.subplots(1, 4, figsize=(20, 8))

# Maximum
str=f'ds_out_{var}.{var}_max.resample(time=\'1Y\', closed=\'left\').mean(\'time\').plot.line(ax=axes[0],x=\'time\',color=\'orange\',alpha=0.01,linewidth=1,add_legend=False)'
exec(str)
str=f'ds_out_{var}.{var}_max.resample(time=\'1Y\', closed=\'left\').mean(\'time\').mean(\'member_id\').plot.line(ax=axes[0],x=\'time\',alpha=0.3,color=\'r\',linewidth=1,label=\'Member Mean 1Y\')'
exec(str)
str=f'ds_out_{var}.{var}_max.resample(time=\'5Y\', closed=\'left\').mean(\'time\').mean(\'member_id\').plot.line(ax=axes[0],x=\'time\',color=\'r\',linewidth=2,label=\'Member Mean 5Y\')'
exec(str)
axes[0].set_xlabel('Time [Years]')
str=f'axes[0].set_ylabel(\'{var} Max [{units}]\')'
exec(str)
axes[0].set_title('Area Max')
axes[0].grid(color='k', linestyle='-', linewidth=0.7)
axes[0].legend()
fig.tight_layout(pad=2.0)

# Mean
str=f'ds_out_{var}.{var}_mean.resample(time=\'1Y\', closed=\'left\').mean(\'time\').plot.line(ax=axes[1],x=\'time\',color=\'gray\',alpha=0.01,linewidth=1,add_legend=False)'
exec(str)
str=f'ds_out_{var}.{var}_mean.resample(time=\'1Y\', closed=\'left\').mean(\'time\').mean(\'member_id\').plot.line(ax=axes[1],x=\'time\',alpha=0.3,color=\'k\',linewidth=1,label=\'Member Mean 1Y\')'
exec(str)
str=f'ds_out_{var}.{var}_mean.resample(time=\'5Y\', closed=\'left\').mean(\'time\').mean(\'member_id\').plot.line(ax=axes[1],x=\'time\',color=\'k\',linewidth=2,label=\'Member Mean 5Y\')'
exec(str)
axes[1].set_xlabel('Time [Years]')
str=f'axes[1].set_ylabel(\'{var} Mean [{units}]\')'
exec(str)
axes[1].set_title('Area Mean')
axes[1].grid(color='k', linestyle='-', linewidth=0.7)
axes[1].legend()

# Minimum
str=f'ds_out_{var}.{var}_min.resample(time=\'1Y\', closed=\'left\').mean(\'time\').plot.line(ax=axes[2],x=\'time\',color=\'c\',alpha=0.01,linewidth=1,add_legend=False)'
exec(str)
str=f'ds_out_{var}.{var}_min.resample(time=\'1Y\', closed=\'left\').mean(\'time\').mean(\'member_id\').plot.line(ax=axes[2],x=\'time\',alpha=0.3,color=\'b\',linewidth=1,label=\'Member Mean 1Y\')'
exec(str)
str=f'ds_out_{var}.{var}_min.resample(time=\'5Y\', closed=\'left\').mean(\'time\').mean(\'member_id\').plot.line(ax=axes[2],x=\'time\',color=\'b\',linewidth=2,label=\'Member Mean 5Y\')'
exec(str)
axes[2].set_xlabel('Time [Years]')
str=f'axes[2].set_ylabel(\'{var} Min [{units}]\')'
exec(str)
axes[2].set_title('Area Min')
axes[2].grid(color='k', linestyle='-', linewidth=0.7)
axes[2].legend()

# Minimum
str=f'ds_out_{var}.{var}_rms.resample(time=\'1Y\', closed=\'left\').mean(\'time\').plot.line(ax=axes[3],x=\'time\',color=\'y\',alpha=0.01,linewidth=1,add_legend=False)'
exec(str)
str=f'ds_out_{var}.{var}_rms.resample(time=\'1Y\', closed=\'left\').mean(\'time\').mean(\'member_id\').plot.line(ax=axes[3],x=\'time\',alpha=0.3,color=\'g\',linewidth=1,label=\'Member Mean 1Y\')'
exec(str)
str=f'ds_out_{var}.{var}_rms.resample(time=\'5Y\', closed=\'left\').mean(\'time\').mean(\'member_id\').plot.line(ax=axes[3],x=\'time\',color=\'g\',linewidth=2,label=\'Member Mean 5Y\')'
exec(str)
axes[3].set_xlabel('Time [Years]')
str=f'axes[3].set_ylabel(\'{var} Error [{units}]\')'
exec(str)
axes[3].set_title('Area Error')
axes[3].grid(color='k', linestyle='-', linewidth=0.7)
axes[3].legend()
plt.show()

### Let's save the data in netcdf format

In [None]:
# create a directory on scratch to save the output
print(f'Variable: {var}')
str=f'path = \'/glade/scratch/mauricio/Data/LENS2/{var}/\'.format(getpass.getuser())'
exec(str)
str=f'os.system(\'mkdir -p \'+path)'
exec(str)
str=f'ds_out_{var}.to_netcdf(path+\'{var}_stats.nc\')'
exec(str)

In [None]:
cluster.close()
client.close()