# MAPEM de Pierro algorithm

Authors: Kris Thielemans, Sam Ellis, Richard Brown  
First version: 22nd of October 2019  
Second version: 27th of October 2019

CCP PETMR Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2019  University College London  
Copyright 2019  King's College London  

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Start with [PET/MAPEM](../PET/MAPEM.ipynb)...

If you've already completed the PET component, you will have implemented a version of MAPEM. If you haven't, you'll probably want to give that a go first!

This example extends upon the quadratic prior used in that notebook to use an anatomical prior.

# All the normal imports and handy functions

In [None]:
#%% Initial imports etc
import numpy
import matplotlib.pyplot as plt
import os
import sys
import shutil
from numba import jit
import time
from scipy.ndimage.filters import gaussian_filter
import sirf.STIR as pet
from sirf.Utilities import examples_data_path
# plotting settings
plt.ion() # interactive 'on' such that plots appear during loops

%matplotlib notebook

#%% some handy function definitions
def imshow(image, limits=None, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image)
    if limits is None:
        limits = [image.min(), image.max()]
                
    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off')
    return bitmap

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)   

#%% define a function for plotting images and the updates
# This is the same function as in `ML_reconstruction`
def plot_progress(all_images, title, subiterations, cmax):
    if len(subiterations)==0:
        num_subiters = all_images[0].shape[0]-1;
        subiterations = range(1, num_subiters+1);
    num_rows = len(all_images);
    slice = 60
    for iter in subiterations:
        plt.figure()
        for r in range(num_rows):
            plt.subplot(num_rows,2,2*r+1)
            imshow(all_images[r][iter,slice,:,:], [0,cmax], '%s at %d' % (title[r],  iter))
            plt.subplot(num_rows,2,2*r+2)
            imshow(all_images[r][iter,slice,:,:]-all_images[r][iter-1,slice,:,:],[-cmax*.1,cmax*.1], 'update')
        plt.show();  
        
def subplot_(idx,vol,title,clims=None,cmap="viridis"):
    plt.subplot(*idx)
    plt.imshow(vol,cmap=cmap)
    if not clims is None:
        plt.clim(clims)
    plt.colorbar()
    plt.title(title)
    plt.axis("off")

# Load the data
To generate the data needed for this notebook, run the [generate_data](./generate_data.ipynb) notebook first.

In [None]:
# Get to correct directory
os.chdir(examples_data_path('Synergistic'))

# copy files to working folder and change directory to where the output files are
shutil.rmtree('working_folder/dePierro_brainweb',True)
shutil.copytree('brainweb','working_folder/dePierro_brainweb')
os.chdir('working_folder/dePierro_brainweb')

full_acquired_data = pet.AcquisitionData('fdg_sino_noisy.hs')
atten = pet.ImageData('uMap_small.hv')

# Anatomical image
anatomical = pet.ImageData('T1_small.hv') # could be T2_small.hv
anatomical_arr = anatomical.as_array()

# create initial image
init_image=atten.get_uniform_copy(atten.as_array().max()*.1)
make_cylindrical_FOV(init_image)

# Code from first MAPEM notebook

The following chunk of code is copied and pasted more-or-less directly from the other notebook as a starting point. 

First, run the code chunk to get the results using the quadratic prior...

### MAPEM functions

In [None]:
# Define weights as an array
w=numpy.array([1.,0.,1.])
# normalise to have sum 1
w/=w.sum()

# Define function for xreg. 
# Using jit gets computation time from 90 secs to 2!
@jit
def compute_xreg(image_array):
    sizes=image_array.shape
    image_reg= image_array*0 # make a copy first. Will then change values
    for z in range(0,sizes[0]):
        for y in range(0,sizes[1]):
            for x in range(1,sizes[2]-1): # ignore first and last pixel for simplicity
                for dx in (-1,0,1):
                    image_reg[z,y,x] += w[dx+1]/2*(image_array[z,y,x]+image_array[z,y,x+dx])
            
    return image_reg

# define a function that computes the MAP-EM update
@jit
def compute_MAPEM_update(xEM,xreg, beta):
    return (2*xEM)/(numpy.sqrt((1 - beta*xreg)**2 + 4*beta*xEM) + (1 - beta*xreg) + 0.00001)

In [None]:
def get_obj_fun(acquired_data, atten):
    print('\n------------- Setting up objective function')
    #     #%% create objective function
    #%% create acquisition model
    am = pet.AcquisitionModelUsingRayTracingMatrix()
    am.set_num_tangential_LORs(5)

    # Set up sensitivity due to attenuation
    asm_attn = pet.AcquisitionSensitivityModel(atten, am)
    asm_attn.set_up(acquired_data)
    bin_eff = pet.AcquisitionData(acquired_data)
    bin_eff.fill(1.0)
    asm_attn.unnormalise(bin_eff)
    asm_attn = pet.AcquisitionSensitivityModel(bin_eff)

    # Set sensitivity of the model and set up
    am.set_acquisition_sensitivity(asm_attn)
    am.set_up(acquired_data,atten);

    #%% create objective function
    obj_fun = pet.make_Poisson_loglikelihood(acquired_data)
    obj_fun.set_acquisition_model(am)

    print('\n------------- Finished setting up objective function')
    return obj_fun

