# Demonstration of the Hybrid Kernelised Expection Maximisation (HKEM) reconstruction with SIRF
This demonstration shows how to use HKEM and investigate the role of each kernel parameter in edge preservation and noise suppression.

While this notebook doesn't do a complete parameter search, it does do quite a few reconstructions. You could therefore first run it through completely, and then come back to check the results. Note that if you do that, some of the plots might not show. You can then just re-run those cells.

N.B.: You need to have run the [BrainWeb](./BrainWeb.ipynb) notebook first to generate the data!

Authors: Daniel Deidda, Kris Thielemans, Evgueni Ovtchinnikov, Richard Brown  
First version: 30th of September 2019  
Second version: 6th of November 2019  
Thierd version: June 2021

CCP SyneRBI Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2019, 2021  National Physical Laboratory  
Copyright 2019  Rutherford Appleton Laboratory STFC  
Copyright 2019, 2021  University College London

This is software developed for the Collaborative Computational
Project in Synergistic Reconstruction for Biomedical Imaging.
(http://www.synerbi.ac.uk/).

SPDX-License-Identifier: Apache-2.0

## HKEM brief description
The Kernel Expectation Maximisation (KEM) method was suggested in  
Wang, Guobao, and Jinyi Qi. ‘PET Image Reconstruction Using Kernel Method’. IEEE Transactions on Medical Imaging 34, no. 1 (January 2015): 61–71. https://doi.org/10.1109/TMI.2014.2343916.

The main idea was to use "kernels" (constructed based on another image such as an MR image) to construct "basis functions" for the PET reconstruction. The reconstruction estimates the PET image as a linear combination of these kernels.

One of the potential problems with KEM is what happens if there are unique features in the PET image, which are not present in the "guidance" (i.e. the MR image). If the MR-derived kernels are too "wide", there is a danger that the PET-unique features are suppressed.

To overcome this problem, Deidda *et al.* developed the Hybrid KEM method, see  
Deidda, Daniel, Nicolas A. Karakatsanis, Philip M. Robson, Yu-Jung Tsai, Nikos Efthimiou, Kris Thielemans, Zahi A. Fayad, Robert G. Aykroyd, and Charalampos Tsoumpas. ‘Hybrid PET-MR List-Mode Kernelized Expectation Maximization Reconstruction’. Inverse Problems 35, no. 4 (March 2019): 044001. https://doi.org/10.1088/1361-6420/ab013f.

The main idea here is to compute new kernels at every image update. These kernels are "hybrids" between the MR and (current) PET image. This way, if there is a PET feature, it will gradually influence the kernels and therefore no longer suppress it. (Or at least, that's what the authors hope!)

Implementing HKEM in SIRF would be rather involved. Luckily, Daniel Deidda did the hard work for you and implemented it in STIR (with help from a few others). `sirf.STIR.KOSMAPOSL` wraps this STIR implementation.

## Initial set-up

In [None]:
#%% make sure figures appears inline and animations works
%matplotlib widget

# Setup the working directory for the notebook
import notebook_setup
from sirf_exercises import cd_to_working_dir
cd_to_working_dir('Synergistic', 'HKEM')

In [None]:
#%% Initial imports etc
import numpy
from numpy.linalg import norm
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import os
import sys
import shutil
import string
#import scipy
#from scipy import optimize
import sirf.STIR as pet
from sirf_exercises import exercises_data_path

brainweb_sim_data_path = exercises_data_path('working_folder', 'Synergistic', 'BrainWeb')

In [None]:
# set-up redirection of STIR messages to files
msg_red = pet.MessageRedirector('info.txt', 'warnings.txt', 'errors.txt')

In [None]:
#%% some handy function definitions
def imshow_hot(image, limits, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image, cmap="hot")
    if len(limits)==0:
        limits = [image.min(), image.max()]

    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.3)
    plt.axis('off')
    return bitmap

def imshow(image, limits, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image)
    if len(limits)==0:
        limits = [image.min(), image.max()]

    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.3)
    plt.axis('off')
    return bitmap

def imshow_gray(image, limits, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image, cmap="gray")
    if len(limits)==0:
        limits = [image.min(), image.max()]

    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.3)
    plt.axis('off')
    return bitmap

def make_positive(image_array):
    """truncate any negatives to zero"""
    image_array[image_array<0] = 0
    return image_array

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)

### Load some data and set some values

