# MAPEM de Pierro algorithm

Authors: Sam Ellis, Kris Thielemans, Richard Brown
First version: 22nd of October 2019
Second version: 27th of October 2019

CCP PETMR Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2019  University College London.
Copyright 2019  King's College London

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# MAPEM
If you have completed the PET component, you will have implemented a version of MAPEM with a quadratic prior. If you haven't, you should probably have a look at that notebook before this one: [../PET/MAPEM.ipynb](../PET/MAPEM.ipynb).

The following chunk of code is copied and pasted directly from that notebook as a starting point. The difference here is that we are going to use an anatomical prior instead of the quadratic prior used previously.

First, run the code chunk to get the results using the quadratic prior...

In [None]:
#%% make sure figures appears inline and animations works
%matplotlib notebook

In [None]:
import sirf.STIR as pet
import matplotlib.pyplot as plt
import os
import numpy as np
from numpy.linalg import norm
import sirf.contrib.kcl.Prior as pr
import shutil

# Get to correct directory
os.chdir("/data")

#%% copy files to working folder and change directory to where the output files are
shutil.rmtree('working_folder/dePierro',True)
shutil.copytree('SRS_data_exhale','working_folder/dePierro')
os.chdir('working_folder/dePierro')

In [None]:
#%% some handy function definitions
def imshow(image, limits=None, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image)
    if limits is None:
        limits = [image.min(), image.max()]
                
    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off')
    return bitmap

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)

### Load some data and set some values

In [None]:
sino = pet.AcquisitionData('noisy_sino.hs')
atten = pet.ImageData('PET_attenuation.nii')
num_subsets = 21
num_subiters = 42

# Anatomical image
anatomical = pet.ImageData('MR_T1.nii') # could be MR_T2.nii or MR_PD.nii
anatomical_arr = anatomical.as_array()

#%%  create initial image
init_image=atten.get_uniform_copy(atten.as_array().max()*.1)
make_cylindrical_FOV(init_image)

## Set up the acquisition model and objective function

In [None]:
#%% create acquisition model
am = pet.AcquisitionModelUsingRayTracingMatrix()
am.set_num_tangential_LORs(5)

# Set up sensitivity due to attenuation
asm_attn = pet.AcquisitionSensitivityModel(atten, am)
asm_attn.set_up(sino)
bin_eff = pet.AcquisitionData(sino)
bin_eff.fill(1.0)
asm_attn.unnormalise(bin_eff)
asm_attn = pet.AcquisitionSensitivityModel(bin_eff)

# Set sensitivity of the model and set up
am.set_acquisition_sensitivity(asm_attn)
am.set_up(sino,atten);

#%% create objective function
obj_fun = pet.make_Poisson_loglikelihood(sino)
obj_fun.set_acquisition_model(am)

## Normal OSEM

In [None]:
#%% create OSEM reconstructor
OSEM_reconstructor = pet.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)
OSEM_reconstructor.set_num_subsets(num_subsets)
OSEM_reconstructor.set_num_subiterations(num_subiters)

reconstructed_image = init_image.clone()
OSEM_reconstructor.set_up(reconstructed_image)
OSEM_reconstructor.reconstruct(reconstructed_image)

### Display OSEM

In [None]:
recon_arr = reconstructed_image.as_array()
plt.subplot(1,3,1);
imshow(recon_arr[:,:,60],[0, 0.1]);
plt.subplot(1,3,2);
imshow(recon_arr[:,60,:],[0, 0.1]);
plt.subplot(1,3,3);
imshow(recon_arr[60,:,:],[0, 0.1]);

## Now use dePierro MAPEM

In [None]:
def dePierroUpdate(imageEM, imageReg, beta, sensImg):
    
    delta = 1e-6*abs(sensImg).max()
    sensImg[sensImg < delta] = delta # avoid division by zero
    beta_j = beta/sensImg
    
    b_j = 1 - beta_j*imageReg
    
    numer = (2*imageEM)
    denom = ((b_j**2 + 4*beta_j*imageEM)**0.5 + b_j)
    
    delta = 1e-6*abs(denom).max()
    denom[denom < delta] = delta # avoid division by zero
    
    imageUpdated = numer/denom
    
    return imageUpdated


