## Make Diagnostic Plots for NA-CORDEX Zarr Stores

In [None]:
import xarray as xr
import numpy as np

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
from pathlib import Path
import os

import pprint
import json

## Run These Cells for Dask Processing (NEW)

In [None]:
import dask
from dask_jobqueue import PBSCluster

# This line makes the dashboard link work on JupyterHub.
dask.config.set({'distributed.dashboard.link': '/proxy/{port}/status'})

num_jobs = 35
walltime = '6:00:00'

cluster = PBSCluster(cores=1, processes=1, walltime=walltime, queue='casper', 
                     resource_spec='select=1:ncpus=1:mem=10GB',)
cluster.scale(jobs=num_jobs)


from distributed import Client
client = Client(cluster)
cluster

## Define Plot Functions

#### Hide unused subplot panels (Helper Function)

In [None]:
def hide_subplots(axes, start_row):
    '''Given an array of axes and a row index, hide plots with this row index or later.
    Subplot array can be 1-D or 2-D.
    '''
    subplot_ndims = axes.ndim
    if subplot_ndims == 1:
        nrows = len(axes)
        for row in np.arange(start_row, nrows):
            axes[row].axis('off')
    else:
        assert(subplot_ndims == 2)
        (nrows, ncols) = axes.shape
        for row in np.arange(start_row, nrows):
            for col in np.arange(ncols):
                axes[row][col].axis('off')


#### Create Single Map Plot (Helper Function)

In [None]:
def plotMap(ax, map_slice, date_object=None, member_id=None):
    '''Create a map plot on the given axes, with min/max as text'''

    ax.imshow(map_slice, origin='lower')

    minval = map_slice.min(dim = ['lat', 'lon'])
    maxval = map_slice.max(dim = ['lat', 'lon'])

    # Format values to have at least 4 digits of precision.
    ax.text(0.01, 0.03, "%4g" % minval, transform=ax.transAxes, fontsize=12)
    ax.text(0.99, 0.03, "%4g" % maxval, transform=ax.transAxes, fontsize=12, horizontalalignment='right')
    ax.set_xticks([])
    ax.set_yticks([])
    
    if date_object:
        ax.set_title(date_object.values.astype(str)[:10], fontsize=12)
        
    if member_id:
        ax.set_ylabel(member_id, fontsize=12)
        
    return ax

#### Create Statistical Map Plots Over Multiple Pages

In [None]:
def plot_stat_maps(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 4

    figWidth = 25 
    figHeight = 12 #20

    pp = PdfPages(f'{plotdir}/{store_name}_maps.pdf')

    for pageNum in range(numPages):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0

        # If this is the final page, target empty subplots for hiding.
        removeBlankSubplots = (pageNum == numPages-1) and (numEnsembleMembers < memberStart + numPlotsPerPage)

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), constrained_layout=True)

        for index in np.arange(memberStart, memberEnd):
            mem_id = member_names[index]
            data_slice = ds[data_var].sel(member_id=mem_id)

            # Persist the slice so it's read from disk only once.
            # This is faster when data values are reused many times.
            data_slice = data_slice.persist()

            data_agg = data_slice.min(dim='time')
            plotMap(axs[plot_row_index, 0], data_agg, member_id=mem_id)

            data_agg = data_slice.max(dim='time')
            plotMap(axs[plot_row_index, 1], data_agg)

            data_agg = data_slice.mean(dim='time')
            plotMap(axs[plot_row_index, 2], data_agg)

            data_agg = data_slice.std(dim='time')
            plotMap(axs[plot_row_index, 3], data_agg)

            plot_row_index = plot_row_index + 1

        axs[0, 0].set_title(f'min({data_var})', fontsize=15)
        axs[0, 1].set_title(f'max({data_var})', fontsize=15)
        axs[0, 2].set_title(f'mean({data_var})', fontsize=15)
        axs[0, 3].set_title(f'std({data_var})', fontsize=15)

        #plt.colorbar(pcm0, ax = axs[:, 0], location='bottom', shrink=0.9, pad=0.02)
        #plt.colorbar(pcm1, ax = axs[:, 1], location='bottom', shrink=0.9, pad=0.02)
        #plt.colorbar(pcm2, ax = axs[:, 2], location='bottom', shrink=0.9, pad=0.02)
        #plt.colorbar(pcm3, ax = axs[:, 3], location='bottom', shrink=0.9, pad=0.02)

        if removeBlankSubplots:
            hide_subplots(axs, plot_row_index)

        plt.suptitle(store_name, fontsize=20)
        pp.savefig()
        plt.close()

    pp.close()

