## Make Diagnostic Plots for DART Reanalysis Zarr Stores

In [1]:
import xarray as xr
import numpy as np

import dask.distributed
from dask.distributed import Client
from ncar_jobqueue import NCARCluster

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
from pathlib import Path
import os

import pprint
import json

## Run These Cells for Dask CASPER

In [None]:
# Processes is processes PER CORE.
# This one works fine.
#cluster = NCARCluster(cores=15, processes=1, memory='100GB', project='STDD0003')
# This one also works, but occasionally hangs neacr the end.
#cluster = NCARCluster(cores=10, processes=1, memory='50GB', project='STDD0003')

# For Casper
num_cores = 2 #1
num_jobs = 50 #90 #4
walltime = "1:00:00" #"1:00:00"
memory = '50GB' #'50GB'

cluster = NCARCluster(cores=num_cores, processes=1, memory=memory, project='STDD0003', walltime=walltime)
cluster.scale(jobs=num_jobs)

client = Client(cluster)
cluster

## Run These Cells for Dask CHEYENNE

In [2]:

# For Cheyenne

num_cores = 16
processes_per_node= 16
num_nodes = 2 
memory = '109GB'

# Consider that land variable maps are > 40 minutes each.
walltime = "4:30:00" #"2:00:00" 

# Run <= 4 workers on each node to avoid crashes.
cluster = NCARCluster(cores=num_cores, processes=processes_per_node, 
                      memory=memory, walltime=walltime)

cluster.scale(jobs=num_nodes)

from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster

Tab(children=(HTML(value='\n            <div class="jp-RenderedHTMLCommon jp-RenderedHTML jp-mod-trusted jp-Ou…

## Define Plot Functions

#### Hide unused subplot panels (Helper Function)

In [3]:
def hide_subplots(axes, start_row):
    '''Given an array of axes and a row index, hide plots with this row index or later.
    Subplot array can be 1-D or 2-D.
    '''
    subplot_ndims = axes.ndim
    if subplot_ndims == 1:
        nrows = len(axes)
        for row in np.arange(start_row, nrows):
            axes[row].axis('off')
    else:
        assert(subplot_ndims == 2)
        (nrows, ncols) = axes.shape
        for row in np.arange(start_row, nrows):
            for col in np.arange(ncols):
                axes[row][col].axis('off')


#### Get consistently shaped data slices for both 2D and 3D variables.

In [4]:
def getSlice(ds, data_var, member_id):
    '''If the data has vertical levels, choose the level closest
       to the Earth's surface for 2-D diagnostic plots.
    '''

    data_slice = ds[data_var].sel(member_id = member_id)

    if 'lev' in data_slice.dims:
        lastLevel = ds.lev.values[-1]
        data_slice = data_slice.sel(lev = lastLevel)
        data_slice = data_slice.squeeze()

    return data_slice

In [5]:
def getSpatialDimensionNames(data_slice):
    '''Get the spatial dimension names for this data slice.
    '''
    # Determine lat/lon conventions for this slice.
    lat_dim = 'lat' if 'lat' in data_slice.dims else 'slat'
    lon_dim = 'lon' if 'lon' in data_slice.dims else 'slon'
    
    return [lat_dim, lon_dim]

#### Create Single Map Plot (Helper Function)

In [6]:
def plotMap(ax, map_slice, date_object=None, member_id=None):
    '''Create a map plot on the given axes, with min/max as text'''

    ax.imshow(map_slice, origin='lower')
    
    spatial_dims = getSpatialDimensionNames(map_slice)
    minval = map_slice.min(dim = spatial_dims)
    maxval = map_slice.max(dim = spatial_dims)
            
    # Format values to have at least 4 digits of precision.
    text_height = 0.17
    ax.text(0.01, text_height, "%4g" % minval, transform=ax.transAxes, fontsize=12)
    ax.text(0.99, text_height, "%4g" % maxval, transform=ax.transAxes, fontsize=12, horizontalalignment='right')
    ax.set_xticks([])
    ax.set_yticks([])
    
    if date_object:
        ax.set_title(date_object.values.astype(str)[:10], fontsize=12)
        
    if member_id:
        ax.set_ylabel(member_id, fontsize=12)
        
    return ax

#### Create Statistical Map Plots Over Multiple Pages

In [7]:
def plot_stat_maps(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 4

    figWidth = 25 
    figHeight = 12 #20

    pp = PdfPages(f'{plotdir}/{store_name}_maps.pdf')

    for pageNum in range(numPages):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0

        # If this is the final page, target empty subplots for hiding.
        removeBlankSubplots = (pageNum == numPages-1) and (numEnsembleMembers < memberStart + numPlotsPerPage)

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), constrained_layout=True)

        for index in np.arange(memberStart, memberEnd):

            mem_id = member_names[index]
            data_slice = getSlice(ds, data_var, mem_id)

            # Try loading the slice so it's read from disk only once.
            data_slice = data_slice.load()

            data_agg = data_slice.min(dim='time')
            plotMap(axs[plot_row_index, 0], data_agg, member_id=mem_id)

            data_agg = data_slice.max(dim='time')
            plotMap(axs[plot_row_index, 1], data_agg)

            data_agg = data_slice.mean(dim='time')
            plotMap(axs[plot_row_index, 2], data_agg)

            data_agg = data_slice.std(dim='time')
            plotMap(axs[plot_row_index, 3], data_agg)

            plot_row_index = plot_row_index + 1

        axs[0, 0].set_title(f'min({data_var})', fontsize=15)
        axs[0, 1].set_title(f'max({data_var})', fontsize=15)
        axs[0, 2].set_title(f'mean({data_var})', fontsize=15)
        axs[0, 3].set_title(f'std({data_var})', fontsize=15)

        if removeBlankSubplots:
            hide_subplots(axs, plot_row_index)

        plt.suptitle(store_name, fontsize=20)
        pp.savefig()
        plt.close()

    pp.close()

### Create Time Series Plots over Multiple Pages
These also mark the locations of missing values.

In [8]:
def plot_timeseries(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 1

    figWidth = 25 
    figHeight = 20

    linewidth = 0.5

    pp = PdfPages(f'{plotdir}/{store_name}_ts.pdf')

    for pageNum in range(numPages):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0
        
        # If this is the final page, target empty subplots for hiding.
        removeBlankSubplots = (pageNum == numPages-1) and (numEnsembleMembers < memberStart + numPlotsPerPage)

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight))

        print(f'Shape of subplots object: {axs.shape}')
        
        
        for index in np.arange(memberStart, memberEnd):

            mem_id = member_names[index]
            data_slice = getSlice(ds, data_var, mem_id)
            spatial_dims = getSpatialDimensionNames(data_slice)

            unit_string = ds[data_var].attrs['units']
            
            # Try loading the slice so it's read from disk only once.
            data_slice = data_slice.load()
            
            min_vals = data_slice.min(dim = spatial_dims)
            max_vals = data_slice.max(dim = spatial_dims)
            mean_vals = data_slice.mean(dim = spatial_dims)
            std_vals = data_slice.std(dim = spatial_dims)

            nan_indexes = np.isnan(min_vals)
            nan_times = ds.time[nan_indexes]

            axs[plot_row_index].plot(ds.time, max_vals, linewidth=linewidth, label='max', color='red')
            axs[plot_row_index].plot(ds.time, min_vals, linewidth=linewidth, label='min', color='blue')
            axs[plot_row_index].plot(ds.time, mean_vals, linewidth=linewidth, label='mean', color='black')
            axs[plot_row_index].fill_between(ds.time, (mean_vals - std_vals), (mean_vals + std_vals), color='grey', 
                         linewidth=0, label='std', alpha=0.5)
            
            ymin, ymax = axs[plot_row_index].get_ylim()
            rug_y = ymin + 0.01*(ymax-ymin)
            axs[plot_row_index].plot(nan_times, [rug_y]*len(nan_times), '|', color='m', label='missing')
            axs[plot_row_index].set_title(mem_id, fontsize=20)
            axs[plot_row_index].legend(loc='upper right')
            axs[plot_row_index].set_ylabel(unit_string)

            plot_row_index = plot_row_index + 1

        plt.suptitle(store_name, fontsize=25)
        plt.tight_layout(pad=10.2, w_pad=3.5, h_pad=3.5)
        
        if removeBlankSubplots:
            hide_subplots(axs, plot_row_index)
            
        pp.savefig()
        plt.close()
    pp.close()


#### Function Producing Maps of First, Middle, Last Timesteps

In [9]:
def getValidDateIndexes(member_slice):
    '''Search for the first and last dates with finite values.'''
    min_values = member_slice.min(dim = ['lat', 'lon'])
    is_finite = np.isfinite(min_values)
    finite_indexes = np.where(is_finite)
    start_index = finite_indexes[0][0]
    end_index = finite_indexes[0][-1]
    #print(f'start ={start_index}, end={end_index}')
    return start_index, end_index


def plot_first_mid_last(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 3

    figWidth = 18 
    figHeight = 12 #20

    pp = PdfPages(f'{plotdir}/{store_name}_fml.pdf')

    for pageNum in range(numPages):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0

        # If this is the final page, target empty subplots for hiding.
        removeBlankSubplots = (pageNum == numPages-1) and (numEnsembleMembers < memberStart + numPlotsPerPage)

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), constrained_layout=True)

        for index in np.arange(memberStart, memberEnd):

            mem_id = member_names[index]
            data_slice = getSlice(ds, data_var, mem_id)
            
            start_index, end_index = getValidDateIndexes(data_slice)
            midDateIndex = np.floor(len(ds.time) / 2).astype(int)

            startDate = ds.time[start_index]
            first_step = data_slice.sel(time=startDate) 
            ax = axs[plot_row_index, 0]
            plotMap(ax, first_step, startDate, mem_id)

            midDate = ds.time[midDateIndex]
            mid_step = data_slice.sel(time=midDate)   
            ax = axs[plot_row_index, 1]
            plotMap(ax, mid_step, midDate)

            endDate = ds.time[end_index]
            last_step = data_slice.sel(time=endDate)            
            ax = axs[plot_row_index, 2]
            plotMap(ax, last_step, endDate)
            
            plot_row_index = plot_row_index + 1
 
        plt.suptitle(store_name, fontsize=20)

        if removeBlankSubplots:
            hide_subplots(axs, plot_row_index)
            
        pp.savefig()
        plt.close()

    pp.close()

