## Make Diagnostic Plots for DART Reanalysis Zarr Stores

In [2]:
import xarray as xr
import numpy as np

import dask.distributed
from dask.distributed import Client
from ncar_jobqueue import NCARCluster

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
from pathlib import Path
import os

import pprint
import json

## Run These Cells for Dask CASPER

In [None]:
# Processes is processes PER CORE.
# This one works fine.
#cluster = NCARCluster(cores=15, processes=1, memory='100GB', project='STDD0003')
# This one also works, but occasionally hangs neacr the end.
#cluster = NCARCluster(cores=10, processes=1, memory='50GB', project='STDD0003')

# For Casper
num_cores = 2 #1
num_jobs = 50 #90 #4
walltime = "1:00:00" #"1:00:00"
memory = '50GB' #'50GB'

cluster = NCARCluster(cores=num_cores, processes=1, memory=memory, project='STDD0003', walltime=walltime)
cluster.scale(jobs=num_jobs)

client = Client(cluster)
cluster

## Run These Cells for Dask CHEYENNE

In [2]:

# For Cheyenne

num_cores = 16
processes_per_node= 16
num_nodes = 2 
memory = '109GB'

# Consider that land variable maps are > 40 minutes each.
walltime = "4:30:00" #"2:00:00" 

# Run <= 4 workers on each node to avoid crashes.
cluster = NCARCluster(cores=num_cores, processes=processes_per_node, 
                      memory=memory, walltime=walltime)

cluster.scale(jobs=num_nodes)

from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster

Tab(children=(HTML(value='\n            <div class="jp-RenderedHTMLCommon jp-RenderedHTML jp-mod-trusted jp-Ou…

## Define Plot Functions

#### Get consistently shaped data slices for both 2D and 3D variables.

In [5]:
def getSlice(ds, data_var, member_ids):
    '''If the data has vertical levels, choose the level closest
       to the Earth's surface for 2-D diagnostic plots.
    '''

    data_slice = ds[data_var].sel(member_id = member_ids)

    if 'lev' in data_slice.dims:
        lastLevel = ds.lev.values[-1]
        data_slice = data_slice.sel(lev = lastLevel)
        data_slice = data_slice.squeeze()

    return data_slice

In [12]:
def getSpatialDimensionNames(data_slice):
    '''Get the spatial dimension names for this data slice.
    '''
    # Determine lat/lon conventions for this slice.
    lat_dim = 'lat' if 'lat' in data_slice.dims else 'slat'
    lon_dim = 'lon' if 'lon' in data_slice.dims else 'slon'
    
    return [lat_dim, lon_dim]

### Create Time Series Plots of Selected Ensemble Members

In [33]:
def plot_timeseries(ds, data_var, member_ids, store_name):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    figWidth = 25 
    figHeight = 20
    linewidth = 0.75 #0.5

    numPlotsPerPage = 3
    numPlotCols = 1
    
    # Plot the aggregate statistics across time.
    fig, axs = plt.subplots(3, 1, figsize=(figWidth, figHeight))

    data_slice = getSlice(ds, data_var, member_ids)
    spatial_dims = getSpatialDimensionNames(data_slice)

    unit_string = ds[data_var].attrs['units']

    # Try loading the slice so it's read from disk only once.
    data_slice = data_slice.load()

    max_vals = data_slice.max(dim = spatial_dims)
    mean_vals = data_slice.mean(dim = spatial_dims)
    min_vals = data_slice.min(dim = spatial_dims)

    axs[0].plot(ds.time, max_vals.transpose(), linewidth=linewidth, label=member_ids)
    axs[0].set_title('Ensemble Member Maximums Over Time', fontsize=20)
    axs[0].set_ylabel(unit_string)
    axs[0].legend(loc='upper right')

    axs[1].plot(ds.time, mean_vals.transpose(), linewidth=linewidth, label=member_ids)
    axs[1].set_title('Ensemble Member Means Over Time', fontsize=20)
    axs[1].set_ylabel(unit_string)
    axs[1].legend(loc='upper right')

    axs[2].plot(ds.time, min_vals.transpose(), linewidth=linewidth, label=member_ids)
    axs[2].set_title('Ensemble Member Minimums Over Time', fontsize=20)
    axs[2].set_ylabel(unit_string)
    axs[2].legend(loc='upper right')

    plt.suptitle(store_name, fontsize=25)
    
    return fig

### Select a Zarr store

In [4]:
zarr_directory = '/glade/scratch/bonnland/DART/ds345.0/zarr-publish/'

store = zarr_directory + "weekly/PS.zarr"

# p = Path(zarr_directory)
# #stores = list(p.rglob("*.zarr"))
# stores = list(p.rglob("HR.zarr"))

ds = xr.open_zarr(store, consolidated=True)
ds

Unnamed: 0,Array,Chunk
Bytes,15.52 GiB,50.00 MiB
Shape,"(80, 471, 192, 288)","(80, 80, 32, 32)"
Count,325 Tasks,324 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 15.52 GiB 50.00 MiB Shape (80, 471, 192, 288) (80, 80, 32, 32) Count 325 Tasks 324 Chunks Type float64 numpy.ndarray",80  1  288  192  471,

Unnamed: 0,Array,Chunk
Bytes,15.52 GiB,50.00 MiB
Shape,"(80, 471, 192, 288)","(80, 80, 32, 32)"
Count,325 Tasks,324 Chunks
Type,float64,numpy.ndarray


### Select Ensemble Members and Make a Plot

In [None]:
data_var = 'PS'
member_ids = [3, 14, 25, 35, 46, 57, 68]
store_name = 'PS.zarr'
fig = plot_timeseries(ds, data_var, member_ids, store_name)

### Save/Download the figure

To download the figure plot file:
* Run the following command.
* Find the file using the Jupyter file browser in the left sidebar.
* Right-click the file name, and select "Download".

In [32]:
fig.savefig('PS.zarr.png')

### Release the Dask workers.

In [None]:
cluster.close()