# Explore the COAWST US East Coast and Gulf of Mexico Forecast Archive Dataset
This is a cloud-optimized version of the NetCDF files accessed from the USGS ScienceBase item [Collection of COAWST model forecast for the US East Coast and Gulf of Mexico](https://www.sciencebase.gov/catalog/item/610acd4fd34ef8d7056893da).   The original daily forecast files were converted into weekly NetCDF files with 168 points in the time dimension to facilitate time series access. 

In [1]:
import os
import fsspec
import xarray as xr
import hvplot.xarray
import intake
import cf_xarray
import numpy as np
import panel as pn
from matplotlib import path
import xoak
import zarr

ModuleNotFoundError: No module named 'xoak'

## Open Dataset

The details of data loading are stored in an `intake` catalog, which simplifies use.  Metadata and coordinate data are loaded, but not the actual data variables, which are loaded only as needed by subsequent analysis and visualization. 

In [None]:
intake_catalog_url = 'https://usgs-coawst.s3.amazonaws.com/useast-archive/coawst_intake.yml'
cat = intake.open_catalog(intake_catalog_url)

In [None]:
if zarr.__version__[0]=='2':
    dataset = 'COAWST-USEAST-zarr2' 
elif zarr.__version__[0]=='3':
    dataset = 'COAWST-USEAST-zarr3'

This is a big dataset, so it takes up to 30s to open the dataset (which involves reading all the metadata and index coordinate variable data). Here we load the data into xarray using `.to_dask()` so that if we have a Dask cluster, we can speed up data processing by loading and processing chunks of data in parallel. 

In [None]:
%%time
ds = cat[dataset].to_dask()

Let's look at that metadata.  We can explore the different attributes and variables by clicking on the variables and icons below. 

In [2]:
ds.nbytes/1e12

NameError: name 'ds' is not defined

We can also explore a specific variable of interest:

In [None]:
var = 'Hwave'
da = ds[var]
da

Use the CF conventions to identify the coordinate variables for longitude, latitude and time

In [None]:
x = da.cf['longitude']
y = da.cf['latitude']
t = da.cf['time']
print(x.name, y.name, t.name)

## Example: Load the entire spatial domain for a variable at a specific time step
Loading the entire spatial domain at a time step only requires reading 8 chunks of data, so it loads in a few seconds.  A dask cluster doesn't help much in this case as it's already fast.   


In [None]:
%%time
da2d = da.cf.sel(T='2012-10-29 12:00', method='nearest').load()

In [None]:
da2d.hvplot.quadmesh(x=x.name, y=y.name, rasterize=True, geo=True, tiles='OSM', cmap='viridis')

## Example: Load a time series for a variable at a specific lon,lat location for a specified time range. 

To identify a point, we will start with its lat/lon coordinates.  If lon and lat were 1D coordinates, we could use lon,lat values to select using xarray, but instead we need to extract using indices, which we need to find.   For this we use the `xoak` package:

In [None]:
lat,lon = 42.5, -70.0  # Gulf of Maine, 100km east of Boston, MA

In [None]:
da.xoak.set_index([y.name, x.name], 'scipy_kdtree')

In [None]:
ds_point = xr.Dataset({"lon": ("point", [lon]), "lat": ("point", [lat])})

Before we read the data, let's see how many chunks we will be reading:

In [None]:
da.xoak.sel(lat_rho=ds_point.lat, lon_rho=ds_point.lon).cf.sel(T='2012-10')

To load this one month means reading 5 chunks of data, so still don't need a cluster:

In [None]:
%%time
da1d = da.xoak.sel(lat_rho=ds_point.lat, lon_rho=ds_point.lon).cf.sel(T='2012-10').load()

In [None]:
da1d.hvplot(x=t.name, grid=True)

How many chunks of data will we read to load the entire time series of record at a point?

In [None]:
da.xoak.sel(lat_rho=ds_point.lat, lon_rho=ds_point.lon)

Since we now need to read 669 chunks of data, we should use a Dask cluster if we have access to one

### Parallelize with Dask 
We opened the dataset so that we can take advantage of parallel compute environments
using `dask`. We're going to start a cluster now so that future steps can take advantage
of this ability. 

This is an optional step, but speeds up data loading and processing significantly, especially 
when accessing data from the cloud.

There are many ways to [deploy a Dask cluster](https://docs.dask.org/en/stable/deploying.html#deploy-dask-clusters).   
Below each cell uses a different approach.   Use one of the approaches below or choose another method. 

In [None]:
#cluster_type = 'Local'    
#cluster_type = 'Coiled'
cluster_type = 'Gateway'
# cluster_type = 'Coiled'

#### Use LocalCluster
LocalCluster is available in any computing environment.  It uses the number of CPUs of the computer running the notebook to create a cluster. 

In [None]:
if cluster_type == 'Local':
    from dask.distributed import LocalCluster, Client
    cluster = LocalCluster()
    client = Client(cluster)

#### Use Coiled
[Coiled](https://www.coiled.io/) provides access to remote Dask clusters that can be used from anywhere.  It requires a Coiled account. 

In [None]:
if cluster_type == 'Coiled':
    import coiled
    cluster = coiled.Cluster(
        region="us-west-2",
        arm=True,   # run on ARM to save energy & cost
        worker_vm_types=["t4g.small"],  # cheap, small ARM instances, 2cpus, 2GB RAM
        worker_options={'nthreads':2},
        n_workers=30,
        wait_for_workers=False,
        compute_purchase_option="spot_with_fallback",
        name='coawst',   # Dask cluster name
        software='esip-pangeo-arm',  # Conda environment name
        workspace='esip-lab',
        timeout=180   # leave cluster running for 3 min in case we want to use it again
    )

    client = cluster.get_client()

#### Use a Dask Gateway Cluster
[Dask Gateway](https://gateway.dask.org/) is a common way to spin up a Dask Cluster.  [Nebari](https://nebari.dev) and [DaskHub](https://github.com/dask/helm-chart) are popular ways of deploying a JupyterHub with Dask Gateway.  You can use a JupyterHub with DaskGateway for free by [signing up for access to the Microsoft Planetary Computer hub](https://planetarycomputer.microsoft.com/account/request).

In [None]:
from dask_gateway import Gateway

gateway = Gateway()  # instantiate Dask gateway 

# Cluster options on Nebari 
options = gateway.cluster_options()

In [None]:
options.image=os.environ['JUPYTER_IMAGE']

In [None]:
options

In [None]:
%%time
if cluster_type == 'Gateway':   # Pangeo@EOSC DaskHub
    from dask_gateway import Gateway

    gateway = Gateway()  # instantiate Dask gateway 

    # Cluster options on Nebari 
    options = gateway.cluster_options()
    options.image = os.environ['JUPYTER_IMAGE']

    # Create a Dask Gateway cluster
    cluster = gateway.new_cluster(options)

    # Get the Dask client for the Dask Gateway cluster
    client = cluster.get_client()

    # Scale the cluster
    cluster.adapt(minimum=4, maximum=30)

In [None]:
%%time
if cluster_type == 'Nebari':    #ESIP Nebari Deployment
    import sys, os
    sys.path.append(os.path.join(os.environ['HOME'],'shared','users','lib'))
    import nebari_tools as nbt

    aws_profile = 'esip-qhub'
    aws_region = 'us-west-2'
    endpoint_url = f's3.{aws_region}.amazonaws.com'

    nbt.set_credentials(profile=aws_profile, region=aws_region, endpoint_url=endpoint_url)
    worker_max = 30

    client, cluster = nbt.start_dask_cluster(profile=aws_profile, worker_max=worker_max, 
                                          region=aws_region, use_existing_cluster=True,
                                          adaptive_scaling=True, wait_for_cluster=True, 
                                          worker_profile='Small Worker', 
                                          propagate_env=True)

In [None]:
client

Load the entire time series:

In [None]:
%%time
ds_selection = da.xoak.sel(lat_rho=ds_point.lat, lon_rho=ds_point.lon).load()       

In [None]:
ds_selection.hvplot(x=t.name, grid=True) 

## Example: Compute the time mean for a variable over the entire domain for a specific time period

In [None]:
%%time
da_mean = da.cf.sel(T=slice('2016-01-01 00:00','2017-01-01 00:00')).mean(dim=t.name).compute()

In [None]:
da_mean.hvplot.quadmesh(x=x.name, y=y.name, rasterize=True, geo=True, tiles='OSM', cmap='viridis')

## Example: Subset a time and space region and export to NetCDF

In [None]:
def bbox2ij(lon,lat,bbox=[-160., -155., 18., 23.]):
    """Return indices for i,j that will completely cover the specified bounding box.     
    i0,i1,j0,j1 = bbox2ij(lon,lat,bbox)
    lon,lat = 2D arrays that are the target of the subset
    bbox = list containing the bounding box: [lon_min, lon_max, lat_min, lat_max]

    Example
    -------  
    >>> i0,i1,j0,j1 = bbox2ij(lon_rho,lat_rho,[-71, -63., 39., 46])
    >>> h_subset = nc.variables['h'][j0:j1,i0:i1]       
    """
    bbox=np.array(bbox)
    mypath=np.array([bbox[[0,1,1,0]],bbox[[2,2,3,3]]]).T
    p = path.Path(mypath)
    points = np.vstack((lon.ravel(),lat.ravel())).T   
    n,m = np.shape(lon)
    inside = p.contains_points(points).reshape((n,m))
    ii,jj = np.meshgrid(range(m),range(n))
    return min(ii[inside]),max(ii[inside]),min(jj[inside]),max(jj[inside])

In [None]:
bbox = [-76.63290610753754, -73.55671530588432, 37.57888442021855, 41.225532965406224]   # DRB

In [None]:
i0,i1,j0,j1 = bbox2ij(x.values, y.values, bbox=bbox)
print(i0,i1,j0,j1)

In [None]:
ds_drb = ds[['temp', 'salt', 'Hwave']].isel(eta_rho=slice(j0,j1), xi_rho=slice(i0,i1))

In [None]:
ds_drb

In [None]:
ds_drb_timeslice = ds_drb.cf.sel(T=slice('2022-04-01 00:00','2022-04-08 00:00'))

In [None]:
ds_drb_timeslice = ds_drb_timeslice.chunk({'eta_rho':-1, 'xi_rho':-1})  # chunk to full spatial subset domain
print(f'Uncompressed dataset size: {ds_drb_timeslice.nbytes/1e6} MB')

In [None]:
%%time
var = 'salt'
da_drb = ds_drb_timeslice[var].load()

In [None]:
viz = da_drb.hvplot.quadmesh(x=x.name, y=y.name, geo=True,
                    cmap='turbo', rasterize=True, tiles='OSM', title=var)
viz = pn.panel(viz, widgets={'ocean_time': pn.widgets.Select} )
pn.Column(viz).servable('DRB Explorer')

Close the Dask client since we can't write NetCDF in parallel

In [None]:
client.close()

Specify the encoding to enable compression in the NetCDF file

In [None]:
%%time
encoding={}
for var in ds_drb_timeslice.variables:
    encoding[var] = dict(zlib=True, complevel=4, 
                         fletcher32=False, shuffle=True,
                         _FillValue=None)

ds_drb_timeslice.to_netcdf('drb.nc', encoding=encoding, mode='w')

## Stop cluster

In [None]:
cluster.shutdown()