##### Load dependencies

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
from pathlib import Path

import ipywidgets as widgets
from ipywidgets import VBox
from ipywidgets import HBox, Label
from IPython.display import display, clear_output

from cil.framework import DataContainer #,AcquisitionData
from cil.utilities.display import show_geometry, show2D
from cil.utilities.jupyter import islicer
from cil.io import NikonDataReader, TIFFWriter
from cil.processors import TransmissionAbsorptionConverter, CentreOfRotationCorrector, RingRemover
from cil.recon import FDK
from cil.optimisation.algorithms import CGLS, FISTA
from cil.optimisation.functions import LeastSquares, L2NormSquared#, ZeroFunction, 
from cil.optimisation.functions import TotalVariation
from cil.optimisation.utilities import callbacks

##### Load the data

Specify the parent directory and list the scans

In [None]:
# Specify the parent directory
parent_directory = "data"

# Make sure the directory exists
if not os.path.isdir(parent_directory):
    raise IOError(parent_directory + " is not a directory.")
    
# list the scans in the directory
files = glob.glob(os.path.join(parent_directory, '*.xtekct'))

# Make sure the directory contains a single xtekct file
if len(files) == 0:
    raise IOError(parent_directory + " is a directory but it does not contain any xtekct file.")

if len(files) > 1:
    print("WARNING:", parent_directory + " is a directory but it contains several xtekct files. The first one", files[0], "will be used by default.")


for i, file in enumerate(files):
    print(str(i) + ': ' + str(file))

choose the number of the file in the files list 

In [None]:
file_menu = widgets.Dropdown(
    options=files,
    value=files[0],
    description='File:',
    disabled=False,
)
file_menu

In [None]:
filename = file_menu.value

In [None]:
use_binning_checkbox = widgets.Checkbox(
    value=False,
    description='Use binning',
    disabled=False,
    indent=False
);

binning_xy_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=10,
    step=1,
    orientation='horizontal',
    readout=True,
    readout_format='d'
);

binning_z_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=10,
    step=1,
    orientation='horizontal',
    readout=True,
    readout_format='d'
);

binning_xy_label = Label('Binning along the X- & Y-axes:');

binning_z_label = Label('Binning along the Z-axis:');

layout = widgets.Layout(width='600px')

out_vbox = VBox(children=[
        HBox([binning_xy_label, binning_xy_slider]),
        HBox([binning_z_label, binning_z_slider]),
    ]);

if use_binning_checkbox.value == False:
    out_vbox.layout.visibility = 'hidden'

out = widgets.Output()

def on_value_change(change):
    with out:
        if use_binning_checkbox.value:
            out_vbox.layout.visibility = 'visible'
            
        else:
            out_vbox.layout.visibility = 'hidden'

        clear_output()
        
    out_vbox.children=[
        HBox([Label('Binning along the X- & Y-axes:'), binning_xy_slider]),
        HBox([Label('Binning along the Z-axis:'), binning_z_slider])
    ];

use_binning_checkbox.observe(on_value_change, names='value')

HBox(children=[use_binning_checkbox, out_vbox])

In [None]:
# Instanciate the reader
if use_binning_checkbox:
    reader = NikonDataReader(file_name=filename, roi= {'horizontal':(None, None, binning_xy_slider.value),'vertical':(None, None, binning_xy_slider.value),'angle':(None, None, binning_z_slider.value)}, mode="bin");
else:
    reader = NikonDataReader(file_name=filename);
    
# Read the data
data = reader.read()

In [None]:
# Inspect the geometry
print(data.geometry);

In [None]:
# Plot and save the geometry
fname = os.path.join(parent_directory, "geometry.png");
show_geometry(data.geometry).save(fname);

In [None]:
# Inspect the projections
islicer(data, direction='angle', origin="upper-left");

In [None]:
show2D(data, origin="upper-left");

##### Normalise using $-\ln\left(\frac{data}{white\_level}\right)$

##### Transmission to absorption 

Use the CIL `TransmissionAbsorptionConverter`
- If there are negative numbers in the data, specify a low value in `min_intensity` to clip these values before calculating -log

In [None]:
data_corr = TransmissionAbsorptionConverter(min_intensity=0.00001, white_level=data.max())(data);

Plot the sinogram of the centre slice using show2D

