In [None]:
import datetime
import json
import os
import shutil
import time

import dxchange
from IPython import display
import matplotlib
import matplotlib.pyplot as plt

from ipypathchooser import PathChooser
from ipysliceviewer import SliceViewer

from distributed_recon import ReconstructionAnalysis

## Connect to Dask

In [None]:
# connect to our Dask cluster running inside of a Cori job, or a local Dask cluster if no job
from dask.distributed import Client, as_completed

localcluster = False
if os.path.exists("dask_client"):
    try:
        with open("dask_client", 'r') as f:
            scheduler_file = f.read().strip()
        
        dask_client = Client(scheduler_file=scheduler_file)
    except Exception as e:
        print("Unable to use existing dask_client file to connect to a Dask cluster!")
        localcluster = True
else:
    localcluster = True

if localcluster:
    # No Dask cluster present, we will start a local cluster
    from dask.distributed import LocalCluster    
    
    cluster = LocalCluster(n_workers=4, threads_per_worker=2)
    dask_client = Client()
    
dask_client

## Set inputs and directories for analyses

In [None]:
dataset_name = "RTV_18A_air_760torr_08_fast"

inputdir = PathChooser()
inputdir.chosen_path = '/global/cfs/cdirs/als/users/parkinson/SLS_Feb2019/disk1/RTV_18A_air_760torr_08_fast'
inputdir

In [None]:
scratchdir = os.path.expandvars("$SCRATCH")
input_stagedir = os.path.join(scratchdir, "als_reconstruction_inputs")

# we will create a directory to store all of our analyses for this session
sessiondir = os.path.join(scratchdir, "{}_{}".format(dataset_name, datetime.datetime.now().isoformat()))
os.mkdir(sessiondir)

# copy our input data directory to $SCRATCH for execution
dataset_stagedir = os.path.join(input_stagedir, dataset_name)

# create the input directory on $SCRATCH as needed and copy the data
if not os.path.exists(dataset_stagedir):
    shutil.copytree(inputdir.chosen_path, dataset_stagedir)
    
# this is the notebook we will be running through papermill, which calls the reconstruction code
currentdir = os.getcwd()
template_nb = os.path.join(currentdir, 'tomopy_recon_template.ipynb')

## Define initial parameters and analysis inputs, then process the dataset

In [None]:
# set the parameters we want to pass to the notebook
generic_params = dict(
    filename = "RTV_18A_air_760torr_08_fast.h5",
    inputPath = dataset_stagedir,
    chunk_proj = 1,
    chunk_sino = 1,
    ncore = 1,
    filetype = 'sls'
)

initial_timepoints = list(range(0,10))

# run with default values
defaults_analysis = ReconstructionAnalysis(
    template_nb=template_nb,
    label="{}_defaults".format(generic_params["filename"]),
    description="Processing {} with default values".format(generic_params["filename"]))

completed, failed = defaults_analysis.run_analysis(
    outputdir=sessiondir,
    params=generic_params,
    timepoints=initial_timepoints,
    dask_client=dask_client
)

## Display a quick preview of a sample image

In [None]:
imagesdir = os.path.join(sessiondir, defaults_analysis._outputdir, "images")
last_index = initial_timepoints[-1]

# preselect an image in the middle of the stack
mid_timepoint = "{:05d}".format(int((last_index / 2)) + 1)
fname = os.path.join(
    imagesdir, 
    "{}_{}.tiff".format(dataset_name, mid_timepoint))

# Alternatively, select an image to preview
#image_path = PathChooser(default_directory=imagesdir)
#display.display(image_path)
#fname = slices_path.chosen_path

sample_image = dxchange.reader.read_tiff(fname)
matplotlib.rcParams['figure.dpi'] = 300
plt.imshow(sample_image, cmap="gray")

## Preview the full image stack

In [None]:
from ipysliceviewer import SliceViewer
s = SliceViewer(default_directory=imagesdir)
display.display(s)

## Process the same dataset using multiple parameter sets

In [None]:
# Capture all of our processing runs
analyses = []

# experiment with butterworth cutoff values
param_set_bcutoff = [
    dict(**generic_params, cor=385, butterworth_cutoff=n) for n in [0.05, 0.1, 0.2]]

# experiment with pipeline options
param_set_pipeline_options = [
    dict(**generic_params, doPhaseRetrieval=True),
    dict(**generic_params, doOutliers1D=True),
    dict(**generic_params, doPolarRing=True),
    dict(**generic_params, doPolarRing2=True)
]

for p in param_set_bcutoff:
    cutoff = p['butterworth_cutoff']
    
    cutoff_analysis = ReconstructionAnalysis(
        template_nb=template_nb,
        label="{}_bcutoff_{}".format(dataset_name, cutoff),
        description="Processing {} with default values and a butterworth cutoff of {}".format(
            generic_params["filename"],
            cutoff))
    completed, failed = cutoff_analysis.run_analysis(
        outputdir=sessiondir,
        params=p,
        timepoints=initial_timepoints,
        dask_client=dask_client)
    
    analyses.append(cutoff_analysis)
    
#for p in param_set_pipeline_options:
#    mod = None
#    for k in p:
#        if k not in generic_params:
#            mod = k
    
#    pipeline_options_analysis = ReconstructionAnalysis(
#        template_nb=template_nb,
#        label="{}_pipeline_option_{}".format(generic_params["filename"], mod),
#        description="Processing {} with default values and a pipeline option of {}".format(
#            generic_params["filename"],
#            mod))
#    completed, failed = pipeline_options_analysis.run_analysis(
#        outputdir=sessiondir,
#        params=p,
#        timepoints=initial_timepoints,
#        dask_client=dask_client)
    
#    analyses.append(pipeline_options_analysis)

## Preview the output for each analysis

In [None]:
def mpl_preview_image(analysis, timepoint=None):
    imagesdir = os.path.join(analysis._outputdir, "images")
    last_index = initial_timepoints[-1]

    if timepoint is None:    
        # preselect an image in the middle of the stack
        timepoint = "{:05d}".format(int((last_index / 2)) + 1)

    fname = os.path.join(
        imagesdir, 
        "{}_{}.tiff".format(analysis._params['filename'].rsplit('.', 1)[0], timepoint))

    sample_image = dxchange.reader.read_tiff(fname)    
    matplotlib.rcParams['figure.dpi'] = 300
    plt.title(analysis.label)
    plt.imshow(sample_image, cmap="gray")

for recon_analysis in analyses:
    mpl_preview_image(recon_analysis)    

## Close the Dask connection

In [None]:
if localcluster:    
    dask_client.shutdown()
    dask_client.close()
else:
    dask_client.close()