In [None]:
%matplotlib widget

import os
import multiprocessing as mp
os.environ['NUMEXPR_MAX_THREADS'] = str(mp.cpu_count()) # to avoid numexpr warning
import time
import numexpr
import numpy as np
from skimage import transform, filters, io
import tomopy
import svmbir
import dxchange
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, fixed, FloatSlider, IntSlider
import ALS_recon_helper
import utils

### Choose Data
###### Here we choose petiole data

In [None]:
#input and output paths
dataDir = "//global/cfs/cdirs/als/users/dperl/petiole_data"
outputDir = "$PSCRATCH"
filename = '20211222_113313_petiole3.h5'

In [None]:
path = os.path.join(dataDir, filename)
metadata = ALS_recon_helper.read_metadata(path)

#### Get projection data. Can subsample angles and/or slices with numpy slicing notation (start, stop, step)
###### Here we highly downsample a single slice for quick reconstruction 

In [None]:
# angles_ind = None # get all angles (implicitly)
# angles_ind = slice(0,-1,1) # get all angles (explicitly)
angles_ind = slice(0,-1,4) # get every 4th angle
# slices_ind =  None # get all slices
# slices_ind = slice(0,10,1) # get first 10 slices
# slices_ind = slice(-10,-1,1) # get last 10 slices
slices_ind = slice(metadata['numslices']//2,metadata['numslices']//2+1,1) # get only middle slice
#downsample_factor = None # do not downsample projections
downsample_factor = 4 # downsample projections (in both slice and ray dimensions)
tomo, angles = ALS_recon_helper.read_data(path, proj=angles_ind, sino=slices_ind, downsample_factor=downsample_factor)
print(f"data shape = {tomo.shape}")

### Find Reconstruction Settings
#### Reconstruct 2D slice with manual center of rotation (COR) 
###### Here we try many CORs to find best one. Iterate COR range with increasing resolution

In [None]:
%%time
# cors = np.arange(-80,81,5) # wide COR rangeum)
cors = np.arange(30,45,0.5) # narrow COR range
recons = [tomopy.recon(tomo, angles, center=cor/downsample_factor + tomo.shape[2]/2,
                       algorithm=tomopy.astra,
                       options={'method':"FBP_CUDA", 'proj_type':'cuda'})
          for cor in cors]

###### Gaussian smooth recons (FBP is noisy)

In [None]:
smooth_um = 2 # in um

smooth_px = smooth_um/(1e4*metadata['pxsize']*downsample_factor)
recons = [filters.gaussian(recon,sigma=smooth_px) for recon in recons]

In [None]:
def plot_recons(recon,fignum=1):
    img_lim = [np.percentile(recons[0],1),np.percentile(recons[0],99)]
    if plt.fignum_exists(fignum): plt.close(fignum)
    fig = plt.figure(num=fignum,figsize=(4, 4))
    axs = plt.gca()
    h = axs.imshow(recon[0],cmap='gray',vmin=img_lim[0],vmax=img_lim[1])    
    return h, axs
def set_cor(i):
    h.set_data(recons[i][0])
    axs.set_title(f"COR = {cors[i]} pixels (at full res)")

###### With mismatched COR, arc/shadow artifacts appear on features. Try to minimize

In [None]:
h, axs = plot_recons(recons[0],fignum=1)
interact(set_cor, i=IntSlider(min=0, max=len(cors)-1, step=1, value=0))

#### Reconstruct 2D slice with FBP at full resolution with best COR 

In [None]:
%%time
angles_ind = None # get all angles
slices_ind = slice(metadata['numslices']//2,metadata['numslices']//2+1,1) # get only middle slice
downsample_factor = None
smooth_um = 1 # in um
COR = 36.5
tomo, angles = ALS_recon_helper.read_data(path, proj=angles_ind, sino=slices_ind, downsample_factor=downsample_factor)
fbp_recon = tomopy.recon(tomo, angles, center=COR + tomo.shape[2]/2,
                       algorithm=tomopy.astra,
                       options={'method':"FBP_CUDA", 'proj_type':'cuda'})

smooth_px = smooth_um/(1e4*metadata['pxsize'])
fbp_recon = filters.gaussian(fbp_recon,sigma=smooth_px)
print(fbp_recon.shape)

In [None]:
img_lim = [np.percentile(recons[0],1),np.percentile(recons[0],99)]
if plt.fignum_exists(2): plt.close(2)
plt.figure(num=2,figsize=(4,4))
plt.imshow(fbp_recon[0],cmap='gray',vmin=img_lim[0],vmax=img_lim[1])

#### Reconstruct 2D slice with SVMBIR at full resolution with best COR
###### SVMBIR recon takes ~2-4 min/slice at full resolution with all (128) threads

In [None]:
%%time
T = 0.1 #smoothing parameter
q = 1.2 #smoothing parameter
p = 2 #smoothing parameter
sharpness = 0.5
snr_db = 60.0
angles_ind = None # get all angles
slices_ind = slice(metadata['numslices']//2,metadata['numslices']//2+1,1) # get only middle slice
downsample_factor = None
COR = 36.5
tomo, angles = ALS_recon_helper.read_data(path, proj=angles_ind, sino=slices_ind, downsample_factor=downsample_factor)
tomo = ALS_recon_helper.shift_prjections(tomo,COR) # must manually shift COR, not shift projector
svmbir_recon = svmbir.recon(tomo,angles,
                          center_offset=0.0, # MUST BE ZERO TO AVOID VERY LONG COMPUTATION OF PROJECTION MATRIX  
                          # init_image=fbp_recon.copy(), # init with fbp for faster convergence. copy is important to not overwite fbp_recon                     
                          T=T, q=q, p=p, sharpness=sharpness, snr_db=snr_db,
                          # num_threads=128,
                          # max_iterations=10,
                          svmbir_lib_path='/pscratch/sd/d/dperl/svmbir_cache',
                          verbose=1)
print(svmbir_recon.shape)

In [None]:
if plt.fignum_exists(3): plt.close(3)
fig, axs = plt.subplots(1,2,num=3,figsize=(10,6),sharex=True,sharey=True)
axs[0].imshow(fbp_recon[0],cmap='gray',vmin=np.percentile(fbp_recon[0],1),vmax=np.percentile(fbp_recon[0],99))
axs[0].set_title('FBP')
axs[1].imshow(svmbir_recon[0],cmap='gray',vmin=np.percentile(svmbir_recon[0],10),vmax=np.percentile(svmbir_recon[0],90))
axs[1].set_title('SVMBIR')
plt.tight_layout()

### Reconstruct 3D stack
###### FBP_CUDA with 100 slices takes ~ 30 sec with 1 GPU 

In [None]:
%%time
angles_ind = None
slices_ind = slice(1000,1100,1)
smooth_um = 1 # in um
COR = 36.5
tomo, angles = ALS_recon_helper.read_data(path, proj=angles_ind, sino=slices_ind)
fbp_recon = tomopy.recon(tomo, angles, center=COR + tomo.shape[2]/2,
                       algorithm=tomopy.astra,
                       options={'method':"FBP_CUDA", 'proj_type':'cuda'})

smooth_px = smooth_um/(1e4*metadata['pxsize'])
recon_3d = filters.gaussian(fbp_recon,sigma=smooth_px)
print(fbp_recon.shape)

In [None]:
def plot_recon_3d(recon_3d,fignum=1):
    img_lim = [np.percentile(recon_3d[0],1),np.percentile(recons[0],99)]
    if plt.fignum_exists(fignum): plt.close(fignum)
    fig = plt.figure(num=fignum,figsize=(6, 6))
    axs = plt.gca()
    h = axs.imshow(recon_3d[0],cmap='gray',vmin=img_lim[0],vmax=img_lim[1])    
    return h, axs
def set_z(z):
    h.set_data(recon_3d[z])
    axs.set_title(f"Slice {z}")

In [None]:
h, axs = plot_recon_3d(recons[0],fignum=4)
interact(set_z, z=IntSlider(min=0, max=len(recon_3d)-1, step=1, value=len(recon_3d)//2))

#### Reconstruct and save volume in chunks of 100 slices

In [None]:
saveDir = os.path.join(outputDir,os.path.splitext(filename)[0])
if not os.path.exists(saveDir): os.makedirs(saveDir)
savename = os.path.join(saveDir,"img")

start = 0
stop = 500
nchunk = 100

smooth_um = 1 # in um
COR = 36.5


for i in range(np.ceil((stop-start)/100).astype(int)):
    start_iter = start+i*nchunk
    stop_iter = np.minimum((i+1)*nchunk,stop)
    tomo, angles = ALS_recon_helper.read_data(path, sino=slice(start_iter,stop_iter,1))
    print(f"Starting recon of slices {start_iter}-{stop_iter}...",end=' ')
    tic = time.time()
    recon = tomopy.recon(tomo, angles, center=COR + tomo.shape[2]/2,
                           algorithm=tomopy.astra,
                           options={'method':"FBP_CUDA", 'proj_type':'cuda'})
    if smooth_um:
        smooth_px = smooth_um/(1e4*metadata['pxsize'])
        recon_3d = filters.gaussian(fbp_recon,sigma=smooth_px)
    print(f"Finished: took {time.time()-tic} sec. Saving files...")

    dxchange.write_tiff_stack(recon, fname=savename, start=start_iter)

### Submit batch job
###### For very long jobs

In [None]:
# TBD