In [None]:
import datetime
import json
import os
import time

import dxchange
from IPython import display
from tqdm.auto import tqdm
import nbclient.exceptions
import matplotlib
import matplotlib.pyplot as plt
import papermill as pm

from ipypathchooser import PathChooser
from ipysliceviewer import SliceViewer

## Set input parameters, create directories and stage input data

In [None]:
dataset_name = "RTV_18A_air_760torr_08_fast"
task_prefix = "papermill_tomopy_{}"

# write results to $SCRATCH
curdir = os.getcwd()
scratchdir = os.path.expandvars("$SCRATCH")
sessiondir = os.path.join(scratchdir, task_prefix.format(datetime.datetime.now().isoformat()))
imagesdir = os.path.join(sessiondir, "{}_images".format(dataset_name))

# create the output directory as needed
if not os.path.exists(sessiondir):
    os.mkdir(sessiondir)

# create the images subdir
if not os.path.exists(imagesdir):
    os.mkdir(imagesdir)
    
# copy our input data directory to $SCRATCH for execution
src = '/global/cfs/cdirs/als/users/parkinson/SLS_Feb2019/disk1/RTV_18A_air_760torr_08_fast'
inputdir = os.path.join(scratchdir, "RTV_18A_air_760torr_08_fast")

# create the input directory on $SCRATCH as needed and copy the data
if not os.path.exists(inputdir):
    shutil.copytree(src, inputdir)

# set the parameters we want to pass to the notebook
params = dict(
    filename = "RTV_18A_air_760torr_08_fast.h5",
    inputPath = inputdir,
    fulloutputPath = imagesdir,
    chunk_proj=1,
    chunk_sino=1,
    ncore=1,
    filetype='sls',
    cor=385
)

# this is the notebook we will be running through papermill, which calls the reconstruction code
in_nb = os.path.join(curdir, 'tomopy_recon_template.ipynb')

## Connect to Dask

In [None]:
# connect to our Dask cluster running inside of a Cori job, or a local Dask cluster if no job
from dask.distributed import Client, as_completed

localcluster = False
if os.path.exists("dask_client"):
    try:
        with open("dask_client", 'r') as f:
            scheduler_file = f.read().strip()
        
        dask_client = Client(scheduler_file=scheduler_file)
    except Exception as e:
        print("Unable to use existing dask_client file to connect to a Dask cluster!")
        localcluster = True
else:
    localcluster = True

if localcluster:
    # No Dask cluster present, we will start a local cluster
    from dask.distributed import LocalCluster    
    
    cluster = LocalCluster(n_workers=4, threads_per_worker=2)
    dask_client = Client()
    
dask_client

## Process all timepoints and collect Papermill results

In [None]:
# process all chunks
#timepoints range from 0 to 105
first_timepoint = 0
last_timepoint = 10
submits = []
fails = []

# save the input parameters out to a file with the output notebooks and data
input_params_file = os.path.join(sessiondir, task_prefix.format("input_parameters.json"))
with open(input_params_file, 'w') as f:
    json.dump(params, f, indent=4)

# submit all papermill tasks to Dask
num_points = last_timepoint - first_timepoint
with tqdm(total=num_points, desc="Tasks submitted", unit="task") as submits_pbar:
    for timepoint in range(first_timepoint, last_timepoint):
        params['timepoint'] = timepoint
        out_nb = os.path.join(sessiondir, '{}.ipynb'.format(task_prefix.format(timepoint)))

        submits.append(
            dask_client.submit(
                pm.execute_notebook, 
                in_nb, 
                out_nb, 
                params,
                start_timeout=60,
                progress_bar=False))
        submits_pbar.update(1)
        time.sleep(1)

print("{}: {} tasks submitted, {} timepoints".format(params['filename'], len(submits), num_points))

last_exc = None
completed = []
failed = []
# wait for all the tasks to complete
with tqdm(total=len(submits), desc="Tasks completed", unit="task") as completed_pbar:
    for future in as_completed(submits):
        try:
            x = future.result()
            completed.append(x)
        except nbclient.exceptions.DeadKernelError as e:
            timepoint = submits.index(future)
            params['timepoint'] = timepoint
            failed.append(params)
        except Exception as e:
            print(e)
            timepoint = submits.index(future)
            params['timepoint'] = timepoint
            failed.append(params)
        finally:
            completed_pbar.update(1)

## Resubmit failed tasks, if any

In [None]:
# cleanup any remaining tasks, if present

resubmits = []
while len(failed) > 0:
    display.clear_output(wait=True)
    
    with tqdm(total=len(failed), desc="Tasks resubmitted", unit="task") as resubmits_pbar:
        for task_params in failed:
            out_nb = os.path.join(sessiondir, '{}.ipynb'.format(task_prefix.format(task_params['timepoint'])))

            resubmits.append(
                dask_client.submit(pm.execute_notebook, in_nb, out_nb, task_params, progress_bar=False))
            resubmits_pbar.update(1)
            time.sleep(1)

    with tqdm(total=len(resubmits), desc="Cleanup Tasks completed", unit="task") as cleanup_pbar:
        for future in as_completed(resubmits):
            try:
                x = future.result()
                
                # clear the failure
                i = resubmits.index(future)
                failed[i] = None
                resubmits[i] = None
                completed.append(x)
            except RuntimeError as e:
                last_exc = e    
            finally:
                cleanup_pbar.update(1)
    
    failed = [f for f in failed if f is not None]
    resubmits = [r for r in resubmits if r is not None]

## List the output data directory and contents

In [None]:
print("Output notebooks and data in:\n {}\n\n".format(sessiondir))

num_notebooks = 0
for entry in os.scandir(sessiondir):
    if entry.name.endswith(".ipynb"):
        num_notebooks += 1

num_images = 0
last_index = 0
for entry in os.scandir(imagesdir):
    if entry.name.endswith(".tiff"):
        num_images += 1

        current_index_string = entry.name.split("_")[-1].split('.')[0]

        if '-' in current_index_string:
            current_index_string = current_index_string.split('-')[0]

        current_index = int(current_index_string)

        if last_index < current_index:
            last_index = current_index
        
print("{} Jupyter Notebooks were created".format(num_notebooks))
print("{} Images were created, with last_index: {}".format(num_images, last_index))

## Display a quick preview of a sample image

In [None]:
# preselect an image in the middle of the stack
mid_timepoint = "{:05d}".format(int((last_index / 2)) + 1)
fname = os.path.join(
    imagesdir, 
    "{}_{}.tiff".format(dataset_name, mid_timepoint))

# Alternatively, select an image to preview
#image_path = PathChooser(default_directory=imagesdir)
#display.display(image_path)
#fname = slices_path.chosen_path

sample_image = dxchange.reader.read_tiff(fname)
matplotlib.rcParams['figure.dpi'] = 300
plt.imshow(sample_image, cmap="gray")

## Preselect the data output directory for viewing images

In [None]:
# select the directory to preview (currently only one directory for data)
slices_path = PathChooser(default_directory=sessiondir)
slices_path.chosen_path = imagesdir
slices_path

## Preview the full image stack

In [None]:
#from ipysliceviewer import SliceViewer
#s = SliceViewer(default_directory=slices_path.chosen_path)
#display.display(s)

In [None]:
if localcluster:    
    dask_client.shutdown()
    dask_client.close()
else:
    dask_client.close()