### Create Time Series Plots over Multiple Pages
These also mark the locations of missing values.

In [None]:
def plot_timeseries(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 1

    figWidth = 25 
    figHeight = 20

    linewidth = 0.5

    pp = PdfPages(f'{plotdir}/{store_name}_ts.pdf')

    for pageNum in range(numPages):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0
        
        # If this is the final page, target empty subplots for hiding.
        removeBlankSubplots = (pageNum == numPages-1) and (numEnsembleMembers < memberStart + numPlotsPerPage)

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight))
        #fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), sharey='col')

        print(f'Shape of subplots object: {axs.shape}')
        
        
        for index in np.arange(memberStart, memberEnd):
            mem_id = member_names[index]
            data_slice = ds[data_var].sel(member_id=mem_id)
            unit_string = ds[data_var].attrs['units']
            
            # Persist the slice so it's read from disk only once.
            # This is faster when data values are reused many times.
            data_slice = data_slice.persist()

            min_vals = data_slice.min(dim = ['lat', 'lon'])
            max_vals = data_slice.max(dim = ['lat', 'lon'])
            mean_vals = data_slice.mean(dim = ['lat', 'lon'])
            std_vals = data_slice.std(dim = ['lat', 'lon'])

            nan_indexes = np.isnan(min_vals)
            nan_times = ds.time[nan_indexes]

            axs[plot_row_index].plot(ds.time, max_vals, linewidth=linewidth, label='max', color='red')
            axs[plot_row_index].plot(ds.time, min_vals, linewidth=linewidth, label='min', color='blue')
            axs[plot_row_index].plot(ds.time, mean_vals, linewidth=linewidth, label='mean', color='black')
            axs[plot_row_index].fill_between(ds.time, (mean_vals - std_vals), (mean_vals + std_vals), color='grey', 
                         linewidth=0, label='std', alpha=0.5)
            
            ymin, ymax = axs[plot_row_index].get_ylim()
            rug_y = ymin + 0.01*(ymax-ymin)
            axs[plot_row_index].plot(nan_times, [rug_y]*len(nan_times), '|', color='m', label='missing')
            axs[plot_row_index].set_title(mem_id, fontsize=20)
            axs[plot_row_index].legend(loc='upper right')
            axs[plot_row_index].set_ylabel(unit_string)

            plot_row_index = plot_row_index + 1

        plt.suptitle(store_name, fontsize=25)
        plt.tight_layout(pad=10.2, w_pad=3.5, h_pad=3.5)
        
        if removeBlankSubplots:
            hide_subplots(axs, plot_row_index)
            
        pp.savefig()
        plt.close()
    pp.close()


#### Function Producing Maps of First, Middle, Last Timesteps

In [None]:
def getValidDateIndexes(member_slice):
    '''Search for the first and last dates with finite values.'''
    min_values = member_slice.min(dim = ['lat', 'lon'])
    is_finite = np.isfinite(min_values)
    finite_indexes = np.where(is_finite)
    start_index = finite_indexes[0][0]
    end_index = finite_indexes[0][-1]
    #print(f'start ={start_index}, end={end_index}')
    return start_index, end_index


