## Make Detailed, Individual Diagnostic Plots for NA-CORDEX Zarr Stores

In [None]:
import xarray as xr
import numpy as np

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
from pathlib import Path
import os

import pprint
import json

### Use Dask to Speed up Computations

In [None]:
import dask
from ncar_jobqueue import NCARCluster

# Processes is processes PER CORE.
# This one works fine.
#cluster = NCARCluster(cores=15, processes=1, memory='100GB', project='STDD0003')
# This one also works, but occasionally hangs near the end.
#cluster = NCARCluster(cores=10, processes=1, memory='50GB', project='STDD0003')

num_jobs = 10  #25
walltime = "4:00:00"
cluster = NCARCluster(cores=num_jobs, processes=1, memory='30GB', project='STDD0003', walltime=walltime)
cluster.scale(jobs=num_jobs)

from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster

## Plot Helper Functions

#### Hide unused subplot panels

In [None]:
def hide_subplots(axes, start_row):
    '''Given an array of axes and a row index, hide plots with this row index or later.
    Subplot array can be 1-D or 2-D.
    '''
    subplot_ndims = axes.ndim
    if subplot_ndims == 1:
        nrows = len(axes)
        for row in np.arange(start_row, nrows):
            axes[row].axis('off')
    else:
        assert(subplot_ndims == 2)
        (nrows, ncols) = axes.shape
        for row in np.arange(start_row, nrows):
            for col in np.arange(ncols):
                axes[row][col].axis('off')


#### Find subplot bounding box extent

In [None]:
from matplotlib.transforms import Bbox

def full_extent(ax, pad=0.0):
    """Get the full window extent of a subplot axis."""
    # For text objects, we need to draw the figure first, otherwise the extents
    # are undefined.
    ax.figure.canvas.draw()
    items = ax.get_xticklabels() + ax.get_yticklabels() 
    items += [ax, ax.title, ax.xaxis.label, ax.yaxis.label]
    bbox = Bbox.union([item.get_window_extent() for item in items])

    return bbox.expanded(1.0 + pad, 1.0 + pad)

#### Save subplot to PNG

In [None]:
def save_as_png(fig, plotdir, store_name, member_id, plot_type, dpi=400):
    '''Find the bounding box for the subplot and produce a separate PNG image for this plot.'''
    filename = f"{plotdir}/{member_id.replace('.', '+')}.{store_name.replace('zarr', plot_type)}.png"
    fig.savefig(filename, dpi=dpi)

#### Create Single Map Plot (Helper Function)

In [None]:
def plotMap(ax, map_slice, date_object=None, member_id=None):
    '''Create a map plot on the given axes, with min/max as text'''

    ax.imshow(map_slice, origin='lower')

    minval = map_slice.min(dim = ['lat', 'lon'])
    maxval = map_slice.max(dim = ['lat', 'lon'])

    # Format values to have at least 4 digits of precision.
    ax.text(0.01, 0.03, "%4g" % minval, transform=ax.transAxes, fontsize=12)
    ax.text(0.99, 0.03, "%4g" % maxval, transform=ax.transAxes, fontsize=12, horizontalalignment='right')
    ax.set_xticks([])
    ax.set_yticks([])
    
    if date_object:
        ax.set_title(date_object.values.astype(str)[:10], fontsize=12)
        
    if member_id:
        ax.set_ylabel(member_id, fontsize=12)
        
    return ax

## Main Plot Functions

#### Create Statistical Map Plots Over Multiple Pages

In [None]:
def plot_stat_maps(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 4

    figWidth = 25 
    figHeight = 12 #20

    pp = PdfPages(f'{plotdir}/{store_name}_maps.pdf')

    for pageNum in range(numPages):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0

        # If this is the final page, target empty subplots for hiding.
        removeBlankSubplots = (pageNum == numPages-1) and (numEnsembleMembers < memberStart + numPlotsPerPage)

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), constrained_layout=True)

        for index in np.arange(memberStart, memberEnd):
            mem_id = member_names[index]
            data_slice = ds[data_var].sel(member_id=mem_id)

            data_agg = data_slice.min(dim='time')
            plotMap(axs[plot_row_index, 0], data_agg, member_id=mem_id)
            save_as_png(fig, axs[plot_row_index, 0], plotdir, store_name, mem_id, 'min')

            data_agg = data_slice.max(dim='time')
            plotMap(axs[plot_row_index, 1], data_agg)
            save_as_png(fig, axs[plot_row_index, 1], plotdir, store_name, mem_id, 'max')

            data_agg = data_slice.mean(dim='time')
            plotMap(axs[plot_row_index, 2], data_agg)
            save_as_png(fig, axs[plot_row_index, 2], plotdir, store_name, mem_id, 'mean')

            data_agg = data_slice.std(dim='time')
            plotMap(axs[plot_row_index, 3], data_agg)
            save_as_png(fig, axs[plot_row_index, 3], plotdir, store_name, mem_id, 'std')

            plot_row_index = plot_row_index + 1

        axs[0, 0].set_title(f'min({data_var})', fontsize=15)
        axs[0, 1].set_title(f'max({data_var})', fontsize=15)
        axs[0, 2].set_title(f'mean({data_var})', fontsize=15)
        axs[0, 3].set_title(f'std({data_var})', fontsize=15)

        #plt.colorbar(pcm0, ax = axs[:, 0], location='bottom', shrink=0.9, pad=0.02)
        #plt.colorbar(pcm1, ax = axs[:, 1], location='bottom', shrink=0.9, pad=0.02)
        #plt.colorbar(pcm2, ax = axs[:, 2], location='bottom', shrink=0.9, pad=0.02)
        #plt.colorbar(pcm3, ax = axs[:, 3], location='bottom', shrink=0.9, pad=0.02)

        if removeBlankSubplots:
            hide_subplots(axs, plot_row_index)

        plt.suptitle(store_name, fontsize=20)
        pp.savefig()
        plt.close()

    pp.close()

