In [None]:
from cil.io import NEXUSDataReader, NEXUSDataWriter
from cil.optimisation.algorithms import SIRT
from cil.plugins.astra.operators import ProjectionOperator
from cil.optimisation.functions import IndicatorBox
from cil.utilities.display import show2D
from cil.processors import Slicer

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid

import numpy as np
import h5py

## Load data after the `RingRemover` processor. 

## For the reconstruction, we use only 5 channels and 5 vertical slices.

## To match the results in the publication comment the line below.

In [None]:
name = "data_after_ring_remover_318_398.nxs"
reader = NEXUSDataReader(file_name="HyperspectralData/"+name)
data = reader.read()

# Comment to use all the channels and vertical slices
data = Slicer(roi={'channel': (37,42),'vertical': (17,22)})(data)

# Get image, acquisition geometries

In [None]:
ag3D = data.geometry.subset(channel=0)
ig3D = ag3D.get_ImageGeometry()

A = ProjectionOperator(ig3D, ag3D,'gpu')

ig = data.geometry.get_ImageGeometry()

# Non-negativity constraint

In [None]:
constraint = IndicatorBox(lower=0.0)

In [None]:
max_iterations = 100 

# SIRT reconstruction without warm start

In [None]:
sirt4D_nowarm = ig.allocate()

x0 = ig3D.allocate()

constraint = IndicatorBox(lower=0)
sirt3D_nowarm = SIRT(max_iteration = max_iterations)

for i in range(ig.channels):
    sirt3D_nowarm.iteration=0
    sirt3D_nowarm.set_up(initial=x0, operator=A, constraint=constraint, 
                     data=data.subset(channel=i))  
    sirt3D_nowarm.run(verbose=0)
    sirt4D_nowarm.fill(sirt3D_nowarm.solution, channel=i)  
    print("Finish SIRT reconstruction for channel {}".format(i))

# SIRT reconstruction with warm start. 

## Initialise with respect to the SIRT reconstruction of the previous channel

In [None]:
sirt4D_warm = ig.allocate()
x0 = ig3D.allocate()
constraint = IndicatorBox(lower=0)
sirt3D_warm = SIRT(max_iteration = max_iterations)


for i in range(ig.channels):
    sirt3D_warm.iteration=0
    sirt3D_warm.set_up(initial=x0, operator=A, constraint=constraint, 
                     data=data.subset(channel=i))  
    sirt3D_warm.run(verbose=0)
    sirt4D_warm.fill(sirt3D_warm.solution, channel=i)
    x0.fill(sirt3D_warm.solution)
    print("Finish SIRT reconstruction for channel {}".format(i))

In [None]:
show2D([sirt4D_warm, sirt4D_nowarm], slice_list=[1,1],
       title=["SIRT warm","SIRT no warm"], 
       origin="upper", fix_range=(0, 0.45), cmap="inferno")

In [None]:
# Save SIRT reconstrucitions
name = "sirt_recon_nowarm.nxs"
writer = NEXUSDataWriter(file_name = "HyperspectralData/" + name,
                     data = sirt4D_nowarm)
writer.write()

name = "sirt_recon_warm.nxs"
writer = NEXUSDataWriter(file_name = "HyperspectralData/" + name,
                     data = sirt4D_warm)
writer.write()