# Demonstration of basic BSREM implementation with SIRF

This notebook is largely based on the `display_and_projection` notebook from the SIRF-Exercises to simulate some data.
Then it proceeds with reconstructing it with BSREM.

Author: Kris Thielemans  

CCP SyneRBI Synergistic Image Reconstruction Framework (SIRF).  
Copyright 2015 - 2017 Rutherford Appleton Laboratory STFC.  
Copyright 2015 - 2018, 2021, 024 University College London.

This is software developed for the Collaborative Computational
Project in Synergistic Reconstruction for Biomedical Imaging
(http://www.ccpsynerbi.ac.uk/).

SPDX-License-Identifier: Apache-2.0

In [None]:
#%% make sure figures appears inline and animations works
%matplotlib widget

In [None]:
#%% Initial imports etc
import numpy
import matplotlib.pyplot as plt
import os
import sys
import sirf.STIR as STIR
from sirf.Utilities import examples_data_path
#from sirf_exercises import exercises_data_path

In [None]:
import sys
sys.path.insert(0,'/home/sirfuser/devel/SIRF-Contribs/src/Python/sirf')

In [None]:
import contrib.partitioner.partitioner as partitioner
from contrib.BSREM.BSREM import BSREM1
from contrib.BSREM.BSREM import BSREM2

In [None]:
# define the directory with input files for this notebook
data_path = os.path.join(examples_data_path('PET'), 'thorax_single_slice')

In [None]:
# Needed for get_subsets()
STIR.AcquisitionData.set_storage_scheme('memory')

In [None]:
# set-up redirection of STIR messages to files
_ = STIR.MessageRedirector('info.txt', 'warnings.txt', 'errors.txt')

In [None]:
#%% some handy function definitions
def plot_2d_image(idx,vol,title,clims=None,cmap="viridis"):
    """Customized version of subplot to plot 2D image"""
    plt.subplot(*idx)
    plt.imshow(vol,cmap=cmap)
    if not clims is None:
        plt.clim(clims)
    plt.colorbar(shrink=.4)
    plt.title(title)
    plt.axis("off")


In [None]:
#%% Read in images
image = STIR.ImageData(os.path.join(data_path, 'emission.hv'))*0.05
attn_image = STIR.ImageData(os.path.join(data_path, 'attenuation.hv'))
template = STIR.AcquisitionData(os.path.join(data_path, 'template_sinogram.hs'))

In [None]:
#%% save max for future displays
cmax = image.max()*.6

In [None]:
# create attenuation
acq_model_for_attn = STIR.AcquisitionModelUsingRayTracingMatrix()
asm_attn = STIR.AcquisitionSensitivityModel(attn_image, acq_model_for_attn)
asm_attn.set_up(template)
attn_factors = asm_attn.forward(template.get_uniform_copy(1))
asm_attn = STIR.AcquisitionSensitivityModel(attn_factors)

In [None]:
# fake background
background = template.get_uniform_copy(1)

In [None]:
# create acquisition model
acq_model = STIR.AcquisitionModelUsingRayTracingMatrix()
# we will increase the number of rays used for every Line-of-Response (LOR) as an example
# (it is not required for the exercise of course)
acq_model.set_num_tangential_LORs(5)
acq_model.set_acquisition_sensitivity(asm_attn)
# set-up
acq_model.set_up(template,image)

In [None]:
#%% simulate some data using forward projection
acquired_data=acq_model.forward(image)

In [None]:
acquired_data.max()

In [None]:
initial_image=image.get_uniform_copy(cmax / 4)
make_cylindrical_FOV(initial_image)
# display
im_slice = initial_image.dimensions()[0] // 2
#plt.figure()
#plot_2d_image([1,1,1],initial_image.as_array()[im_slice,:,:], 'initial image',[0,cmax])

In [None]:
num_subsets = 4
data,acq_models, obj_funs = partitioner.data_partition(acquired_data,background,attn_factors, num_subsets)

In [None]:
prior = STIR.RelativeDifferencePrior()
# evenly distribute prior over subsets
prior.set_penalisation_factor(1 / num_subsets);
prior.set_up(initial_image)
for f in obj_funs:
    f.set_prior(prior)

In [None]:
bsrem1 = BSREM1(data, obj_funs, initial=initial_image, initial_step_size=1, relaxation_eta=.01, update_objective_interval=5)
bsrem1.max_iteration=50
bsrem1.run()

In [None]:
bsrem2=BSREM2(data, acq_models, prior, initial=initial_image, initial_step_size=1, relaxation_eta=.01, update_objective_interval=5)
bsrem2.max_iteration=50
bsrem2.run()

In [None]:
plt.figure()
tmp1=bsrem1.x
tmp2=bsrem2.x
plot_2d_image([1,2,1], tmp1.as_array()[im_slice,:,:], 'image',[0,tmp1.max()])
plot_2d_image([1,2,2], tmp2.as_array()[im_slice,:,:], 'image',[0,tmp1.max()])

In [None]:
plt.close('all')