def get_reconstructor(num_subsets, num_subiters, obj_fun, init_image):
    print('\n------------- Setting up reconstructor') 

    #%% create OSEM reconstructor
    OSEM_reconstructor = pet.OSMAPOSLReconstructor()
    OSEM_reconstructor.set_objective_function(obj_fun)
    OSEM_reconstructor.set_num_subsets(num_subsets)
    OSEM_reconstructor.set_num_subiterations(num_subiters)

    #%% initialise
    OSEM_reconstructor.set_up(init_image)
    
    print('\n------------- Finished setting up reconstructor')
    return OSEM_reconstructor

In [None]:
# Use SSRB to create smaller sinogram to speed up calculations
acquired_data = full_acquired_data.clone()
acquired_data = acquired_data.rebin(3)

In [None]:
# Get the objective function
obj_fun = get_obj_fun(acquired_data, atten)

In [None]:
num_subsets = 21
num_subiters = 42

# Do a normal OSEM (for comparison)

In [None]:
# Do initial OSEM recon
OSEM_reconstructor = get_reconstructor(num_subsets, num_subiters, obj_fun, init_image)
osem_image = init_image.clone()
OSEM_reconstructor.reconstruct(osem_image)

plt.figure();
plt.imshow(osem_image.as_array()[60,:,:]);
plt.show();

# Now do a normal MAPEM

In [None]:
beta = 10
# We don't have to get a new reconstructor each time, 
# but if we don't, then we'll start from the same subiteration
# that we finished last time. But we save a few seconds.
# OSEM_reconstructor = get_reconstructor(num_subsets,num_subiters, obj_fun, init_image)

#%% do a loop, saving images as we go along
current_image = init_image.clone()
all_images = numpy.ndarray(shape=(num_subiters+1,) + current_image.as_array().shape );
all_images[0,:,:,:] =  current_image.as_array();

for iter in range(1, num_subiters+1):
    start_time = time.time()
    image_reg = compute_xreg(current_image.as_array()) # compute xreg
    OSEM_reconstructor.update(current_image); # compute EM update
    image_EM=current_image.as_array() # get xEM as a numpy array
    updated = compute_MAPEM_update(image_EM, image_reg, beta) # compute new update
    current_image.fill(updated) # store for next iteration
    all_images[iter,:,:,:] =  updated; # save for plotting later on
    print('\n------------- Subiteration %i finished in %i s.' % (iter, time.time() - start_time))

#%% now call this function to see how we went along
subiterations = (1,2,4,8,16,32,42);
plot_progress([all_images], ['Quadratic prior MAP-OSEM'],subiterations, all_images.max()*0.9)

# Implement de Pierro regularisation

In [None]:
def dePierroReg(image,weights,nhoodIndVec):
    """Get the de Pierro regularisation image"""
    imSize = image.shape
    
    # vectorise image for indexing 
    imageVec = image.reshape(-1,order='F')
        
    # retrieve voxel intensities for neighbourhoods 
    resultVec = imageVec[nhoodIndVec]
    result = resultVec.reshape(weights.shape,order='F')
    
    # compute xreg
    imageReg = 0.5*numpy.sum(weights*(result + image.reshape(-1,1,order='F')),axis=1)
    imageReg = imageReg.reshape(imSize,order='F')
    
    return imageReg

def compute_nhoodIndVec(image,weights):
    """Get the neigbourhoods of each voxel"""
    weightsSize = weights.shape
    w = int(round(weightsSize[1]**(1.0/3))) # side length of neighbourhood
    nhoodInd    = neighbourExtract(image.shape,w)
    return nhoodInd.reshape(-1,order='F')

def neighbourExtract(imageSize,w):
    """Adapted from Prior class"""
    n = imageSize[0]
    m = imageSize[1]
    h = imageSize[2]
    wlen = 2*numpy.floor(w/2)
    widx = xidx = yidx = numpy.arange(-wlen/2,wlen/2+1)

    if h==1:
        zidx = [0]
        nN = w*w
    else:
        zidx = widx
        nN = w*w*w
        
    Y,X,Z = numpy.meshgrid(numpy.arange(0,m), numpy.arange(0,n), numpy.arange(0,h))                
    N = numpy.zeros([n*m*h, nN],dtype='int32')
    l = 0
    for x in xidx:
        Xnew = setBoundary(X + x,n)
        for y in yidx:
            Ynew = setBoundary(Y + y,m)
            for z in zidx:
                Znew = setBoundary(Z + z,h)
                N[:,l] = ((Xnew + (Ynew)*n + (Znew)*n*m)).reshape(-1,1).flatten('F')
                l += 1
    return N