def dePierroReg(image,weights):
    
    # get size and vectorise image for indexing 
    imSize = image.shape
    imageVec = image.reshape(-1,1,order='F').flatten('F')
    
    # get the neigbourhoods of each voxel
    weightsSize = weights.shape
    w = int(round(weightsSize[1]**(1.0/3))) # side length of neighbourhood
    nhoodInd    = neighbourExtract(imSize,w)
    nhoodIndVec = nhoodInd.reshape(-1,1,order='F').flatten('F')
    
    # retrieve voxel intensities for neighbourhoods 
    resultVec = np.float32(imageVec[nhoodIndVec])
    result = resultVec.reshape(nhoodInd.shape,order='F')
    
    # compute xreg
    try:
        imageReg = 0.5*np.sum(weights*(result + np.float32(image).reshape(-1,1,order='F')),axis=1)
    except:
        tmpVar = 1;    
    imageReg = imageReg.reshape(imSize,order='F')
    
    return imageReg


def neighbourExtract(imageSize,w):
    # Adapted from Prior class        
    n = imageSize[0]
    m = imageSize[1]
    h = imageSize[2]
    wlen = 2*np.floor(w/2)
    widx = xidx = yidx = np.arange(-wlen/2,wlen/2+1)

    if h==1:
        zidx = [0]
        nN = w*w
    else:
        zidx = widx
        nN = w*w*w
        
    Y,X,Z = np.meshgrid(np.arange(0,m), np.arange(0,n), np.arange(0,h))                
    N = np.zeros([n*m*h, nN],dtype='int32')
    l = 0
    for x in xidx:
        Xnew = setBoundary(X + x,n)
        for y in yidx:
            Ynew = setBoundary(Y + y,m)
            for z in zidx:
                Znew = setBoundary(Z + z,h)
                N[:,l] = ((Xnew + (Ynew)*n + (Znew)*n*m)).reshape(-1,1).flatten('F')
                l += 1
    return N


def setBoundary(X,n):
    # Boundary conditions for neighbourExtract
    # Adapted from Prior class
    idx = X<0
    X[idx] = X[idx] + n
    idx = X>n-1
    X[idx] = X[idx] - n
    return X.flatten('F')

### MAPEM input data and parameters

In [None]:
sensitivity_image = obj_fun.get_subset_sensitivity(0)
de_pierro_reconstructed_image = init_image.clone()
beta = 500

In [None]:
# create a Prior for computing Bowsher weights
myPrior = pr.Prior(anatomical_arr.shape)
weights = myPrior.BowshserWeights(anatomical_arr,7)
weights = np.float32(weights/7.0)
if (np.abs(np.sum(weights,axis=1)-1)>1.0e-6).any():
    raise ValueError("Weights should sum to 1 for each voxel")

# Create OSEM reconstructor
print('Setting up reconstruction object')
OSEM_reconstructor = pet.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)                             
OSEM_reconstructor.set_num_subsets(num_subsets)
OSEM_reconstructor.set_num_subiterations(num_subiters)
OSEM_reconstructor.set_up(de_pierro_reconstructed_image)

In [None]:
for iter in range(1,num_subiters + 1):
    print('\n------------- Subiteration %d' % iter) 
    
    # Calculate imageReg and return as an array
    imageReg_array = dePierroReg(de_pierro_reconstructed_image.as_array(),weights)
    
    # OSEM image update
    OSEM_reconstructor.update(de_pierro_reconstructed_image)
    imageEM_array = de_pierro_reconstructed_image.as_array()

    # Final image update
    imageUpdated_array = dePierroUpdate \
        (imageEM_array, imageReg_array, beta, sensitivity_image.as_array())

    # Fill image and truncate to cylindrical field of view        
    de_pierro_reconstructed_image.fill(imageUpdated_array)
    make_cylindrical_FOV(de_pierro_reconstructed_image)

