# MAPEM de Pierro algorithm

Authors: Kris Thielemans, Sam Ellis, Richard Brown  
First version: 22nd of October 2019  
Second version: 27th of October 2019

CCP PETMR Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2019  University College London  
Copyright 2019  King's College London  

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Start with [PET/MAPEM](../PET/MAPEM.ipynb)...

If you've already completed the PET component, you will have implemented a version of MAPEM. If you haven't, you'll probably want to give that a go first!

This example extends upon the quadratic prior used in that notebook to use an anatomical prior.

# All the normal imports and handy functions

In [None]:
#%% Initial imports etc
import numpy
from numpy.linalg import norm
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import os
import sys
import shutil
import time
import numba
#import scipy
#from scipy import optimize
import sirf.STIR as pet
from sirf.Utilities import examples_data_path
# plotting settings
plt.ion() # interactive 'on' such that plots appear during loops

%matplotlib notebook

#%% some handy function definitions
def imshow(image, limits=None, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image)
    if limits is None:
        limits = [image.min(), image.max()]
                
    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off')
    return bitmap

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)   

#%% define a function for plotting images and the updates
# This is the same function as in `ML_reconstruction`
def plot_progress(all_images, title, subiterations, cmax):
    if len(subiterations)==0:
        num_subiters = all_images[0].shape[0]-1;
        subiterations = range(1, num_subiters+1);
    num_rows = len(all_images);
    slice = 60
    for iter in subiterations:
        plt.figure(iter)
        for r in range(num_rows):
            plt.subplot(num_rows,2,2*r+1)
            imshow(all_images[r][iter,slice,:,:], [0,cmax], '%s at %d' % (title[r],  iter))
            plt.subplot(num_rows,2,2*r+2)
            imshow(all_images[r][iter,slice,:,:]-all_images[r][iter-1,slice,:,:],[-cmax*.1,cmax*.1], 'update')
        plt.show()      

# Load the data
To generate the data needed for this notebook, run the [generate_data](./generate_data.ipynb) notebook first.

In [None]:
# Get to correct directory
os.chdir("/data")

#%% copy files to working folder and change directory to where the output files are
shutil.rmtree('working_folder/dePierro_brainweb',True)
shutil.copytree('brainweb','working_folder/dePierro_brainweb')
os.chdir('working_folder/dePierro_brainweb')

acquired_data = pet.AcquisitionData('subj_04_sino_noisy.hs')
atten = pet.ImageData('subj_04_uMap.hv')

# Anatomical image
anatomical = pet.ImageData('subj_04_T1_tumour.hv') # could be MR_T2.nii or MR_PD.nii
anatomical_arr = anatomical.as_array()

#%%  create initial image
init_image=atten.get_uniform_copy(atten.as_array().max()*.1)
make_cylindrical_FOV(init_image)

# Code from first MAPEM notebook

The following chunk of code is copied and pasted directly from the other notebook as a starting point. 

First, run the code chunk to get the results using the quadratic prior...

### MAPEM functions

In [None]:
# Define weights as an array
w=numpy.array([1.,0.,1.])
# normalise to have sum 1
w/=w.sum()

# Define function for xreg
@jit
def compute_xreg(image_array):
    sizes=image_array.shape
    image_reg= image_array*0 # make a copy first. Will then change values
    for z in range(0,sizes[0]):
        for y in range(0,sizes[1]):
            for x in range(1,sizes[2]-1): # ignore first and last pixel for simplicity
                for dx in (-1,0,1):
                    image_reg[z,y,x] += w[dx+1]/2*(image_array[z,y,x]+image_array[z,y,x+dx])
            
    return image_reg

# define a function that computes the MAP-EM update
@jit
def compute_MAPEM_update(xEM,xreg, beta):
    return (2*xEM)/(numpy.sqrt((1 - beta*xreg)**2 + 4*beta*xEM) + (1 - beta*xreg) + 0.00001)

In [None]:
def get_obj_fun(acquired_data, atten):
    print('\n------------- Setting up objective function')
    #     #%% create objective function
    #%% create acquisition model
    am = pet.AcquisitionModelUsingRayTracingMatrix()
    am.set_num_tangential_LORs(5)

    # Set up sensitivity due to attenuation
    asm_attn = pet.AcquisitionSensitivityModel(atten, am)
    asm_attn.set_up(acquired_data)
    bin_eff = pet.AcquisitionData(acquired_data)
    bin_eff.fill(1.0)
    asm_attn.unnormalise(bin_eff)
    asm_attn = pet.AcquisitionSensitivityModel(bin_eff)

    # Set sensitivity of the model and set up
    am.set_acquisition_sensitivity(asm_attn)
    am.set_up(acquired_data,atten);

    #%% create objective function
    obj_fun = pet.make_Poisson_loglikelihood(acquired_data)
    obj_fun.set_acquisition_model(am)

    print('\n------------- Finished setting up objective function')
    return obj_fun

