## Make Diagnostic Plots for NA-CORDEX Zarr Stores

In [None]:
import xarray as xr
import numpy as np

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns

import pprint
import json

### Use Dask to Speed up Computations

In [None]:
import dask
from ncar_jobqueue import NCARCluster

# Processes is processes PER CORE.
# This one works fine.
#cluster = NCARCluster(cores=15, processes=1, memory='100GB', project='STDD0003')
# This one also works, but occasionally hangs near the end.
#cluster = NCARCluster(cores=10, processes=1, memory='50GB', project='STDD0003')

num_jobs=20
cluster = NCARCluster(cores=num_jobs, processes=1, memory='10GB', project='STDD0003')
cluster.scale(jobs=num_jobs)

from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster

## Define Plot Functions

### Function producing Super-Wide Time Series Plots. 

In [None]:
def plot_timeseries(ds, data_var, store_name):
    # Generate super-wide plot of individual time series. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotRows = numEnsembleMembers
    numPlotCols = 4

    figWidth = 200 #100 
    figHeight = 80

    linestyle = 'k.'
    linewidth = 0.5
    markersize = 0.5

    fig, axs = plt.subplots(numPlotRows, numPlotCols, figsize=(figWidth,figHeight), sharey='col')

    for index in range(numEnsembleMembers):
        mem_id = member_names[index]
        data_slice = ds[data_var].sel(member_id=mem_id)

        data_agg = data_slice.min(dim=['lat', 'lon'])
        axs[index, 0].plot(ds.time, data_agg, linewidth=linewidth)
        axs[index, 0].plot(ds.time, data_agg, linestyle, markersize=markersize)
        axs[index, 0].set_ylabel(mem_id, fontsize=15)

        data_agg = data_slice.max(dim=['lat', 'lon'])
        axs[index, 1].plot(ds.time, data_agg, linewidth=linewidth)
        axs[index, 1].plot(ds.time, data_agg, linestyle, markersize=markersize)
        axs[index, 1].set_ylabel(mem_id, fontsize=15)

        data_agg = data_slice.mean(dim=['lat', 'lon'])
        axs[index, 2].plot(ds.time, data_agg, linewidth=linewidth)
        axs[index, 2].plot(ds.time, data_agg, linestyle, markersize=markersize)
        axs[index, 2].set_ylabel(mem_id, fontsize=15)

        data_agg = data_slice.std(dim=['lat', 'lon'])
        axs[index, 3].plot(ds.time, data_agg, linewidth=linewidth)
        axs[index, 3].plot(ds.time, data_agg, linestyle, markersize=markersize)
        axs[index, 3].set_ylabel(mem_id, fontsize=15)
       
    
    axs[0, 0].set_title(f'min({var})', fontsize=40)
    axs[0, 1].set_title(f'max({var})', fontsize=40)
    axs[0, 2].set_title(f'mean({var})', fontsize=40)
    axs[0, 3].set_title(f'std({var})', fontsize=40)
    

    plt.suptitle(store, fontsize=50)
    plt.tight_layout(pad=20.2, w_pad=5.5, h_pad=5.5)
    plt.savefig(f'{store_name}_ts.pdf')



## Create Spatial Plots from Zarr Store

#### Function Producing Map Plots Over Multiple Pages

In [None]:
def plot_maps(ds, data_var, store_name):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 4

    figWidth = 25 
    figHeight = 12 #20

    linewidth = 0.5

    pp = PdfPages(f'{store_name}_maps.pdf')

    for pageNum in range(numPlotsPerPage):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), constrained_layout=True)

        for index in np.arange(memberStart, memberEnd):
            mem_id = member_names[index]
            data_slice = ds[data_var].sel(member_id=mem_id)

            data_agg = data_slice.min(dim='time')
            pcm0 = axs[plot_row_index, 0].imshow(data_agg, origin='lower')
            axs[plot_row_index, 0].set_ylabel(mem_id, fontsize=8)

            data_agg = data_slice.max(dim='time')
            pcm1 = axs[plot_row_index, 1].imshow(data_agg, origin='lower')
            axs[plot_row_index, 1].set_ylabel(mem_id, fontsize=8)

            data_agg = data_slice.mean(dim='time')
            pcm2 = axs[plot_row_index, 2].imshow(data_agg, origin='lower')
            axs[plot_row_index, 2].set_ylabel(mem_id, fontsize=8)

            data_agg = data_slice.std(dim='time')
            pcm3 = axs[plot_row_index, 3].imshow(data_agg, origin='lower')
            axs[plot_row_index, 3].set_ylabel(mem_id, fontsize=8)

            plot_row_index = plot_row_index + 1

        axs[0, 0].set_title(f'min({data_var})', fontsize=15)
        axs[0, 1].set_title(f'max({data_var})', fontsize=15)
        axs[0, 2].set_title(f'mean({data_var})', fontsize=15)
        axs[0, 3].set_title(f'std({data_var})', fontsize=15)

        plt.colorbar(pcm0, ax = axs[:, 0], location='bottom', shrink=0.7)
        plt.colorbar(pcm1, ax = axs[:, 1], location='bottom', shrink=0.7)
        plt.colorbar(pcm2, ax = axs[:, 2], location='bottom', shrink=0.7)
        plt.colorbar(pcm3, ax = axs[:, 3], location='bottom', shrink=0.7)

        plt.suptitle(store, fontsize=20)
        pp.savefig()

    pp.close()