### Create Time Series Plots over Multiple Pages
These also mark the locations of missing values.

In [None]:
def plot_timeseries(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    figWidth = 25 
    figHeight = 6

    linewidth = 0.5  
    dpi = 700
        
    fig = plt.figure(figsize=(figWidth, figHeight))
    
    for mem_id in member_names:
        data_slice = ds[data_var].sel(member_id=mem_id)
        unit_string = ds[data_var].attrs['units']
            
        min_vals = data_slice.min(dim = ['lat', 'lon'])
        max_vals = data_slice.max(dim = ['lat', 'lon'])
        mean_vals = data_slice.mean(dim = ['lat', 'lon'])
        std_vals = data_slice.std(dim = ['lat', 'lon'])

        nan_indexes = np.isnan(min_vals)
        nan_times = ds.time[nan_indexes]

        plt.clf()
        plt.plot(ds.time, min_vals, linewidth=linewidth, label='min', color='blue')
        plt.plot(ds.time, max_vals, linewidth=linewidth, label='max', color='red')
        plt.plot(ds.time, mean_vals, linewidth=linewidth, label='mean', color='black')
        plt.fill_between(ds.time, (mean_vals - std_vals), (mean_vals + std_vals), color='grey', 
                         linewidth=0, label='std', alpha=0.5)
            
        ymin, ymax = plt.ylim()
        rug_y = ymin + 0.01*(ymax-ymin)
        plt.plot(nan_times, [rug_y]*len(nan_times), '|', color='m', label='missing')
        plt.title(mem_id, fontsize=20)
        
        # Place the legend outside the plot, near upper right corner.
        plt.legend(bbox_to_anchor=(1.02,1), loc="upper left")
        plt.ylabel(unit_string)

        save_as_png(fig, plotdir, store_name, mem_id, 'ts', dpi)





#### Function Producing Maps of First, Middle, Last Timesteps

In [None]:
def getValidDateIndexes(member_slice):
    '''Search for the first and last dates with finite values.'''
    min_values = member_slice.min(dim = ['lat', 'lon'])
    is_finite = np.isfinite(min_values)
    finite_indexes = np.where(is_finite)
    start_index = finite_indexes[0][0]
    end_index = finite_indexes[0][-1]
    #print(f'start ={start_index}, end={end_index}')
    return start_index, end_index


def plot_first_mid_last(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 3

    figWidth = 18 
    figHeight = 12 #20

    pp = PdfPages(f'{plotdir}/{store_name}_fml.pdf')

    for pageNum in range(numPages):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0

        # If this is the final page, target empty subplots for hiding.
        removeBlankSubplots = (pageNum == numPages-1) and (numEnsembleMembers < memberStart + numPlotsPerPage)

        # Plot the aggregate statistics across time.
        plt.figure(figsize=(figWidth, figHeight))

        for index in np.arange(memberStart, memberEnd):
            mem_id = member_names[index]
            data_slice = ds[data_var].sel(member_id=mem_id)
            
            start_index, end_index = getValidDateIndexes(data_slice)
            midDateIndex = np.floor(len(ds.time) / 2).astype(int)

            startDate = ds.time[start_index]
            first_step = data_slice.sel(time=startDate) 
            ax = axs[plot_row_index, 0]
            plotMap(ax, first_step, startDate, mem_id)
            save_as_png(fig, axs[plot_row_index, 0], plotdir, store_name, mem_id, 'first')

            midDate = ds.time[midDateIndex]
            mid_step = data_slice.sel(time=midDate)   
            ax = axs[plot_row_index, 1]
            plotMap(ax, mid_step, midDate)
            save_as_png(fig, axs[plot_row_index, 1], plotdir, store_name, mem_id, 'middle')

            endDate = ds.time[end_index]
            last_step = data_slice.sel(time=endDate)            
            ax = axs[plot_row_index, 2]
            plotMap(ax, last_step, endDate)
            save_as_png(fig, axs[plot_row_index, 2], plotdir, store_name, mem_id, 'last')
            
            plot_row_index = plot_row_index + 1
 
        plt.suptitle(store_name, fontsize=20)

        if removeBlankSubplots:
            hide_subplots(axs, plot_row_index)
            
        pp.savefig()
        plt.close()

    pp.close()

### Loop over Zarr Stores in Directory and Make Plots.

In [None]:
# For now, make the Zarr output directory a global variable.
#dirout = '/glade/scratch/bonnland/na-cordex/zarr-demo'
zarr_directory = '/glade/scratch/bonnland/na-cordex/zarr-publish/'
plot_directory = '/glade/scratch/bonnland/na-cordex/zarr-plots-test/'

p = Path(zarr_directory)
stores = list(p.rglob("*.zarr"))
#stores = list(p.rglob("uas.rcp85.*.zarr"))
for store in stores:
    print(f'Opening {store}...')
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print('\n')
    except Exception as e:
        print(e)
        continue
    data_vars = [vname for vname in ds.data_vars]
    data_var = data_vars[0]
    store_name = store.as_posix().split('/')[-1]
    
    # Only produce plots that haven't been created already.  
    plotdir = plot_directory + store_name
    if not os.path.exists(plotdir):
        os.makedirs(plotdir)
    else:
        # Plots exist; skip to the next case.
        continue
    
    #plot_stat_maps(ds, data_var, store_name, plotdir)
    #plot_first_mid_last(ds, data_var, store_name, plotdir)
    plot_timeseries(ds, data_var, store_name, plotdir)
    

### Release the workers.

In [None]:
!date

In [None]:
cluster.close()