def plot_first_mid_last(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 3

    figWidth = 18 
    figHeight = 12 #20

    pp = PdfPages(f'{plotdir}/{store_name}_fml.pdf')

    for pageNum in range(numPages):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0

        # If this is the final page, target empty subplots for hiding.
        removeBlankSubplots = (pageNum == numPages-1) and (numEnsembleMembers < memberStart + numPlotsPerPage)

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), constrained_layout=True)

        for index in np.arange(memberStart, memberEnd):
            mem_id = member_names[index]
            data_slice = ds[data_var].sel(member_id=mem_id)
            
            # Persist the slice so it's read from disk only once.
            # This is faster when data values are reused many times.
            data_slice = data_slice.persist()

            start_index, end_index = getValidDateIndexes(data_slice)
            midDateIndex = np.floor(len(ds.time) / 2).astype(int)

            startDate = ds.time[start_index]
            first_step = data_slice.sel(time=startDate) 
            ax = axs[plot_row_index, 0]
            plotMap(ax, first_step, startDate, mem_id)

            midDate = ds.time[midDateIndex]
            mid_step = data_slice.sel(time=midDate)   
            ax = axs[plot_row_index, 1]
            plotMap(ax, mid_step, midDate)

            endDate = ds.time[end_index]
            last_step = data_slice.sel(time=endDate)            
            ax = axs[plot_row_index, 2]
            plotMap(ax, last_step, endDate)
            
            plot_row_index = plot_row_index + 1
 
        plt.suptitle(store_name, fontsize=20)

        if removeBlankSubplots:
            hide_subplots(axs, plot_row_index)
            
        pp.savefig()
        plt.close()

    pp.close()

### Loop over Zarr Stores in Directory and Make Plots.

In [None]:
# For now, make the Zarr output directory a global variable.
#dirout = '/glade/scratch/bonnland/na-cordex/zarr-demo'
#zarr_directory = '/glade/scratch/bonnland/na-cordex/zarr/'
zarr_directory = '/glade/scratch/bonnland/na-cordex/zarr-publish/'
plot_directory = '/glade/scratch/bonnland/na-cordex/zarr-plots/'

p = Path(zarr_directory)
stores = list(p.rglob("*.zarr"))
#stores = list(p.rglob("uas.rcp85.*.zarr"))
for store in stores:
    print(f'Opening {store}...')
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print('\n')
    except Exception as e:
        print(e)
        continue
    data_vars = [vname for vname in ds.data_vars]
    data_var = data_vars[0]
    store_name = store.as_posix().split('/')[-1]
    
    # Only produce plots that haven't been created already.  
    plotdir = plot_directory + store_name
    if not os.path.exists(plotdir):
        os.makedirs(plotdir)
    else:
        # Plots exist; skip to the next case.
        del ds
        continue
    
    plot_stat_maps(ds, data_var, store_name, plotdir)
    plot_first_mid_last(ds, data_var, store_name, plotdir)
    plot_timeseries(ds, data_var, store_name, plotdir)
    
    # See if we can avoid memory leaks in the lab session
    del ds
    

In [None]:
plt.show()

### Release the workers.

In [None]:
!date

In [None]:
cluster.close()

## Unused Plot Functions

### Function producing Super-Wide Time Series Plots. 