#### Function Producing Map Plots Over a SINGLE Page

In [None]:
def plot_maps_OLD(ds, data_var, store_name):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)

    numPlotRows = numEnsembleMembers
    numPlotCols = 4

    figWidth = 17 
    figHeight = 35

    fig, axs = plt.subplots(numPlotRows, numPlotCols, figsize=(figWidth,figHeight), constrained_layout=True)

    for index in range(numEnsembleMembers):
        mem_id = member_names[index]
        data_slice = ds[data_var].sel(member_id=mem_id)

        data_agg = data_slice.min(dim='time')
        pcm0 = axs[index, 0].imshow(data_agg, origin='lower')
        axs[index, 0].set_ylabel(mem_id, fontsize=8)

        data_agg = data_slice.max(dim='time')
        pcm1 = axs[index, 1].imshow(data_agg, origin='lower')
        axs[index, 1].set_ylabel(mem_id, fontsize=8)

        data_agg = data_slice.mean(dim='time')
        pcm2 = axs[index, 2].imshow(data_agg, origin='lower')
        axs[index, 2].set_ylabel(mem_id, fontsize=8)

        data_agg = data_slice.std(dim='time')
        pcm3 = axs[index, 3].imshow(data_agg, origin='lower')
        axs[index, 3].set_ylabel(mem_id, fontsize=8)
       
    
    axs[0, 0].set_title(f'min({var})', fontsize=15)
    axs[0, 1].set_title(f'max({var})', fontsize=15)
    axs[0, 2].set_title(f'mean({var})', fontsize=15)
    axs[0, 3].set_title(f'std({var})', fontsize=15)
    
    plt.colorbar(pcm0, ax = axs[:, 0], location='bottom', shrink=0.7)
    plt.colorbar(pcm1, ax = axs[:, 1], location='bottom', shrink=0.7)
    plt.colorbar(pcm2, ax = axs[:, 2], location='bottom', shrink=0.7)
    plt.colorbar(pcm3, ax = axs[:, 3], location='bottom', shrink=0.7)

    plt.suptitle(store, fontsize=20)
    plt.savefig(f'{store_name}_maps.pdf')


### Alternate Function for Time Series Plots
These also mark the locations of missing values.

In [None]:
def plot_timeseries_alt(ds, data_var, store_name):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    numPlotsPerPage = 4
    numPages = np.ceil(numEnsembleMembers / numPlotsPerPage).astype(int)
    numPlotCols = 1

    figWidth = 25 
    figHeight = 20

    linewidth = 0.5

    pp = PdfPages(f'{store_name}_ts2.pdf')

    for pageNum in range(numPlotsPerPage):

        memberStart = pageNum * numPlotsPerPage
        memberEnd = np.min((memberStart + numPlotsPerPage, numEnsembleMembers))
        plot_row_index = 0

        # Plot the aggregate statistics across time.
        fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight))
        #fig, axs = plt.subplots(numPlotsPerPage, numPlotCols, figsize=(figWidth, figHeight), sharey='col')

        for index in np.arange(memberStart, memberEnd):
            mem_id = member_names[index]
            data_slice = ds[data_var].sel(member_id=mem_id)
            
            min_vals = data_slice.min(dim = ['lat', 'lon'])
            max_vals = data_slice.max(dim = ['lat', 'lon'])
            mean_vals = data_slice.mean(dim = ['lat', 'lon'])
            std_vals = data_slice.std(dim = ['lat', 'lon'])

            nan_indexes = np.isnan(min_vals)
            nan_times = ds.time[nan_indexes]

            axs[plot_row_index].plot(ds.time, min_vals, linewidth=linewidth, label='min')
            axs[plot_row_index].plot(ds.time, max_vals, linewidth=linewidth, label='max')
            axs[plot_row_index].plot(ds.time, mean_vals, linewidth=linewidth, label='mean')
            axs[plot_row_index].plot(ds.time, std_vals, linewidth=linewidth, label='std')
            
            ymin, ymax = axs[plot_row_index].get_ylim()
            rug_y = ymin + 0.01*(ymax-ymin)
            axs[plot_row_index].plot(nan_times, [rug_y]*len(nan_times), '|', color='m')
            #sns.rugplot(nan_times, ax=axs[plot_row_index], legend=False)
            axs[plot_row_index].set_title(mem_id, fontsize=20)
            axs[plot_row_index].legend(loc='lower right')

            plot_row_index = plot_row_index + 1

        plt.suptitle(store, fontsize=25)
        plt.tight_layout(pad=10.2, w_pad=3.5, h_pad=3.5)

        pp.savefig()

    pp.close()

### Loop over Zarr Stores in Directory and Make Plots.

In [None]:
# For now, make the Zarr output directory a global variable.
dirout = '/glade/scratch/bonnland/na-cordex/zarr-demo'

p = Path(dirout)
stores = list(p.rglob("*.zarr"))
#stores = list(p.rglob("huss.hist*.zarr"))
for store in stores:
    print(f'Opening {store}...')
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print('\n')
    except Exception as e:
        print(e)
        continue
    data_vars = [vname for vname in ds.data_vars]
    data_var = data_vars[0]
    store_name = store.as_posix()
    store_name = store_name.split('/')[-1]
    
    plot_maps(ds, data_var, store_name)
    plot_timeseries_alt(ds, data_var, store_name)
    

In [None]:
!date

### Release the workers.

In [None]:
cluster.close()