# MAPEM de Pierro algorithm

Authors: Sam Ellis, Richard Brown
First version: 22nd of October 2019

CCP PETMR Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2019  University College London.
Copyright 2019  King's College London

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

N.B.: You need to have run the [generate_data](./generate_data.ipynb) notebook first.

In [None]:
#%% make sure figures appears inline and animations works
%matplotlib notebook

In [None]:
import sirf.STIR as pet
import matplotlib.pyplot as plt
import os
import numpy as np
from numpy.linalg import norm
import sirf.contrib.kcl.Prior as pr
import shutil

# Get to correct directory
os.chdir("/data")

#%% copy files to working folder and change directory to where the output files are
shutil.rmtree('working_folder/dePierro_thorax',True)
shutil.copytree('SRS_data_exhale','working_folder/dePierro_thorax')
os.chdir('working_folder/dePierro_thorax')

In [None]:
#%% some handy function definitions
def imshow(image, limits=None, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image)
    if limits is None:
        limits = [image.min(), image.max()]
                
    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off')
    return bitmap

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)

### Load some data and set some values

In [None]:
sino = pet.AcquisitionData('noisy_sino.hs')
atten = pet.ImageData('PET_attenuation.nii')
num_subsets = 21
num_subiters = 42

# Anatomical image
anatomical = pet.ImageData('MR_T1.nii') # could be MR_T2.nii or MR_PD.nii
anatomical_arr = anatomical.as_array()

#%%  create initial image
init_image=atten.get_uniform_copy(atten.as_array().max()*.1)
make_cylindrical_FOV(init_image)

## Set up the acquisition model and objective function

In [None]:
#%% create acquisition model
am = pet.AcquisitionModelUsingRayTracingMatrix()
am.set_num_tangential_LORs(5)

# Set up sensitivity due to attenuation
asm_attn = pet.AcquisitionSensitivityModel(atten, am)
asm_attn.set_up(sino)
bin_eff = pet.AcquisitionData(sino)
bin_eff.fill(1.0)
asm_attn.unnormalise(bin_eff)
asm_attn = pet.AcquisitionSensitivityModel(bin_eff)

# Set sensitivity of the model and set up
am.set_acquisition_sensitivity(asm_attn)
am.set_up(sino,atten);

#%% create objective function
obj_fun = pet.make_Poisson_loglikelihood(sino)
obj_fun.set_acquisition_model(am)

## Normal OSEM

In [None]:
#%% create OSEM reconstructor
OSEM_reconstructor = pet.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)
OSEM_reconstructor.set_num_subsets(num_subsets)
OSEM_reconstructor.set_num_subiterations(num_subiters)

reconstructed_image = init_image.clone()
OSEM_reconstructor.set_up(reconstructed_image)
OSEM_reconstructor.reconstruct(reconstructed_image)

### Display OSEM

In [None]:
recon_arr = reconstructed_image.as_array()
plt.subplot(1,3,1);
imshow(recon_arr[:,:,60],[0, 0.1]);
plt.subplot(1,3,2);
imshow(recon_arr[:,60,:],[0, 0.1]);
plt.subplot(1,3,3);
imshow(recon_arr[60,:,:],[0, 0.1]);

## Now use dePierro MAPEM

In [None]:
def dePierroUpdate(imageEM, imageReg, beta, sensImg):
    
    delta = 1e-6*abs(sensImg).max()
    sensImg[sensImg < delta] = delta # avoid division by zero
    beta_j = beta/sensImg
    
    b_j = 1 - beta_j*imageReg
    
    numer = (2*imageEM)
    denom = ((b_j**2 + 4*beta_j*imageEM)**0.5 + b_j)
    
    delta = 1e-6*abs(denom).max()
    denom[denom < delta] = delta # avoid division by zero
    
    imageUpdated = numer/denom
    
    return imageUpdated