### Loop over Zarr Stores in Directory and Make Plots.

In [14]:
# For now, make the Zarr output directory a global variable.
#dirout = '/glade/scratch/bonnland/na-cordex/zarr-demo'
zarr_directory = '/glade/scratch/bonnland/DART/ds345.0/zarr-publish/'
plot_directory = '/glade/scratch/bonnland/DART/ds345.0/zarr-plots/'

p = Path(zarr_directory)
#stores = list(p.rglob("*.zarr"))
stores = list(p.rglob("HR.zarr"))
for store in stores:
    print(f'Opening {store}...')
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print('\n')
    except Exception as e:
        print(e)
        continue
    data_vars = [vname for vname in ds.data_vars]
    data_var = data_vars[0]
    store_name = store.as_posix().split('/')[-1]
    
    # Only produce plots that haven't been created already.  
    plotdir = plot_directory + store_name
    if not os.path.exists(plotdir):
        os.makedirs(plotdir)
    else:
        # Plots exist; skip to the next case.
        #continue
        pass
    
    #plot_stat_maps(ds, data_var, store_name, plotdir)
    #plot_first_mid_last(ds, data_var, store_name, plotdir)
    plot_timeseries(ds, data_var, store_name, plotdir)
    
plt.show()

Opening /glade/scratch/bonnland/DART/ds345.0/zarr-publish/HR.zarr...


Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)
Shape of subplots object: (4,)


### Release the workers.

In [11]:
!date

Tue Jul 20 18:00:32 MDT 2021


In [12]:
!qstat -u bonnland

Job id            Name             User              Time Use S Queue
----------------  ---------------- ----------------  -------- - -----
8634095.chadmin1  st_archive.FCnu  tilmes                   0 H shareex         
8645149.chadmin1  f.e21.B1850.f09  ranfeng                  0 H shareex         
8658478.chadmin1  f.e21.B1850.f09  ranfeng                  0 H shareex         
8663535.chadmin1  st_archive.FWma  nadavis                  0 H shareex         
8675323.chadmin1  f.e21.B1850.f09  ranfeng                  0 H shareex         
8675561.chadmin1  f.e21.B1850.f09  ranfeng                  0 H shareex         
8675724.chadmin1  f.e21.B1850.f09  ranfeng                  0 H shareex         
8795079.chadmin1  Test2.st_archiv  lzastko                  0 H shareex         
8848196.chadmin1  st_archive.f.e2  nadavis                  0 H shareex         
8848305.chadmin1  st_archive.b.e2  nadavis                  0 H shareex         
9007934.chadmin1  st_archive.one_  mohsinn        

In [None]:
cluster.close()

In [None]:
stores

In [None]:
ds = xr.open_zarr(store.as_posix(), consolidated=True)

In [None]:
ds

In [None]:
    data_vars = [vname for vname in ds.data_vars]
data_vars

In [None]:
    data_var = data_vars[0]
    store_name = store.as_posix().split('/')[-1]
    
    # Only produce plots that haven't been created already.  
    plotdir = plot_directory + store_name

In [None]:
plotdir

In [None]:
plt.show()