def setBoundary(X,n):
    """Boundary conditions for neighbourExtract.
    Adapted from Prior class"""
    idx = X<0
    X[idx] = X[idx] + n
    idx = X>n-1
    X[idx] = X[idx] - n
    return X.flatten('F')

def dePierroUpdate(xEM, imageReg, beta):
    """Update the image based on the de Pierro regularisation image"""
    return (2*xEM)/(((1 - beta*imageReg)**2 + 4*beta*xEM)**0.5 + (1 - beta*imageReg) + 0.00001)

### Create a Prior for computing Bowsher weights

In [None]:
import sirf.contrib.kcl.Prior as pr
def update_bowsher_weights(prior,side_image,num_bowsher_neighbours):
    weights = prior.BowshserWeights\
        (side_image.as_array(),num_bowsher_neighbours)/float(num_bowsher_neighbours)
    return weights

num_bowsher_neighbours = 7
myPrior = pr.Prior(anatomical_arr.shape)
weights = update_bowsher_weights(myPrior,anatomical,num_bowsher_neighbours)

In [None]:
# compute indices of the neighbourhood
nhoodIndVec=compute_nhoodIndVec(atten,weights)

In [None]:
# illustrate that only a few of the weights in the neighbourhood are kept
# (taking an arbitrary voxel)
print(weights[500,:])

In [None]:
# enable this to use uniform weights
# weights[:]=1/27

In [None]:
def MAPEM_iteration(OSEM_reconstructor,current_image,weights,nhoodIndVec,beta):
    image_reg = dePierroReg(current_image.as_array(),weights,nhoodIndVec) # compute xreg
    OSEM_reconstructor.update(current_image); # compute EM update
    image_EM=current_image.as_array() # get xEM as a numpy array
    updated = dePierroUpdate(image_EM, image_reg, beta) # compute new update
    current_image.fill(updated) # store for next iteration
    return current_image

In [None]:
current_image=init_image.clone()
all_images_deP = numpy.ndarray(shape=(num_subiters+1,) + current_image.as_array().shape );
all_images_deP[0,:,:,:] =  current_image.as_array();
for iter in range(1, num_subiters+1):
    start_time = time.time()
    
    current_image = MAPEM_iteration(OSEM_reconstructor,current_image,weights,nhoodIndVec,beta)
    
    all_images_deP[iter,:,:,:] =  current_image.as_array(); # save for plotting later on
    print('\n------------- Subiteration %i finished in %i s.' % (iter, time.time() - start_time))
    
#%% now call this function to see how we went along
plt.figure()
subiterations = (1,2,4,8,16,32,42);
plot_progress([all_images_deP], ['Boswher MAP-OSEM'],subiterations, all_images_deP.max())

In [None]:
# Plot the anatomical, OSEM, and two MAPEM 
plt.figure();
subplot_([2,2,1],anatomical_arr[60,:,:],"Anatomical")
subplot_([2,2,2],osem_image.as_array()[60,:,:],"OSEM")
subplot_([2,2,3],all_images[num_subiters,60,:,:],"Quadratic prior")
subplot_([2,2,4],all_images_deP[num_subiters,60,:,:],"Bowsher prior")

# Finally, misalignment between anatomical and emission images?

What happens if you want to use an anatomical prior but the image isn't aligned with the image you're trying to reconstruct?  

You'll have to register them of course! Have a look at the [registration notebook](../Reg/sirf_registration.ipynb) if you haven't already.  

The idea here would be to run an initial reconstruction (say, OSEM), and then register the anatomical image to the resulting reconstruction...

Once we've got the anatomical image in the correct space, we can calculate the Bowsher weights.

In [None]:
import sirf.Reg as Reg

registration = Reg.NiftyAladinSym()
registration.set_reference_image
registration.set_reference_image(osem_image)
registration.set_floating_image(anatomical)
registration.set_parameter('SetPerformRigid','1')
registration.set_parameter('SetPerformAffine','0')
registration.process()
anatomical_in_emission_space = registration.get_output()

weights = update_bowsher_weights(myPrior,anatomical_in_emission_space,num_bowsher_neighbours)

If we were trying to do some sort of synergistic alternating reconstruction where motion was present, then we would probably want to try to somethings along the lines of:

- Get the best looking images independently
- Register the images
- Extract forward and back transformations
- The regularisation images evole as each others' side information evolves. 
- We therefore would need to resample into the target space before recalculating weights