In [None]:
show2D([data, data_corr], slice_list=('vertical', data.shape[2] // 2))

##### Get a vertical slice of the data

##### Filtered back projection

We use the CIL filtered back projection. By default this uses a Ram-Lak

In [None]:
data_slice = data_corr.get_slice(vertical="centre")

In [None]:
ig = data_slice.geometry.get_ImageGeometry();
recons_FDK_before = FDK(data_slice, ig).run(verbose=False)
show2D(recons_FDK_before, origin="upper-left")

In [None]:
voxel_num_x_slider = widgets.IntSlider(
    value=ig.voxel_num_x,
    min=1,
    max=ig.voxel_num_x,
    step=1,
    disabled=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

voxel_num_y_slider = widgets.IntSlider(
    value=ig.voxel_num_y,
    min=1,
    max=ig.voxel_num_y,
    step=1,
    disabled=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

voxel_num_z_slider = widgets.IntSlider(
    value=data.shape[1],
    min=1,
    max=data.shape[1],
    step=1,
    disabled=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

# Box(children=[voxel_num_x, voxel_num_y, voxel_num_z])
VBox(children=[
    HBox([Label('Number of voxels along the X-axis:'), voxel_num_x_slider]),
    HBox([Label('Number of voxels along the Y-axis:'), voxel_num_y_slider]),
    HBox([Label('Number of voxels along the Z-axis:'), voxel_num_z_slider])
])

In [None]:
# Define a ROI to reconstruct the 3 middle slices only
ig = data_corr.geometry.get_ImageGeometry();
ig.voxel_num_x = voxel_num_x_slider.value
ig.voxel_num_y = voxel_num_y_slider.value
ig.voxel_num_z = voxel_num_z_slider.value

reco_before = FDK(data_corr, ig).run(verbose=False)
show2D(reco_before, origin="upper-left")

In [None]:
islicer(reco_before, origin="upper-left")

##### Centre of rotation correction

If the data has projections which are 180 degrees apart, uncomment this cell to use the CIL `CentreOfRotationCorrector.xcorrelation` processor to find the centre of rotation offset automatically
- Specify a first projection to use for the correlation, and the algorithm will identify the second angle which is 180 degrees from the first - within a specified angular tolerance 

In [None]:
processor = CentreOfRotationCorrector.image_sharpness(slice_index='centre', tolerance=1/125)
processor.set_input(data_corr)
processor.get_output(out=data_corr)

In [None]:
recons_FDK = FDK(data_corr, ig).run(verbose=False)
show2D([reco_before, recons_FDK, reco_before - recons_FDK],
    ['Before centre of rotation correction','After centre of rotation correction','Signed difference'])

Print the geometry to see the rotation axis has been changed

In [None]:
print(data_corr.geometry)

Alternatively manually enter a pixel offset.

In [None]:
# pixel_offset = ???
# data_corr.geometry.set_centre_of_rotation(pixel_offset, distance_units='pixels')

In [None]:
from cil.plugins.astra.operators import ProjectionOperator
from cil.optimisation.functions import IndicatorBox



Next, we create our simulated tomographic data by projecting our noiseless phantom to the acquisition space. Using the image geometry ig and acquisition geometry ag, we define the ProjectionOperator.


In [None]:
# set the backend for FBP and the ProjectionOperator
device = 'gpu'

In [None]:
# Create projection operator using Astra-Toolbox.
data_corr.reorder('astra')
A = ProjectionOperator(ig, data_corr.geometry, device)

In [None]:
# initial estimate - zero array in this case 
initial = ig.allocate(0)

# setup CGLS
cgls = CGLS(initial=initial, 
            operator=A, 
            data=data_corr,
            update_objective_interval = 1 )

In [None]:
cgls_number_of_iteration_slider = widgets.IntSlider(
    value=5,
    min=1,
    max=50,
    step=1,
    disabled=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

HBox([Label('Number of iterations (CGLS):'), cgls_number_of_iteration_slider])

In [None]:
# run N interations
cgls.run(cgls_number_of_iteration_slider.value, callbacks=[callbacks.TextProgressCallback()])

In [None]:
# get and visualise the results
recon_CGLS = cgls.solution

show2D([recons_FDK, recon_CGLS, reco_before - recons_FDK],
    ['FBP', 'CGLS','Signed difference'])

In [None]:
plt.plot(cgls.objective)
plt.gca().set_yscale('log')
plt.xlabel('Number of iterations')
plt.ylabel('Objective value')
plt.grid()

In [None]:
writer = TIFFWriter()
writer.set_up(data = data_slice, file_name='path_to_data/data.tiff') #add data type, cast to float16
# writer.write()

In [None]:
raise UserWarning('Exit Early')


Constrained reconstruction

Perhaps the most intuitive constraint one can enforce on reconstructed data is the non-negativity constraint. The image data we are reconstructing is the linear attenuation coefficient of the material, so intuitively this cannot have a negative value. Here we employ the SIRT algorithm, an algebraic iterative method for a particular weighted least-squares problem which in addition accepts certain convex constraints such as a non-negativity constraint. As with CGLS, it exhibits semi-convergence, however tends to require more iterations. We enforce box constraints (lower and upper bounds) with the IndicatorBox function.


In [None]:
# constraint = IndicatorBox(lower=0)
f = LeastSquares(A, data_corr)
# f = L2NormSquared(A, data_corr)

alpha = 5.
g = alpha * TotalVariation(max_iteration=50, tolerance=0, lower=0)

fista = FISTA(initial = initial, f = f, g = g, max_iteration=500, update_objective_interval=1)



In [None]:
fista_number_of_iteration_slider = widgets.IntSlider(
    value=250,
    min=1,
    max=500,
    step=1,
    disabled=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

HBox([Label('Number of iterations (FISTA):'), fista_number_of_iteration_slider])

In [None]:
# run N interations
fista.run(fista_number_of_iteration_slider.value, callbacks=[callbacks.TextProgressCallback()])

In [None]:
plt.plot(fista.objective)
plt.gca().set_yscale('log')
plt.xlabel('Number of iterations')
plt.ylabel('Objective value')
plt.grid()

In [None]:
# get and visualise the results
recon_FISTA = fista.solution

show2D([recons_FDK, recon_CGLS, recon_FISTA], ['FBP', 'CGLS', 'FISTA'], \
       num_cols=3, size=(15,10), origin='upper-left')

In [None]:
raise UserWarning('Exit Early')

##### Crop the data

In [None]:
# data_before = data_slice.copy()

# processor = Slicer(roi = {'horizontal':(500,2100,1)})
# processor.set_input(data_slice)
# data_slice = processor.get_output()

# show2D([data_before, data_slice], title=['Before cropping', 'After cropping'])


Compare the reconstruction

In [None]:
# reco = FBP(data_slice).run(verbose=False)
# reco.apply_circular_mask(0.9)

# show2D([reco_before.array[1000:1100,1000:1100], reco.array[1000:1100,1000:1100]])

Plot a cross-section through the reconstruction

In [None]:
# plt.plot(reco_before.array[1100,1100:1200])
# plt.plot(reco.array[1100,1100:1200])
# plt.xlabel('Horizontal x (pixels)')
# plt.ylabel('Intensity')
# plt.legend(['Before phase retrieval','After phase retrieval'])

##### Ring remover

Use the CIL ring remover processor to remove rings using a wavelet decomposition method

- Increasing sigma increases the frequency of ring artefacts that can be removed
- Increasing the number of decompositions  will increase the strength of the ring remover, but too high sigma will distort the profile of the image

In [None]:
# data_before = data_slice.copy()
# reco_before = reco.copy()

As above, we can loop through different parameters and view the reconstructions with islicer

In [None]:
array_list = []
array_list.append(reco.array)
decNum_list = [1, 10, 50, 100, 500]
for d in decNum_list:
    processor = RingRemover(decNum = d, wname = "db35", sigma = 1.5,  info = True)
    processor.set_input(data_corr)
    temp_data = processor.get_output()
    reco = FDK(temp_data, ig).run(verbose=False)
    array_list.append(reco.array[1])

In [None]:
temp = []

for i in range(len(array_list)):
    show2D(array_list[i][1])

In [None]:
DC = DataContainer(np.stack(temp, axis=0), dimension_labels=tuple(['Ring remover decNum']) + reco.geometry.dimension_labels)


There's a ring visible at (700-1100, 700-1100). Cycle through the slices to see how well it is removed

In [None]:
islicer(DC, slice_number=0)


Choose the preferred ring removal method and apply it to the data 

In [None]:
best_snr = 5
data_slice = rem.remove_all_stripe(data_before.as_array(), best_snr, 5, 1)
data_slice = AcquisitionData(data_slice.astype(np.float32), geometry=data_before.geometry)
reco = FBP(data_slice).run(verbose=False)
show2D([reco_before.array[700:1000,700:1000], reco.array[700:1000,700:1000]])

##### Save the processed data

Once we've happy with the reconstruction save the processed data as TIFF

In [None]:
writer = TIFFWriter()
writer.set_up(data = data_slice, file_name='path_to_data/data.tiff') #add data type, cast to float16
# writer.write()