In [None]:
# Load the data we generated previously in BrainWeb.ipynb
full_sino = pet.AcquisitionData(os.path.join(brainweb_sim_data_path, 'FDG_tumour_sino_noisy.hs'))

atten = pet.ImageData(os.path.join(brainweb_sim_data_path, 'uMap_small.hv'))

# Anatomical image
anatomical = pet.ImageData(os.path.join(brainweb_sim_data_path, 'T2_small.hv')).abs()

#%%  create initial image
init_image=anatomical.get_uniform_copy(1)
make_cylindrical_FOV(init_image)

In [None]:
image = pet.ImageData(os.path.join(brainweb_sim_data_path, 'FDG_tumour.hv'))
image_array = image.as_array()
#%% save max for future displays
cmax = image_array.max()*.6

In [None]:
## Show anatomical image and true image
anatomical_array=anatomical.as_array()
atten_array=atten.as_array()
im_slice = 62 #atten_array.shape[0]//2

plt.figure()
plt.subplot(1,2,1)
imshow_gray(anatomical_array[im_slice,:,:,], [0,220],'Anatomical image')
plt.subplot(1,2,2)
imshow_hot(image_array[im_slice,:,:,], [0,cmax*2],'True image')

## Set up the acquisition model and objective function

We will first use the `rebin` functionality to create smaller acquisition data to speed up calculations.
The line below will keep only "segment" 0.

If you want to make things faster you can rebin your data by compressing axial and view bins.
Of course, this will affect the quality of the reconstructed images somewhat.

If you have enough computational power you can try setting `max_in_segment_num_to_process` higher (or even do `sino=full_sino`).

In [None]:
help(full_sino.rebin)

In [None]:
sino = full_sino.rebin(1, num_views_to_combine=1,max_in_segment_num_to_process=0,do_normalisation=False)

A copy of the function in the BrainWeb notebook, except that we set the number of LORs to tracer per bin a bit lower (to avoid using too much of an inverse crime)

In [None]:
def get_acquisition_model(uMap, templ_sino, global_factor=.01):
    '''create an acq_model given a mu-map and a global sensitivity factor
    
    The default global_factor is chosen such that the mean values of the
    forward projected BrainWeb data have a reasonable magnitude
    '''
    #%% create acquisition model
    am = pet.AcquisitionModelUsingRayTracingMatrix()
    am.set_num_tangential_LORs(5)

    # Set up sensitivity due to attenuation
    asm_attn = pet.AcquisitionSensitivityModel(uMap, am)
    asm_attn.set_up(templ_sino)
    bin_eff = templ_sino.get_uniform_copy(global_factor)
    print('applying attenuation (please wait, may take a while)...')
    asm_attn.unnormalise(bin_eff)
    asm = pet.AcquisitionSensitivityModel(bin_eff)

    am.set_acquisition_sensitivity(asm)
    return am

In [None]:
am=get_acquisition_model(atten,sino)

In [None]:
#%% create objective function
obj_fun = pet.make_Poisson_loglikelihood(sino)
obj_fun.set_acquisition_model(am)

##  create KOSMAPOSL reconstructor
`sirf.STIR.KOSMAPOSL` implements the Ordered Subsets HKEM (if you do not add an additional prior).

In this section we define all parameters

In [None]:
recon = pet.KOSMAPOSLReconstructor()
recon.set_objective_function(obj_fun)

recon.set_anatomical_prior(anatomical)
recon.set_num_non_zero_features(1)
recon.set_num_subsets(21)
recon.set_num_subiterations(42)

## Study parameter sigma_m (MR edge preservation) 


In [None]:
#%% reconstruct the image 
H1m_reconstructed_image = [] 

#fix other parameters
recon.set_num_neighbours(3)
recon.set_sigma_p(0.2)
recon.set_sigma_dm(5.0)
recon.set_sigma_dp(5.0)

sigma_m={0.05, 0.2, 1}
ii=0
for i in sigma_m:

    H1m_reconstructed_image.append(init_image.clone())
    
    recon.set_sigma_m(i)

#   set up the reconstructor
    recon.set_hybrid(True)
    recon.set_up(H1m_reconstructed_image[ii])
    recon.reconstruct(H1m_reconstructed_image[ii])


    ii=ii+1

In [None]:
# %% bitmap display of images
# definea lists
H1m_reconstructed_array = []
H1m_error_array = []

ii=0