In [None]:
plt.figure()

depier_arr = de_pierro_reconstructed_image.as_array()

plt.subplot(2,3,1);
imshow(recon_arr[:,:,60],[0, 0.05]);
plt.subplot(2,3,2);
imshow(recon_arr[:,60,:],[0, 0.05]);
plt.subplot(2,3,3);
imshow(recon_arr[60,:,:],[0, 0.05]);
plt.subplot(2,3,4);
imshow(depier_arr[:,:,60],[0, 0.05]);
plt.subplot(2,3,5);
imshow(depier_arr[:,60,:],[0, 0.05]);
plt.subplot(2,3,6);
imshow(depier_arr[60,:,:],[0, 0.05]);

# Writing MAP-EM with SIRF
This notebook provides a start for writing MAP-EM yourself. It gives you basic lines for setting up a simulation, a simplified implementation of MAP-EM (with one known problem), and a few lines to help plot results.

You are strongly advised to complete (at least) the first half of the ML_reconstruction exercise before starting with this one. For instance, the simulation and plotting code here is taken directly from the `ML_reconstruction` code, but with fewer comments.

Author: Kris Thielemans  
First version: 19th of May 2018

CCP PETMR Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2017 - 2018 University College London.

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
     http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Initial set-up

In [None]:
#%% Initial imports etc
import numpy
from numpy.linalg import norm
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import os
import sys
import shutil
#import scipy
#from scipy import optimize
import sirf.STIR as pet
from sirf.Utilities import examples_data_path
# plotting settings
plt.ion() # interactive 'on' such that plots appear during loops

%matplotlib notebook

In [None]:
#%% some handy function definitions
def imshow(image, limits, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap=plt.imshow(image)
    if len(limits)==0:
        limits=[image.min(),image.max()]
                
    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off');
    return bitmap

def make_positive(image_array):
    """truncate any negatives to zero"""
    image_array[image_array<0] = 0;
    return image_array;

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)



# Create some sample data

In [None]:
#%% go to directory with input files
# adapt this path to your situation (or start everything in the relevant directory)
os.chdir(examples_data_path('PET'))


In [None]:
#%% copy files to working folder and change directory to where the output files are
shutil.rmtree('working_folder/thorax_single_slice',True)
shutil.copytree('thorax_single_slice','working_folder/thorax_single_slice')
os.chdir('working_folder/thorax_single_slice')

In [None]:
#%% Read in image
image = pet.ImageData('emission.hv');
# save max for future displays
image_array=image.as_array()
cmax = image_array.max()*.6

In [None]:
#%% create acquisition model
am = pet.AcquisitionModelUsingRayTracingMatrix()
templ = pet.AcquisitionData('template_sinogram.hs')
am.set_up(templ,image); 

In [None]:
#%% simulate some data using forward projection
acquired_data=am.forward(image)

# create OSEM reconstructor

In [None]:
#%% create objective function
obj_fun = pet.make_Poisson_loglikelihood(acquired_data)
obj_fun.set_acquisition_model(am)

In [None]:
#%% create OSEM reconstructor
OSEM_reconstructor = pet.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)
OSEM_reconstructor.set_num_subsets(1)
num_subiters=10;
OSEM_reconstructor.set_num_subiterations(num_subiters)

In [None]:
#%%  create initial image
init_image=image.get_uniform_copy(cmax/4)
make_cylindrical_FOV(init_image)

In [None]:
#%% initialise
OSEM_reconstructor.set_up(init_image)

# Implement MAP-EM!
Actually, we will use implement MAP-OSEM as that's easy to do.

The lines below (almost) implement MAP-OSEM with a prior which just smooths along the horizontal direction. If you run it, you should see some warnings, and get the wrong image. Try to fix that. Once you did that, either evaluate, and/or extend to 2D, and/or increase the speed of your Python code.

We will use the algorithm described in