def get_reconstructor(num_subsets, num_subiters, obj_fun, init_image):
    print('\n------------- Setting up reconstructor') 

    #%% create OSEM reconstructor
    OSEM_reconstructor = pet.OSMAPOSLReconstructor()
    OSEM_reconstructor.set_objective_function(obj_fun)
    OSEM_reconstructor.set_num_subsets(num_subsets)
    OSEM_reconstructor.set_num_subiterations(num_subiters)

    #%% initialise
    OSEM_reconstructor.set_up(init_image)
    
    print('\n------------- Finished setting up reconstructor')
    return OSEM_reconstructor


num_subsets = 21
num_subiters = 1
beta = 1

# Use SSRB to create smaller sinogram to speed up calculations
acquired_data=acquired_data.rebin(11,2,0)

obj_fun = get_obj_fun(acquired_data, atten)
OSEM_reconstructor = get_reconstructor(num_subsets,num_subiters, obj_fun, init_image)

#%% do a loop, saving images as we go along
current_image = init_image.clone()
all_images = numpy.ndarray(shape=(num_subiters+1,) + current_image.as_array().shape );
all_images[0,:,:,:] =  current_image.as_array();

In [None]:
for iter in range(1, num_subiters+1):
    print('\n------------- Subiteration %d' % iter) 
    time_start_iter = time.time()
    image_reg= compute_xreg(current_image.as_array()) # compute xreg
    

    time_compute_xreg = time.time()
    print("time for compute xreg: %d" % (time_compute_xreg - time_start_iter))
    OSEM_reconstructor.update(current_image); # compute EM update
    time_osem_update = time.time()
    print("time osem update: %d" % (time_osem_update - time_compute_xreg))
    image_EM=current_image.as_array() # get xEM as a numpy array
    time_as_array = time.time()
    print("time as_array: %d" % (time_as_array - time_osem_update))
    updated = compute_MAPEM_update(image_EM, image_reg, beta) # compute new update
    time_mapem_update = time.time()
    print("time mapem update: %d" % (time_mapem_update - time_as_array))
    current_image.fill(updated) # store for next iteration
    time_fill = time.time()
    print("time fill: %d" % (time_fill - time_mapem_update) )
    all_images[iter,:,:,:] =  updated; # save for plotting later on

# # #%% now call this function to see how we went along
# # subiterations = (1,2,4,8,16,32,42);
# # plot_progress([all_images], ['MAP-OSEM'],subiterations, all_images.max()*0.9)

# Implement de Pierro regularisation

In [None]:
def dePierroUpdate(xEM, imageReg, beta, sensImg):
    
    delta = 1e-6*abs(sensImg).max()
    sensImg[sensImg < delta] = delta # avoid division by zero
    beta_j = beta/sensImg
    return (2*xEM)/(numpy.sqrt((1 - beta_j*imageReg)**2 + 4*beta_j*xEM) + (1 - beta_j*imageReg) + 0.00001)

def dePierroReg(image,weights,nhoodIndVec):
    
    # vectorise image for indexing 
    imageVec = image.reshape(-1,order='F')
        
    # retrieve voxel intensities for neighbourhoods 
    resultVec = imageVec[nhoodIndVec]
    result = resultVec.reshape(weights.shape,order='F')
    
    # compute xreg
    imageReg = 0.5*numpy.sum(weights*(result + image.reshape(-1,1,order='F')),axis=1)
    imageReg = imageReg.reshape(imSize,order='F')
    
    return imageReg

# old version
def dePierroReg2Arg(image,weights):
    
    # get size and vectorise image for indexing 
    imSize = image.shape
    imageVec = image.reshape(-1,1,order='F').flatten('F')
    
    # get the neigbourhoods of each voxel
    weightsSize = weights.shape
    w = int(round(weightsSize[1]**(1.0/3))) # side length of neighbourhood
    nhoodInd    = neighbourExtract(imSize,w)
    nhoodIndVec = nhoodInd.reshape(-1,order='F')
    
    # retrieve voxel intensities for neighbourhoods 
    resultVec = numpy.float32(imageVec[nhoodIndVec])
    result = resultVec.reshape(nhoodInd.shape,order='F')
    
    # compute xreg
    try:
        imageReg = 0.5*numpy.sum(weights*(result + image.reshape(-1,1,order='F')),axis=1)
    except:
        tmpVar = 1;    
    imageReg = imageReg.reshape(imSize,order='F')
    
    return imageReg