for i in sigma_m:

    H1m_reconstructed_array.append(H1m_reconstructed_image[ii].as_array())
    H1m_error_array.append(image_array - H1m_reconstructed_array[ii])
    
  
    j="{}".format(i)
    plt.figure()
    plt.subplot(1,3,1)
    imshow_hot(image_array[im_slice,:,:,], [0,cmax*2],'True image')
    plt.subplot(1,3,2)
    imshow_hot(H1m_reconstructed_array[ii][im_slice,:,:,], [0,cmax*2], 'sigma_m='+j)
    plt.subplot(1,3,3)
    imshow(H1m_error_array[ii][im_slice,:,:,], [-cmax*0.5,cmax*0.5], 'HKEM error')

    ii=ii+1;

## Study parameter sigma_p (PET edge preservation)


In [None]:
#%% reconstruct the image 
H1p_reconstructed_image = [] 

#fix other parameters
recon.set_num_neighbours(3)
recon.set_sigma_m(0.2)
recon.set_sigma_dm(5.0)
recon.set_sigma_dp(5.0)

sigma_p={0.05, 2}
ii=0
for i in sigma_p:

    H1p_reconstructed_image.append(init_image.clone())

    recon.set_sigma_p(i)
#   set up the reconstructor
    recon.set_hybrid(True)
    recon.set_up(H1p_reconstructed_image[ii])
    recon.reconstruct(H1p_reconstructed_image[ii])
    ii=ii+1;

In [None]:
H1p_reconstructed_image.append(H1p_reconstructed_image[1])
#%% bitmap display of images
# define lists
H1p_reconstructed_array = []
H1p_error_array = []
ii=0
sigma_p={0.05, 2,0.2}
for i in sigma_p:

    j="{}".format(i)

    H1p_reconstructed_array.append(H1p_reconstructed_image[ii].as_array())

    H1p_error_array.append(image_array - H1p_reconstructed_array[ii])

    plt.figure()
    plt.subplot(1,3,1)
    imshow_hot(image_array[im_slice,:,:,], [0,cmax*2],'True image')
    plt.subplot(1,3,2)
    imshow_hot(H1p_reconstructed_array[ii][im_slice,:,:,], [0,cmax*2], 'sigma_p='+j)
    plt.subplot(1,3,3)
    imshow(H1p_error_array[ii][im_slice,:,:,], [-cmax*0.5,cmax*0.5], 'HKEM error')

    ii=ii+1

## Study parameter sigma_d (smoothing, depends on the voxel size)


In [None]:
#%% reconstruct the image 
H1d_reconstructed_image = [] 

#fix other parameters
recon.set_num_neighbours(3)
recon.set_sigma_m(0.2)
recon.set_sigma_p(0.2)

sigma_dm={0.5, 1}
ii=0
for i in sigma_dm:

    H1d_reconstructed_image.append(init_image.clone())

    recon.set_sigma_dp(i)
    recon.set_sigma_dm(i)

   #   set up the reconstructor
    recon.set_hybrid(True)
    recon.set_up(H1d_reconstructed_image[ii])
    recon.reconstruct(H1d_reconstructed_image[ii])

    ii=ii+1

In [None]:
H1d_reconstructed_image.append(H1m_reconstructed_image[1])
#%% bitmap display of images
# define lists
H1d_reconstructed_array = []
H1d_error_array = []
ii=0
sigma_dm={0.5, 1, 5}
for i in sigma_dm:

    j="{}".format(i)

    H1d_reconstructed_array.append(H1d_reconstructed_image[ii].as_array())

#   anatomical_image_array = anatomical_image.as_array()
    H1d_error_array.append(image_array - H1d_reconstructed_array[ii])

    plt.figure()
    plt.subplot(1,3,1)
    imshow_hot(image_array[im_slice,:,:,], [0,cmax*2],'True image')
    plt.subplot(1,3,2)
    imshow_hot(H1d_reconstructed_array[ii][im_slice,:,:,], [0,cmax*2], 'sigma_dm='+j)
    plt.subplot(1,3,3)
    imshow(H1d_error_array[ii][im_slice,:,:,], [-cmax*0.5,cmax*0.5], 'HKEM error')

    ii=ii+1
    
plt.show()

## Study parameter "neighbourhood size", n
try to rebin the data even more if it is too slow

In [None]:
#%% reconstruct the image 
H1n_reconstructed_image = [] 