Guobao Wang and Jinyi Qi,  
Penalized Likelihood PET Image Reconstruction using Patch-based Edge-preserving Regularization  
IEEE Trans Med Imaging. 2012 Dec; 31(12): 2194–2204.   
[doi:  10.1109/TMI.2012.2211378](https://dx.doi.org/10.1109%2FTMI.2012.2211378)

However, we will not used patches here, but just a simple quadratic prior.

We will occasionally use the notation of the paper (but not consistently, mainly to avoid conflicts between a coordinate *x* and the image *x*).

Please note: this code was written to be simple. It is not terribly safe, quite slow, and doesn't use best programming practices (it uses global variables, has no docstrings, no testing of validate of input, etc).

## define weights as an array
Note that code further below assumes that this has 3 elements


In [None]:
w=numpy.array([1.,0.,1.])
# normalise to have sum 1
w/=w.sum()

## define a function that computes $x_{reg}$
For a simple quadratic prior (with normalised weights)

$x^{reg}_j={1\over 2}\sum_{k\in N_j} w_{jk}(x_j+x_k)$

In [None]:
def compute_xreg(image_array):
    sizes=image_array.shape
    image_reg= image_array*0 # make a copy first. Will then change values
    for z in range(0,sizes[0]):
        for y in range(0,sizes[1]):
            for x in range(1,sizes[2]-1): # ignore first and last pixel for simplicity
                for dx in (-1,0,1):
                    image_reg[z,y,x] += w[dx+1]/2*(image_array[z,y,x]+image_array[z,y,x+dx])
            
    return image_reg

## define a function that computes the MAP-EM update, given $x_{EM}$ and $x_{reg}$

$x^{new}={2 x^{EM}  \over \sqrt{(1-\beta x^{reg})^2 + 4 \beta x^{EM}} + (1-\beta x^{reg})}$

In [None]:
def compute_MAPEM_update(xEM,xreg, beta):
    return 2*xEM/(numpy.sqrt((1 - beta*xreg)**2 + 4*beta*xEM) + (1 - beta*xreg))

## write MAP-OSEM and test it
We have an existing EMML SIRF reconstruction that does the hard work, i.e. compute $x_{EM}$, so let's use that!

In [None]:
#%% useful for timing (using `time.time()`)
import time

In [None]:
#%% define a function for plotting images and the updates
# This is the same function as in `ML_reconstruction`
def plot_progress(all_images, title, subiterations = []):
    if len(subiterations)==0:
        num_subiters = all_images[0].shape[0]-1;
        subiterations = range(1, num_subiters+1);
    num_rows = len(all_images);
    for iter in subiterations:
        plt.figure(iter)
        for r in range(num_rows):
            plt.subplot(num_rows,2,2*r+1)
            imshow(all_images[r][iter,slice,:,:], [0,cmax], '%s at %d' % (title[r],  iter))
            plt.subplot(num_rows,2,2*r+2)
            imshow(all_images[r][iter,slice,:,:]-all_images[r][iter-1,slice,:,:],[-cmax*.1,cmax*.1], 'update')
        plt.show()

Finally, let's implement MAP-OSEM!

In [None]:
#%% do a loop, saving images as we go along
num_subiters = 16;
beta =1;
current_image = init_image.clone()
all_images = numpy.ndarray(shape=(num_subiters+1,) + current_image.as_array().shape );
all_images[0,:,:,:] =  current_image.as_array();
for iter in range(1, num_subiters+1):
    image_reg= compute_xreg(current_image.as_array()) # compute xreg
    OSEM_reconstructor.update(current_image); # compute EM update
    image_EM=current_image.as_array() # get xEM as a numpy array
    updated = compute_MAPEM_update(image_EM, image_reg, beta) # compute new update
    current_image.fill(updated) # store for next iteration
    print(iter)
    all_images[iter,:,:,:] =  updated; # save for plotting later on

In [None]:
#%% display
slice = 0
imshow(current_image.as_array()[slice,:,:],[])

In [None]:
#%% now call this function to see how we went along
subiterations = (1,2,4,8,16);
plot_progress([all_images], ['MAP-OSEM'],subiterations)
    