In [None]:
def plot_timeseries_wide(ds, data_var, store_name):
    # Generate super-wide plot of individual time series. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotRows = numEnsembleMembers
    numPlotCols = 4

    figWidth = 200 #100 
    figHeight = 80

    linestyle = 'k.'
    linewidth = 0.5
    markersize = 0.5

    fig, axs = plt.subplots(numPlotRows, numPlotCols, figsize=(figWidth,figHeight), sharey='col')

    for index in range(numEnsembleMembers):
        mem_id = member_names[index]
        data_slice = ds[data_var].sel(member_id=mem_id)

        data_agg = data_slice.min(dim=['lat', 'lon'])
        axs[index, 0].plot(ds.time, data_agg, linewidth=linewidth)
        axs[index, 0].plot(ds.time, data_agg, linestyle, markersize=markersize)
        axs[index, 0].set_ylabel(mem_id, fontsize=15)

        data_agg = data_slice.max(dim=['lat', 'lon'])
        axs[index, 1].plot(ds.time, data_agg, linewidth=linewidth)
        axs[index, 1].plot(ds.time, data_agg, linestyle, markersize=markersize)
        axs[index, 1].set_ylabel(mem_id, fontsize=15)

        data_agg = data_slice.mean(dim=['lat', 'lon'])
        axs[index, 2].plot(ds.time, data_agg, linewidth=linewidth)
        axs[index, 2].plot(ds.time, data_agg, linestyle, markersize=markersize)
        axs[index, 2].set_ylabel(mem_id, fontsize=15)

        data_agg = data_slice.std(dim=['lat', 'lon'])
        axs[index, 3].plot(ds.time, data_agg, linewidth=linewidth)
        axs[index, 3].plot(ds.time, data_agg, linestyle, markersize=markersize)
        axs[index, 3].set_ylabel(mem_id, fontsize=15)
       
    
    axs[0, 0].set_title(f'min({var})', fontsize=40)
    axs[0, 1].set_title(f'max({var})', fontsize=40)
    axs[0, 2].set_title(f'mean({var})', fontsize=40)
    axs[0, 3].set_title(f'std({var})', fontsize=40)
    

    plt.suptitle(store, fontsize=50)
    plt.tight_layout(pad=20.2, w_pad=5.5, h_pad=5.5)
    plt.savefig(f'{store_name}_ts.pdf')



#### Function Producing Map Plots Over a SINGLE Page

In [None]:
def plot_maps_OLD(ds, data_var, store_name):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)

    numPlotRows = numEnsembleMembers
    numPlotCols = 4

    figWidth = 17 
    figHeight = 35

    fig, axs = plt.subplots(numPlotRows, numPlotCols, figsize=(figWidth,figHeight), constrained_layout=True)

    for index in range(numEnsembleMembers):
        mem_id = member_names[index]
        data_slice = ds[data_var].sel(member_id=mem_id)

        data_agg = data_slice.min(dim='time')
        pcm0 = axs[index, 0].imshow(data_agg, origin='lower')
        axs[index, 0].set_ylabel(mem_id, fontsize=8)

        data_agg = data_slice.max(dim='time')
        pcm1 = axs[index, 1].imshow(data_agg, origin='lower')
        axs[index, 1].set_ylabel(mem_id, fontsize=8)

        data_agg = data_slice.mean(dim='time')
        pcm2 = axs[index, 2].imshow(data_agg, origin='lower')
        axs[index, 2].set_ylabel(mem_id, fontsize=8)

        data_agg = data_slice.std(dim='time')
        pcm3 = axs[index, 3].imshow(data_agg, origin='lower')
        axs[index, 3].set_ylabel(mem_id, fontsize=8)
       
    
    axs[0, 0].set_title(f'min({var})', fontsize=15)
    axs[0, 1].set_title(f'max({var})', fontsize=15)
    axs[0, 2].set_title(f'mean({var})', fontsize=15)
    axs[0, 3].set_title(f'std({var})', fontsize=15)
    
    plt.colorbar(pcm0, ax = axs[:, 0], location='bottom', shrink=0.7)
    plt.colorbar(pcm1, ax = axs[:, 1], location='bottom', shrink=0.7)
    plt.colorbar(pcm2, ax = axs[:, 2], location='bottom', shrink=0.7)
    plt.colorbar(pcm3, ax = axs[:, 3], location='bottom', shrink=0.7)

    plt.suptitle(store_name, fontsize=20)
    plt.savefig(f'{store_name}_maps.pdf')
