## Make Detailed, Individual Diagnostic Plots for NA-CORDEX Zarr Stores

In [None]:
import xarray as xr
import numpy as np

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
from pathlib import Path
import os

import pprint
import json

### Use Dask to Speed up Computations

In [None]:
import dask
from ncar_jobqueue import NCARCluster

# Processes is processes PER CORE.
# This one works fine.
#cluster = NCARCluster(cores=15, processes=1, memory='100GB', project='STDD0003')
# This one also works, but occasionally hangs near the end.
#cluster = NCARCluster(cores=10, processes=1, memory='50GB', project='STDD0003')

num_jobs = 20  #10  #25
walltime = "2:00:00"
cluster = NCARCluster(cores=num_jobs, processes=1, memory='30GB', project='STDD0003', walltime=walltime)
cluster.scale(jobs=num_jobs)

from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster

## Plot Helper Functions

#### Extract XArray metadata

In [None]:
def get_metadata_value(ds, field_name, member_id):
    '''Extract metadata from an XArray dataset, whether it is in dictionary or string form.'''

    try: 
        metadata = json.loads(ds.attrs[field_name])
    except Exception as e:
        metadata = ds.attrs[field_name]

    if isinstance(metadata, dict):
        # Check for deeper dictionary structure
        if 'hist' in metadata.keys():
            value = metadata['hist'][member_id]
        else:
            value = metadata[member_id]
    else:
        value = metadata
        
    return value

#### Save subplot to PNG

In [None]:
def save_as_png(fig, plotdir, store_name, member_id, plot_type, dpi=400):
    '''Save current figure as PNG and clear the plot.'''
    
    filename = f"{plotdir}/{member_id.replace('.', '+')}.{store_name.replace('zarr', plot_type)}.png"
    fig.savefig(filename, dpi=dpi)
    fig.clf()

#### Create Single Map Plot

In [None]:
def plotMap(map_slice, lat_lon_values, date_object=None, member_id=None, ds=None, titleText=None):
    '''Create a map plot on the given figure'''

    colorbar_shrink = 0.6
    colorbar_pad = 0.1

    plt.imshow(map_slice, origin='lower')

    titleString = ''

    if ds:
        titleString = get_metadata_value(ds, 'title', member_id)
        data_var = list(ds.data_vars)[0]
        units = ds.data_vars[data_var].attrs['units']
        
    if date_object:
        titleString = titleString + f'\nDate: {date_object.values.astype(str)[:10]},  Units: {units}'
        
    elif member_id:
        startYear = ds.time.values[0].astype('datetime64[Y]')
        endYear = ds.time.values[-1].astype('datetime64[Y]')
        titleString = titleString + f'\n{titleText} Over Period {startYear}-{endYear},  Units: {units}'
     
    plt.title(titleString, fontsize=14)

    plt.ylabel('Latitude')
    plt.xlabel('Longitude')
    plt.xticks(lat_lon_values['lonIndexes'], lat_lon_values['lonLabels'])
    plt.yticks(lat_lon_values['latIndexes'], lat_lon_values['latLabels'])

    plt.colorbar(orientation='horizontal', shrink=colorbar_shrink, pad=colorbar_pad)    

#### Return Lat/Lon Tick Mark Locations and Labels

In [None]:
def getLatLonValues(ds, latStart, latSpacing, lonStart, lonSpacing):
    ''' Return the lat/lon coordinates from a dataset with a given start and spacing.
        Start and spacing values are given as *percentages*, to allow for variable grid resolutions.'''
    lat_lon_values = {}
    
    latLength = ds.lat.size
    startIndex = np.round(latStart * latLength).astype(int)
    spacing = np.round(latSpacing * latLength).astype(int)
    lat_lon_values['latIndexes'] = np.arange(startIndex, latLength, spacing)    
    latValues = ds.lat.values[startIndex:latLength:spacing]
    lat_lon_values['latLabels'] = ["%.0f" % number for number in latValues]

    lonLength = ds.lon.size
    startIndex = np.round(lonStart * lonLength).astype(int)
    spacing = np.round(lonSpacing * lonLength).astype(int)
    lat_lon_values['lonIndexes'] = np.arange(startIndex, lonLength, spacing)
    lonValues = ds.lon.values[startIndex:lonLength:spacing]
    lat_lon_values['lonLabels'] = ["%.0f" % number for number in lonValues]
    
    return lat_lon_values