def compute_nhoodIndVec(image,weights):
    # get the neigbourhoods of each voxel
    weightsSize = weights.shape
    w = int(round(weightsSize[1]**(1.0/3))) # side length of neighbourhood
    nhoodInd    = neighbourExtract(image.shape,w)
    return nhoodInd.reshape(-1,order='F')
    
def neighbourExtract(imageSize,w):
    # Adapted from Prior class        
    n = imageSize[0]
    m = imageSize[1]
    h = imageSize[2]
    wlen = 2*numpy.floor(w/2)
    widx = xidx = yidx = numpy.arange(-wlen/2,wlen/2+1)

    if h==1:
        zidx = [0]
        nN = w*w
    else:
        zidx = widx
        nN = w*w*w
        
    Y,X,Z = numpy.meshgrid(numpy.arange(0,m), numpy.arange(0,n), numpy.arange(0,h))                
    N = numpy.zeros([n*m*h, nN],dtype='int32')
    l = 0
    for x in xidx:
        Xnew = setBoundary(X + x,n)
        for y in yidx:
            Ynew = setBoundary(Y + y,m)
            for z in zidx:
                Znew = setBoundary(Z + z,h)
                N[:,l] = ((Xnew + (Ynew)*n + (Znew)*n*m)).reshape(-1,1).flatten('F')
                l += 1
    return N

def setBoundary(X,n):
    # Boundary conditions for neighbourExtract
    # Adapted from Prior class
    idx = X<0
    X[idx] = X[idx] + n
    idx = X>n-1
    X[idx] = X[idx] - n
    return X.flatten('F')

In [None]:
def tic():
    #Homemade version of matlab tic and toc functions
    import time
    global startTime_for_tictoc
    startTime_for_`tictoc = time.time()

def toc():
    import time
    return time.time() - startTime_for_tictoc

### MAPEM input data and parameters

In [None]:
OSEM_reconstructor = get_reconstructor(num_subsets,num_subiters, obj_fun, init_image)

In [None]:
sensitivity_image = obj_fun.get_subset_sensitivity(0)*num_subsets
beta = 500

### Create a Prior for computing Bowsher weights

In [None]:
weights[500,:]

In [None]:
import sirf.contrib.kcl.Prior as pr
myPrior = pr.Prior(anatomical_arr.shape)
weights = myPrior.BowshserWeights(anatomical_arr,7)
weights = numpy.float32(weights/7.0)
if (numpy.abs(numpy.sum(weights,axis=1)-1)>1.0e-6).any():
    raise ValueError("Weights should sum to 1 for each voxel")

In [None]:
# compute indices of the neighbourhood
nhoodIndVec=compute_nhoodIndVec(image,weights)

In [None]:
# illustrate that only a few of the weights in the neighbourhood are kept
# (taking an arbitrary voxel)
weights[500,:]

In [None]:
# first run a few OSEM subiterations to get a reasonable image
OSEM_reconstructor.set_num_subiterations(12)
osem_init_image = init_image.clone()
OSEM_reconstructor.reconstruct(osem_init_image)

In [None]:
num_subiters=42
beta=.01

In [None]:
bowsher=current_image.clone()

In [None]:
# enable this to use uniform weights
# weights[:]=1/27

In [None]:
current_image=osem_init_image
all_images_deP = numpy.ndarray(shape=(num_subiters+1,) + current_image.as_array().shape );
all_images_deP[0,:,:,:] =  current_image.as_array();
sensitivity_image = obj_fun.get_subset_sensitivity(0).as_array()
for iter in range(1, num_subiters+1):
    print('\n------------- Subiteration %d' % iter) 
    
    # image_reg= compute_xreg(current_image.as_array()) # compute xreg
    image_reg= dePierroReg(current_image.as_array(),weights,nhoodIndVec) # compute xreg
    OSEM_reconstructor.update(current_image); # compute EM update
    image_EM=current_image.as_array() # get xEM as a numpy array
    # updated = compute_MAPEM_update(image_EM, image_reg, beta) # compute new update
    updated = dePierroUpdate(image_EM, image_reg, beta, sensitivity_image) # compute new update
    current_image.fill(updated) # store for next iteration
    all_images_deP[iter,:,:,:] =  updated; # save for plotting later on

In [None]:
array=current_image.as_array()
slice = array.shape[0]//2
cmax=array.max()*.9
imshow(array[slice,:,:],[0,cmax],'MAP-OSEM');

In [None]:
# run OSEM for the same number of subiterations for comparison
OSEM_reconstructor.set_num_subiterations(num_subiters)
osem_image = osem_init_image.clone()
OSEM_reconstructor.reconstruct(osem_image)

In [None]:
imshow(osem_image.as_array()[slice,:,:],[0,cmax],'OSEM');

In [None]:
#%% now call this function to see how we went along
plot_progress([all_images_deP], ['MAP-OSEM'],numpy.arange(10,num_subiters), cmax)