#fix other parameters
recon.set_sigma_m(0.2)
recon.set_sigma_p(0.2)
recon.set_sigma_dm(5.0)
recon.set_sigma_dp(5.0)

n={1, 5}
ii=0
for i in n:

    H1n_reconstructed_image.append(init_image.clone())

    recon.set_num_neighbours(i)

#   set up the reconstructor
    recon.set_hybrid(True)
    recon.set_up(H1n_reconstructed_image[ii])
    recon.reconstruct(H1n_reconstructed_image[ii])

    ii=ii+1

In [None]:
# add the n=3 case which we've done above
# careful: zthe list will have `n` in a strange order (see below)
H1n_reconstructed_image.append(H1m_reconstructed_image[1])

In [None]:
#%% bitmap display of images
# define lists

H1n_reconstructed_array = []
H1n_error_array = []
n=[1, 5, 3]
for ii in range(len(n)):

    i=n[ii]
    j="{}".format(i)
    
    H1n_reconstructed_array.append(H1n_reconstructed_image[ii].as_array())

#   anatomical_image_array = anatomical_image.as_array()
    H1n_error_array.append(image_array - H1n_reconstructed_array[ii])

    plt.figure()
    plt.subplot(1,3,1)
    imshow_hot(image_array[im_slice,:,:,], [0,cmax*2],'True image')
    plt.subplot(1,3,2)
    imshow_hot(H1n_reconstructed_array[ii][im_slice,:,:,], [0,cmax*2], 'HKEM, N='+j)
    plt.subplot(1,3,3)
    imshow(H1n_error_array[ii][im_slice,:,:,], [-cmax*0.5,cmax*0.5], 'HKEM error')

plt.show()

## Reconstruct  with KEM
HKEM reduces to KEM  when setting hybrid to `False`

In [None]:
#KEM image is: H0_reconstructed_array

H0_reconstructed_image = [] 

#fix other parameters
recon.set_sigma_m(0.2)
recon.set_sigma_p(0.2)
recon.set_sigma_dm(5.0)
recon.set_sigma_dp(5.0)

H0_reconstructed_image.append(init_image.clone())
recon.set_num_neighbours(5)

#   set up the reconstructor
recon.set_hybrid(False)
recon.set_up(H0_reconstructed_image[0])
recon.reconstruct(H0_reconstructed_image[0])


## Compare HKEM and KEM 


In [None]:
H0_reconstructed_array= H0_reconstructed_image[0].as_array()
plt.figure()
plt.subplot(1,2,1)
imshow_hot(H0_reconstructed_array[im_slice,:,:,], [0,cmax*2.], 'KEM')
plt.subplot(1,2,2)
imshow_hot(H1n_reconstructed_array[2][im_slice,:,:,], [0,cmax*2.], 'HKEM')

plt.show()

## Suggested exercises

### 1) what difference can you see when you change each parameter? and between HKEM and KEM?
The above plots might give you some feeling for how the different parameters change the images. It would be better to do some quantitative measures such as a RMSE with the ground truth, or some ROI values (and in particular the tumour).

### 2) what happens if there is misalignment between Anatomical image and PET image?

There can be motion between the PET and anatomical images. If this misalignment is too large, clearly it will be disadvantageous to use it for "guidance". The effect was studied for HKEM in  
Deidda, Daniel, N. A. Karakatsanis, Philip M. Robson, Nikos Efthimiou, Zahi A. Fayad, Robert G. Aykroyd, and Charalampos Tsoumpas. ‘Effect of PET-MR Inconsistency in the Kernel Image Reconstruction Method’. IEEE Transactions on Radiation and Plasma Medical Sciences 3, no. 4 (July 2019): 400–409. https://doi.org/10.1109/TRPMS.2018.2884176.

You can try to reproduce some of that investigation using the following steps:
- you can create misalignment by shifting or rotation the anatomical image like in the [BrainWeb notebook ](BrainWeb.ipynb)
- set `KOSMAPOSL` reconstructor to use the new anatomical image and reconstruct as above. 
- plot images, or investigate ROI values, e.g. in the tumour.

### 3) Try to resolve the misalignment before running HKEM
Run an OSEM reconstruction and align the anatomical image from the previous exercise with the OSEM image. You can use `sirf.Reg` for this.

### Similar exercises can be done using other algorithms that use anatomical information.
You could have a look at the [MAPEM_Bowsher notebook](MAPEM_Bowsher.ipynb) and repeat these exercises and compare results.