## Main Plot Functions

### Create Time Series Plots over Multiple Pages
These also mark the locations of missing values.

In [None]:
def plot_timeseries(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    # Make figure width a function of the time span length.
    startYear = ds.time.values[0].astype('datetime64[Y]')
    endYear = ds.time.values[-1].astype('datetime64[Y]')
    scaleFactor = 0.75

    figWidth = int(scaleFactor * (endYear - startYear))
    figHeight = 6

    linewidth = 0.5  
    dpi = 100 #700
        
    fig = plt.figure(figsize=(figWidth, figHeight), constrained_layout=True)
    
    for mem_id in member_names:
        data_slice = ds[data_var].sel(member_id=mem_id)
        unit_string = ds[data_var].attrs['units']
            
        min_vals = data_slice.min(dim = ['lat', 'lon'])
        max_vals = data_slice.max(dim = ['lat', 'lon'])
        mean_vals = data_slice.mean(dim = ['lat', 'lon'])
        std_vals = data_slice.std(dim = ['lat', 'lon'])

        nan_indexes = np.isnan(min_vals)
        nan_times = ds.time[nan_indexes]

        plt.clf()
        plt.plot(ds.time, max_vals, linewidth=linewidth, label='max', color='red')
        plt.plot(ds.time, mean_vals, linewidth=linewidth, label='mean', color='black')
        plt.plot(ds.time, min_vals, linewidth=linewidth, label='min', color='blue')
        plt.fill_between(ds.time, (mean_vals - std_vals), (mean_vals + std_vals), color='grey', 
                         linewidth=0, label='std', alpha=0.5)
            
        ymin, ymax = plt.ylim()
        rug_y = ymin + 0.01*(ymax-ymin)
        plt.plot(nan_times, [rug_y]*len(nan_times), '|', color='m', label='missing')
        titleString = get_metadata_value(ds, 'title', mem_id)
        # Remove reference to "Hist" in the title
        titleString = titleString.replace(' Hist ', ' ')
        plt.title(titleString, fontsize=20)
        
        # Place the legend outside the plot, near upper right corner.
        plt.legend(bbox_to_anchor=(1.002, 1), loc="upper left")
        plt.ylabel(unit_string)
        plt.xlabel('Year')
        
        # Reduce x axis margins.
        xm, ym = plt.margins()
        plt.margins(0.005, ym)

        save_as_png(fig, plotdir, store_name, mem_id, 'ts', dpi)


#### Create Statistical Map Plots Over Multiple Pages

In [None]:
def plot_stat_maps(ds, data_var, store_name, plotdir):
    # Generate plot. 
    #
    # With 30 workers, expect 1 minute walltime for computation and 1-2 minutes for plot rendering.
    #
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size

    lat_lon_values = getLatLonValues(ds, 0.05, 0.31, 0.1, 0.2)

    figWidth = 12 #15 #18 
    figHeight = 8 #10 #12
    dpi = 200

    fig = plt.figure(figsize=(figWidth, figHeight))
    
    for mem_id in member_names:
        data_slice = ds[data_var].sel(member_id=mem_id)

        data_agg = data_slice.min(dim='time')
        plotMap(data_agg, lat_lon_values, member_id=mem_id, ds=ds, titleText='Minimum')
        save_as_png(fig, plotdir, store_name, mem_id, 'map-min', dpi)

        data_agg = data_slice.max(dim='time')
        plotMap(data_agg, lat_lon_values, member_id=mem_id, ds=ds, titleText='Maximum')
        save_as_png(fig, plotdir, store_name, mem_id, 'map-max')

        data_agg = data_slice.mean(dim='time')
        plotMap(data_agg, lat_lon_values, member_id=mem_id, ds=ds, titleText='Mean')
        save_as_png(fig, plotdir, store_name, mem_id, 'map-mean')

        data_agg = data_slice.std(dim='time')
        plotMap(data_agg, lat_lon_values, member_id=mem_id, ds=ds, titleText='Standard Deviation')
        save_as_png(fig, plotdir, store_name, mem_id, 'map-std')


#### Function Producing Maps of First, Middle, Last Timesteps

In [None]:
def getValidDateIndexes(member_slice):
    '''Search for the first and last dates with finite values.'''
    min_values = member_slice.min(dim = ['lat', 'lon'])
    is_finite = np.isfinite(min_values)
    finite_indexes = np.where(is_finite)
    start_index = finite_indexes[0][0]
    end_index = finite_indexes[0][-1]
    #print(f'start ={start_index}, end={end_index}')
    return start_index, end_index


def plot_first_mid_last(ds, data_var, store_name, plotdir):
    ''' Generate plots of first, middle and final time steps. '''
    member_names = ds.coords['member_id'].values
    numEnsembleMembers = member_names.size
    
    lat_lon_values = getLatLonValues(ds, 0.05, 0.31, 0.1, 0.2)

    figWidth = 12 #8 #18 
    figHeight = 8 #5 #12
    dpi = 400

    fig = plt.figure(figsize=(figWidth, figHeight))
    
    for mem_id in member_names:

        data_slice = ds[data_var].sel(member_id=mem_id)
            
        start_index, end_index = getValidDateIndexes(data_slice)
        midDateIndex = np.floor(len(ds.time) / 2).astype(int)

        startDate = ds.time[start_index]
        first_step = data_slice.sel(time=startDate) 
        plotMap(first_step, lat_lon_values, startDate, mem_id, ds=ds)
        save_as_png(fig, plotdir, store_name, mem_id, 'first', dpi)

        midDate = ds.time[midDateIndex]
        mid_step = data_slice.sel(time=midDate)   
        plotMap(mid_step, lat_lon_values, midDate, mem_id, ds=ds)
        save_as_png(fig, plotdir, store_name, mem_id, 'middle', dpi)

        endDate = ds.time[end_index]
        last_step = data_slice.sel(time=endDate)            
        plotMap(last_step, lat_lon_values, endDate, mem_id, ds=ds)
        save_as_png(fig, plotdir, store_name, mem_id, 'last', dpi)


### Loop over Zarr Stores in Directory and Make Plots.

In [None]:
# For now, make the Zarr output directory a global variable.
#dirout = '/glade/scratch/bonnland/na-cordex/zarr-demo'
zarr_directory = '/glade/scratch/bonnland/na-cordex/zarr-publish/'
plot_directory = '/glade/scratch/bonnland/na-cordex/zarr-plots-test/'

p = Path(zarr_directory)
stores = list(p.rglob("*.zarr"))
#stores = list(p.rglob("uas.rcp85.*.zarr"))
for store in stores:
    print(f'Opening {store}...')
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print('\n')
    except Exception as e:
        print(e)
        continue
    data_vars = [vname for vname in ds.data_vars]
    data_var = data_vars[0]
    store_name = store.as_posix().split('/')[-1]
    
    # Only produce plots that haven't been created already.  
    plotdir = plot_directory + store_name
    if not os.path.exists(plotdir):
        os.makedirs(plotdir)
    else:
        # Plots exist; skip to the next case.
        continue
    
    plot_stat_maps(ds, data_var, store_name, plotdir)
    plot_first_mid_last(ds, data_var, store_name, plotdir)
    plot_timeseries(ds, data_var, store_name, plotdir)
    

### Release the workers.

In [None]:
!date

In [None]:
cluster.close()