In [None]:
import os
import glob
import sirf.STIR
import numpy as np

from sirf.STIR import show_2D_array

In [None]:
msg_red = sirf.STIR.MessageRedirector('info.txt', 'warn.txt', 'errr.txt')

In [None]:
def create_sample_image(image, weighting = 1):
    '''fill the image with some simple geometric shapes.'''
    # density needs to be scaled down for attenuation image
        
    image.fill(0)
    # create a shape
    shape = sirf.STIR.EllipticCylinder()
    shape.set_length(400)
    shape.set_radii((80, 40))
    shape.set_origin((0, 40, 30))

    # add the shape to the image
    image.add_shape(shape, scale = weighting)
    
    # add another shape
    shape.set_radii((130, 110))
    shape.set_origin((0, 0, 0))
    image.add_shape(shape, scale = weighting*1.5)

    # add another shape
    shape.set_radii((30, 30))
    shape.set_origin((0, -50, -30))
    image.add_shape(shape, scale = weighting*1.5)

    # add another shape
    shape.set_origin((0, 30, -50))
    image.add_shape(shape, scale = weighting*0.75)

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV."""
    cyl_filter = sirf.STIR.TruncateToCylinderProcessor()
    cyl_filter.apply(image)
    return image

def add_noise(proj_data,noise_factor = 1):
    """Add Poission noise to acquisition data."""
    proj_data_arr = proj_data.as_array() / noise_factor
    # Data should be >=0 anyway, but add abs just to be safe
    proj_data_arr = np.abs(proj_data_arr)
    noisy_proj_data_arr = np.random.poisson(proj_data_arr).astype('float32');
    noisy_proj_data = proj_data.clone()
    noisy_proj_data.fill(noisy_proj_data_arr);
    return noisy_proj_data

In [None]:
sino = sirf.STIR.AcquisitionData('data/SPECT/template_sinogram.hs')

In [None]:
# create ground truth image
image = sino.create_uniform_image()
create_sample_image(image)
image = image.zoom_image(zooms=(0.5, 1.0, 1.0)) #required for now because SPECT is 360 degree acquisiton

# create attenuation image
uMap = sino.create_uniform_image()
create_sample_image(uMap, weighting = 0.1)
uMap = uMap.zoom_image(zooms=(0.5, 1.0, 1.0))

In [None]:
# show the ground truth image
image_array = image.as_array()
show_2D_array('Phantom image', image_array[0,:,:])

# show the attenuation image
uMap_array = uMap.as_array()
show_2D_array('Attenuation image', uMap_array[0,:,:])

In [None]:
# select acquisition model that implements the geometric
# forward projection by a ray tracing matrix multiplication
acq_model_matrix = sirf.STIR.SPECTUBMatrix();
acq_model_matrix.set_attenuation_image(uMap) # add attenuation
acq_model_matrix.set_resolution_model(0.5,0.5,full_3D=False) #resolution modelling
acq_model = sirf.STIR.AcquisitionModelUsingMatrix(acq_model_matrix)

In [None]:
print('projecting image...')
# project the image to obtain simulated acquisition data
# data from raw_data_file is used as a template
acq_model.set_up(sino, image)
simulated_data = sino.get_uniform_copy()
acq_model.forward(image, 0, 1, simulated_data)

  # create noisy data
noisy_data = simulated_data.clone()
noisy_data_as_array = np.random.poisson(simulated_data.as_array())
noisy_data.fill(noisy_data_as_array)

# show simulated acquisition data
simulated_data_as_array = simulated_data.as_array()
show_2D_array('Forward projection', simulated_data_as_array[0, 0,:,:])
show_2D_array('Forward projection with added noise', noisy_data_as_array[0, 0,:,:])

In [None]:
# create objective function
obj_fun = sirf.STIR.make_Poisson_loglikelihood(noisy_data)
obj_fun.set_acquisition_model(acq_model)

# create OSEM reconstructor object
num_subsets = 21 # number of subsets for OSEM reconstruction
num_subiters = 42 #number of subiterations (i.e two full iterations)
OSEM_reconstructor = sirf.STIR.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)
OSEM_reconstructor.set_num_subsets(num_subsets)
OSEM_reconstructor.set_num_subiterations(num_subiters)

In [None]:
# create initialisation image and set up reconstructor
init_image = make_cylindrical_FOV(image.get_uniform_copy(1))
OSEM_reconstructor.set_up(init_image)

In [None]:
# Reconstruct and show reconstructed image
OSEM_reconstructor.reconstruct(init_image)
out_image = OSEM_reconstructor.get_current_estimate()
out_image_array = out_image.as_array()
show_2D_array('Reconstructed image', out_image_array[0,:,:])

In [None]:
#%% delete temporary files
wdpath = os.getcwd()
for filename in glob.glob(os.path.join(wdpath, "tmp*")):
    os.remove(filename) 