# Overview
This notebook is used to simulate an X-ray CT scan and reconstruction.

The scan object is a twisted hexagonal geometry with helical flow channels (object.stl), and the scan parameters are taken from an existing scan (scan_param).

The aim is to create a dataset of projections similar to the experiment to:
- practise the reconstruction methods
- prepare the scripting for data import and processing when the raw projection data arrives

gVXR is used to simulate the projects, then the data is written to unsigned 16-bit raw format, which is then read in again and reconstructed.
The experimental data is very large, so for the simulation a reduction factor is used for projection resolution and a smaller number of angles are used.

# Setup

## Install
- [environment: CIL CPU](https://github.com/TomographicImaging/CIL)
> conda create --name cilCPU -c conda-forge -c https://software.repos.intel.com/python/conda -c ccpi cil=24.2.0 ipp=2021.12 astra-toolbox=*=py* tigre ccpi-regulariser tomophantom ipykernel ipywidgets scikit-image

- [package: gVXR](https://gvirtualxray.fpvidal.net/)
> pip install --upgrade gvxr

- [package: k3d](https://github.com/K3D-tools/K3D-jupyter)
> conda install -c conda-forge k3d

## Import modules

In [None]:
# CIL
from cil.framework import AcquisitionData, AcquisitionGeometry, DataContainer
from cil.utilities.display import show_geometry, show2D
from cil.plugins.astra import ProjectionOperator
from cil.io import RAWFileWriter
from cil.optimisation.algorithms import CGLS, SIRT
from cil.optimisation.functions import IndicatorBox

# gvxr
import k3d
from gvxrPython3 import gvxr
from gvxrPython3.utils import visualise

# other
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import numpy as np

# Read data

## User input

In [None]:
# xml name
xmlName = 'scan_param.xml'

# stl name
stlName = 'object.stl'

# data reduction factor (scale down)
dataRed = 15

## read xml data

In [None]:
# convert xml as nested dictionary
tree = ET.parse(xmlName)
d = {}
for group in tree.getroot():
    d[group.tag] = {}
    for item in group:
        if item:
            for subitem in item:
                d[group.tag][item.tag] = {}
                d[group.tag][item.tag][subitem.tag] = subitem.text
        else:
            d[group.tag][item.tag] = item.text

# geometry data
numPixX = int(d['ScanParameter']['DetectorPixelX'])
numPixY = int(d['ScanParameter']['DetectorPixelY'])
souPos = float(d['Geometrie']['SourceObjectDist'])
detPos = float(d['Geometrie']['SourceDetectorDist'])
pixSiz = float(d['Recon']['ProjectionPixelSizeX'])
detOffX = float(d['Recon']['ProjectionCenterOffsetX'])
detOffY = float(d['Recon']['ProjectionCenterOffsetY'])

# simulation data - resolution decreased by factor of dataRed
numPixX = int(numPixX / dataRed)
numPixY = int(numPixY / dataRed)
pixSiz = pixSiz * dataRed

# scan data
scaVol = float(d['ScanParameter']['Voltage'])
numPro = int(d['Recon']['ProjectionCount'])


## Load stl and display

In [None]:
plot = k3d.plot()
with open(stlName, "rb") as model:
    plot += k3d.stl(model.read(),color=0xfdea4f)
plot.display()

# Simulate Projection

## User Settings

In [None]:
# projection
numPro = 300
endAng = 360

# material
matNam = "ZrO2"
matDen = 5.68     # g/cm3

## Setup Geometry

In [None]:
# create geometry
ag = AcquisitionGeometry.create_Cone3D(source_position=[0,-souPos,0],
                                       detector_position=[detOffX,detPos,detOffY])\
    .set_panel(num_pixels=[numPixX,numPixY],pixel_size=pixSiz)\
    .set_angles(angles=np.linspace(0,endAng,numPro,endpoint=False))

# display geometry
show_geometry(ag)

# image geometry
ig = ag.get_ImageGeometry()

## Setup gVXR

In [None]:
# Create an OpenGL context
window_id = 0
opengl_major_version = 4
opengl_minor_version = 5
backend = "OPENGL"
visible = False
gvxr.createWindow(window_id, visible, backend, opengl_major_version, opengl_minor_version)

# Setup Detector
gvxr.setDetectorNumberOfPixels(numPixX, numPixY)
gvxr.setDetectorPixelSize(pixSiz, pixSiz, "mm")
gvxr.setDetectorUpVector(0, 0, 1)
gvxr.setDetectorPosition(0, detPos, 0, "mm")

# Setup Source
gvxr.setSourcePosition(0,-souPos,0,"mm")
gvxr.usePointSource()
gvxr.addEnergyBinToSpectrum(scaVol, "keV", 1)

# Setup Model
gvxr.removePolygonMeshesFromSceneGraph()
gvxr.loadMeshFile("cell", stlName, "mm")
gvxr.setCompound("cell",matNam)
gvxr.setDensity("cell", matDen, "g/cm3")

## Virtual Scan

In [None]:
# Define the number of projections, along with the angle step
angSte = endAng / numPro

# Pre-create our results array with the size of our detector
projSim = np.ndarray((numPro, numPixY, numPixX))

# Rotate our object by angSte for every projection, saving it to the results array.
for i in range(0, numPro):
    # Save current angular projection
    projSim[i] = gvxr.computeXRayImage()
    # Rotate models
    
    gvxr.rotateScene(angSte,0,0,1)

# Don't forget, we need to use flatfield normalisation on the radiographs to get the correct attenuation values!
# Because our flatfield and darkfield are perfect, all we need to do is divide by the total energy
projSim /= gvxr.getTotalEnergyWithDetectorResponse()

## Display setup with projection

In [None]:
plot=visualise(use_log=True,)
plot.display()

## Export projections

In [None]:
# bit depth
bitDep = 16

# bit depth multiplier
bitMul = (2**bitDep)-1

# admin
filNam = "proj"

# write data to raw - CIL method
imData = DataContainer(projSim, False)
rawWrite = RAWFileWriter(imData,filNam+"_CIL",'uint16')
rawWrite.write()

# write to raw - numpy
projSimNP = projSim * bitMul
projSimNP.astype('uint16').tofile(filNam+"_np.raw")

# Read projections

## Read from .raw

In [None]:
# use simulated or simulated read/write projection data
dataSource = "simRW"        # sim, simRW, exp
dataWriteMethod = "np"      # CIL, np

match(dataSource):
    case "sim":
        print('using sim data')
        proj = projSim
    
    case "exp":
        print('using experimental data')
        
    case "simRW":
        print('using simRW data')

        # get number of pixels
        num_pixels = numPixX * numPixY * numPro

        match(dataWriteMethod):
            case "CIL":
                # Read the .raw file as a numpy array - CIL data
                with open(filNam+"_CIL.raw", "rb") as f:
                    data = np.fromfile(f, dtype=np.uint16, count=num_pixels)

                # normalise data - need to use compression and scale data
                data = (data.astype(float) + 16918.20416302011) / 82453.20416302011
                
            case "np":
                # Read the .raw file as a numpy array - NP data
                with open(filNam+"_np.raw", "rb") as f:
                    data = np.fromfile(f, dtype=np.uint16, count=num_pixels)

                # normalise data
                data = data.astype(float) / bitMul

        # Reshape the data into a 3D array (num_images, height, width)
        projRead = data.reshape((numPro, numPixX, numPixY))

        # reallocate projections
        proj = projRead


## Compare read/write data to simulation

In [None]:
 # plot both and difference
projDiff = projSim[0]-projRead[0]
show2D([projSim[0], projRead[0], projDiff],title=["sim","sim_read_write_"+dataWriteMethod,"difference"],\
        cmap='inferno',num_cols=3,fix_range=[(0,1),(0,1),(projDiff.min(),projDiff.max())], size=(9,3))

plt.figure(figsize=(3,3))
plt.hist(projDiff.ravel(), bins=20)
plt.title("Histogram of difference")
plt.grid('both')
plt.show()

# Reconstruction

## Get sinogram slice

In [None]:
# colormap
cmapSino = "inferno"    # gray, inferno

# reduce data (64 -> 32 bit)
proj_red = proj.astype(np.float32)
aqData = AcquisitionData(proj_red,False,ag)

# get slice
dataSlice = aqData.get_slice(vertical='centre')

# log corecction
dataSlice.log(out=dataSlice)
dataSlice *= -1
show2D(dataSlice, cmap=cmapSino, size=(5,5))

# get details
datageom = dataSlice.geometry
imgeom = datageom.get_ImageGeometry()

## User Input

In [None]:
# recon - iterations
iterMeth = 'SIRT'    # CGLS, SIRT
numIter = 100

# colour scheme
cmapRecon = "inferno"    # gray, inferno

# mask radius (for CGLS)
maskRad = 1

## Setup and Run

In [None]:

# FBD - DOES NOT WORK ON CPU FOR ANYTHING BUT PARALLEL 2D
# fbp = FBP(None, datageom)
# fbp.set_input(dataSlice)
# reconObj = fbp.get_output()

# ITERATIVE
A = ProjectionOperator(imgeom, dataSlice.geometry, device='cpu')
match iterMeth:
    case 'CGLS':
        cgls = CGLS(initial=imgeom.allocate(), operator=A, data=dataSlice)
        cgls.run(numIter)
        reconObj = cgls
    case 'SIRT':    
        constraint = IndicatorBox(lower=0)
        sirt = SIRT(initial=imgeom.allocate(), operator=A, data=dataSlice, constraint=constraint)
        sirt.run(numIter)
        reconObj = sirt

# get recon data
recon = reconObj.solution
recon.apply_circular_mask(radius=maskRad, in_place=True)

## Plot

In [None]:
# objective function
plt.figure(figsize=(3,3))
plt.plot(reconObj.objective)
plt.gca().set_yscale('log')
plt.xlabel('Number of iterations')
plt.ylabel('Objective value')
plt.grid()

# reconstruction
show2D(recon, title='recon', cmap=cmapRecon, size=(5,5), origin='upper-left')