## Make Diagnostic Plots of Data in DART-CAM6 Zarr Stores

In [None]:
%load_ext watermark

import xarray as xr
import numpy as np
import dask
import intake

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from pathlib import Path
import os

from dask_jobqueue import PBSCluster

%watermark -iv

## Create and Connect to a Dask Distributed Cluster

Run the cell below if the notebook is running on a supercomputer with a PBS Scheduler.
If the notebook is running on a different parallel computing environment, you will need 
to replace the usage of `PBSCluster` with a similar object from `dask_jobqueue` or `dask_gateway`.

In [None]:
num_jobs = 20
walltime = '0:20:00'
memory='10GB' 

cluster = PBSCluster(cores=1, processes=1, walltime=walltime, memory=memory, queue='casper', 
                     resource_spec='select=1:ncpus=1:mem=10GB',)
cluster.scale(jobs=num_jobs)


from distributed import Client
client = Client(cluster)
cluster

## Find and Obtain Data Using an Intake Catalog

#### Choose Cloud Storage (AWS or NCAR Cloud)

In [None]:
# If True,  use NCAR Cloud Storage.   
# If False, use AWS  Cloud Storage.

USE_NCAR_CLOUD_STORAGE = True

#### Define the Intake Catalog URL and Storage Access Options

In [None]:
if USE_NCAR_CLOUD_STORAGE:
    catalog_url = "https://stratus.ucar.edu/ncar-dart-cam6/catalogs/aws-dart-cam6.json"
    storage_options={"anon": True, 'client_kwargs':{"endpoint_url":"https://stratus.ucar.edu/"}}
                     
else:
    catalog_url = "https://ncar-dart-cam6.s3-us-west-2.amazonaws.com/catalogs/aws-dart-cam6.json"
    storage_options={"anon": True}

#### Open catalog and produce a content summary

In [None]:
# Define the catalog description file location

# Open the catalog
col = intake.open_esm_datastore(catalog_url)
col

In [None]:
# Produce a catalog content summary.

uniques = col.unique()

print(f'variables: {uniques["variable"]}\n')

#### Load data into xarray using the catalog

In [None]:
data_var = 'PS'

col_subset = col.search(variable=data_var)
col_subset

#### Show the chosen Zarr store attributes

In [None]:
col_subset.df

#### Convert catalog subset to a dictionary of xarray datasets, and use the first one.

In [None]:
dsets = col_subset.to_dataset_dict(
    xarray_open_kwargs={"consolidated": True}, storage_options=storage_options
)
print(f"\nDataset dictionary keys:\n {dsets.keys()}")

# Load the first dataset and display a summary.
dataset_key = list(dsets.keys())[0]
ds = dsets[dataset_key]

ds

## Define Plot Functions

#### Get consistently shaped data slices for both 2D and 3D variables.

In [None]:
def getSlice(ds, data_var):
    '''If the data has vertical levels, choose the level closest
       to the Earth's surface for 2-D diagnostic plots.
    '''
    data_slice = ds[data_var]

    if 'lev' in data_slice.dims:
        lastLevel = ds.lev.values[-1]
        data_slice = data_slice.sel(lev = lastLevel)
        data_slice = data_slice.squeeze()

    return data_slice

#### Get lat/lon dimension names 

In [None]:
def getSpatialDimensionNames(data_slice):
    '''Get the spatial dimension names for this data slice.
    '''
    # Determine lat/lon conventions for this slice.
    lat_dim = 'lat' if 'lat' in data_slice.dims else 'slat'
    lon_dim = 'lon' if 'lon' in data_slice.dims else 'slon'
    
    return [lat_dim, lon_dim]

#### Produce Time Series Spaghetti Plot of Ensemble Members

In [None]:
def plot_timeseries(ds, data_var, store_name):
    '''Create a spaghetti plot for a given variable.
    '''
    figWidth = 25 
    figHeight = 20
    linewidth = 0.5

    numPlotsPerPage = 3
    numPlotCols = 1
    
    # Plot the aggregate statistics across time.
    fig, axs = plt.subplots(3, 1, figsize=(figWidth, figHeight))

    data_slice = getSlice(ds, data_var)
    spatial_dims = getSpatialDimensionNames(data_slice)

    unit_string = ds[data_var].attrs['units']

    # Persist the slice so it's read from disk only once.
    # This is faster when data values are reused many times.
    data_slice = data_slice.persist()

    max_vals = data_slice.max(dim = spatial_dims).transpose()
    mean_vals = data_slice.mean(dim = spatial_dims).transpose()
    min_vals = data_slice.min(dim = spatial_dims).transpose()

    
    rangeMaxs = max_vals.max(dim = 'member_id')
    rangeMins = max_vals.min(dim = 'member_id')
    axs[0].set_facecolor('lightgrey')
    axs[0].fill_between(ds.time, rangeMins, rangeMaxs, linewidth=linewidth, color='white')
    axs[0].plot(ds.time, max_vals, linewidth=linewidth, color='red', alpha=0.1)
    axs[0].set_title('Ensemble Member Maxima Over Time', fontsize=20)
    axs[0].set_ylabel(unit_string)

    rangeMaxs = mean_vals.max(dim = 'member_id')
    rangeMins = mean_vals.min(dim = 'member_id')
    axs[1].set_facecolor('lightgrey')
    axs[1].fill_between(ds.time, rangeMins, rangeMaxs, linewidth=linewidth, color='white')
    axs[1].plot(ds.time, mean_vals, linewidth=linewidth, color='red', alpha=0.1)
    axs[1].set_title('Ensemble Member Means Over Time', fontsize=20)
    axs[1].set_ylabel(unit_string)

    rangeMaxs = min_vals.max(dim = 'member_id')
    rangeMins = min_vals.min(dim = 'member_id')
    axs[2].set_facecolor('lightgrey')
    axs[2].fill_between(ds.time, rangeMins, rangeMaxs, linewidth=linewidth, color='white')
    axs[2].plot(ds.time, min_vals, linewidth=linewidth, color='red', alpha=0.1)
    axs[2].set_title('Ensemble Member Minima Over Time', fontsize=20)
    axs[2].set_ylabel(unit_string)

    plt.suptitle(store_name, fontsize=25)
    
    return fig

### Actually Create Spaghetti Plot Showing All Ensemble Members

In [None]:
%%time

store_name = f'{data_var}.zarr'
fig = plot_timeseries(ds, data_var, store_name)

### Save/Download the figure

To download the figure plot file:
* Run the following command.
* Find the file using the Jupyter file browser in the left sidebar.
* Right-click the file name, and select "Download".

In [None]:
fig.savefig(f'{data_var}.zarr.pdf', facecolor='white', dpi=200)

### Release the Dask workers.

In [None]:
cluster.close()

#### 