def dePierroReg(image,weights):
    
    # get size and vectorise image for indexing 
    imSize = image.shape
    imageVec = image.reshape(-1,1,order='F').flatten('F')
    
    # get the neigbourhoods of each voxel
    weightsSize = weights.shape
    w = int(round(weightsSize[1]**(1.0/3))) # side length of neighbourhood
    nhoodInd    = neighbourExtract(imSize,w)
    nhoodIndVec = nhoodInd.reshape(-1,1,order='F').flatten('F')
    
    # retrieve voxel intensities for neighbourhoods 
    resultVec = np.float32(imageVec[nhoodIndVec])
    result = resultVec.reshape(nhoodInd.shape,order='F')
    
    # compute xreg
    try:
        imageReg = 0.5*np.sum(weights*(result + np.float32(image).reshape(-1,1,order='F')),axis=1)
    except:
        tmpVar = 1;    
    imageReg = imageReg.reshape(imSize,order='F')
    
    return imageReg


def neighbourExtract(imageSize,w):
    # Adapted from Prior class        
    n = imageSize[0]
    m = imageSize[1]
    h = imageSize[2]
    wlen = 2*np.floor(w/2)
    widx = xidx = yidx = np.arange(-wlen/2,wlen/2+1)

    if h==1:
        zidx = [0]
        nN = w*w
    else:
        zidx = widx
        nN = w*w*w
        
    Y,X,Z = np.meshgrid(np.arange(0,m), np.arange(0,n), np.arange(0,h))                
    N = np.zeros([n*m*h, nN],dtype='int32')
    l = 0
    for x in xidx:
        Xnew = setBoundary(X + x,n)
        for y in yidx:
            Ynew = setBoundary(Y + y,m)
            for z in zidx:
                Znew = setBoundary(Z + z,h)
                N[:,l] = ((Xnew + (Ynew)*n + (Znew)*n*m)).reshape(-1,1).flatten('F')
                l += 1
    return N


def setBoundary(X,n):
    # Boundary conditions for neighbourExtract
    # Adapted from Prior class
    idx = X<0
    X[idx] = X[idx] + n
    idx = X>n-1
    X[idx] = X[idx] - n
    return X.flatten('F')

### MAPEM input data and parameters

In [None]:
sensitivity_image = obj_fun.get_subset_sensitivity(0)
de_pierro_reconstructed_image = init_image.clone()
beta = 5000

In [None]:
# create a Prior for computing Bowsher weights
myPrior = pr.Prior(anatomical_arr.shape)
weights = myPrior.BowshserWeights(anatomical_arr,7)
weights = np.float32(weights/7.0)
if (np.abs(np.sum(weights,axis=1)-1)>1.0e-6).any():
    raise ValueError("Weights should sum to 1 for each voxel")

# Create OSEM reconstructor
print('Setting up reconstruction object')
OSEM_reconstructor = pet.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)                             
OSEM_reconstructor.set_num_subsets(num_subsets)
OSEM_reconstructor.set_num_subiterations(num_subiters)
OSEM_reconstructor.set_up(de_pierro_reconstructed_image)

In [None]:
for iter in range(1,num_subiters + 1):
    print('\n------------- Subiteration %d' % iter) 
    
    # Calculate imageReg and return as an array
    imageReg_array = dePierroReg(de_pierro_reconstructed_image.as_array(),weights)
    
    # OSEM image update
    OSEM_reconstructor.update(de_pierro_reconstructed_image)
    imageEM_array = de_pierro_reconstructed_image.as_array()

    # Final image update
    imageUpdated_array = dePierroUpdate \
        (imageEM_array, imageReg_array, beta, sensitivity_image.as_array())

    # Fill image and truncate to cylindrical field of view        
    de_pierro_reconstructed_image.fill(imageUpdated_array)
    make_cylindrical_FOV(de_pierro_reconstructed_image)

In [None]:
plt.figure()

depier_arr = de_pierro_reconstructed_image.as_array()

plt.subplot(2,3,1);
imshow(recon_arr[:,:,60],[0, 0.05]);
plt.subplot(2,3,2);
imshow(recon_arr[:,60,:],[0, 0.05]);
plt.subplot(2,3,3);
imshow(recon_arr[60,:,:],[0, 0.05]);
plt.subplot(2,3,4);
imshow(depier_arr[:,:,60],[0, 0.05]);
plt.subplot(2,3,5);
imshow(depier_arr[:,60,:],[0, 0.05]);
plt.subplot(2,3,6);
imshow(depier_arr[60,:,:],[0, 0.05]);