# MAPEM de Pierro algorithm

In [None]:
#%% make sure figures appears inline and animations works
%matplotlib notebook

In [None]:
import sirf.STIR as pet
import matplotlib.pyplot as plt
import os
import numpy as np
from numpy.linalg import norm
import sirf.contrib.kcl.Prior as pr

# Get to correct directory
os.chdir("/data/SRS_data_exhale/")

# We'll need a template sinogram
templ_sino = pet.AcquisitionData('mMR_template_span11.hs')

In [None]:
num_subiters=50

#%% some handy function definitions
def imshow(image, limits=None, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image)
    if limits is None:
        limits = [image.min(), image.max()]
                
    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off')
    return bitmap

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)

## PET ground truth

In [None]:
gt_act = pet.ImageData('PET_activity.nii')
gt_atten = pet.ImageData('PET_attenuation.nii')

fig, axs = plt.subplots(2,3);
fig.suptitle('Ground truth PET');

# PET Activity
axs[0,0].set_ylabel('Activity', rotation=90, size='large');
axs[0,0].imshow(gt_act.as_array()[:,:,60]);
axs[0,1].imshow(gt_act.as_array()[:,60,:]);
axs[0,2].imshow(gt_act.as_array()[60,:,:]);

# PET attenuation
axs[1,0].set_ylabel('Attenuation', rotation=90, size='large');
axs[1,0].imshow(gt_atten.as_array()[:,:,60]);
axs[1,1].imshow(gt_atten.as_array()[:,60,:]);
axs[1,2].imshow(gt_atten.as_array()[60,:,:]);

## MR ground truth

In [None]:
gt_T1 = pet.ImageData('MR_T1.nii')
gt_T2 = pet.ImageData('MR_T2.nii')
gt_PD = pet.ImageData('MR_PD.nii')

fig, axs = plt.subplots(3,3);
fig.suptitle('Ground truth MR');

# MR T1
axs[0,0].set_ylabel('T1', rotation=90, size='large');
axs[0,0].imshow(gt_T1.as_array()[:,:,60]);
axs[0,1].imshow(gt_T1.as_array()[:,60,:]);
axs[0,2].imshow(gt_T1.as_array()[60,:,:]);

# MR T1
axs[1,0].set_ylabel('T2', rotation=90, size='large');
axs[1,0].imshow(gt_T2.as_array()[:,:,60]);
axs[1,1].imshow(gt_T2.as_array()[:,60,:]);
axs[1,2].imshow(gt_T2.as_array()[60,:,:]);

# MR PD
axs[2,0].set_ylabel('PD', rotation=90, size='large');
axs[2,0].imshow(gt_PD.as_array()[:,:,60]);
axs[2,1].imshow(gt_PD.as_array()[:,60,:]);
axs[2,2].imshow(gt_PD.as_array()[60,:,:]);

## Create noise

In [None]:
#%% create acquisition model
am = pet.AcquisitionModelUsingRayTracingMatrix()
am.set_num_tangential_LORs(5)

# Set up sensitivity due to attenuation
asm_attn = pet.AcquisitionSensitivityModel(gt_atten, am)
asm_attn.set_up(templ_sino)
bin_eff = pet.AcquisitionData(templ_sino)
bin_eff.fill(1.0)
print('applying attenuation (please wait, may take a while)...')
asm_attn.unnormalise(bin_eff)
asm_attn = pet.AcquisitionSensitivityModel(bin_eff)

am.set_acquisition_sensitivity(asm_attn)

am.set_up(templ_sino,gt_act);

In [None]:
#%% simulate some data using forward projection
gt_sino = am.forward(gt_act)

In [None]:
gt_sino_array = np.abs(gt_sino.as_array());
noisy_array = np.random.poisson(gt_sino_array).astype('float64');
noisy_sino = gt_sino.clone();
noisy_sino.fill(noisy_array);

In [None]:
sino_max = gt_sino_array[0,400,:,:].max()

plt.figure()
plt.subplot(1,2,1);
imshow(gt_sino_array[0,400,:,:], [0,sino_max], 'Original');
plt.subplot(1,2,2);
imshow(noisy_array[0,400,:,:], [0,sino_max], 'Noisy');

## OSEM reconstruction of noisy data

In [None]:
#%% create objective function
obj_fun = pet.make_Poisson_loglikelihood(noisy_sino)
obj_fun.set_acquisition_model(am)
#%% create OSEM reconstructor
OSEM_reconstructor = pet.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)
OSEM_reconstructor.set_num_subsets(21)
OSEM_reconstructor.set_num_subiterations(num_subiters)
#%%  create initial image
init_image=gt_act.get_uniform_copy(gt_act.as_array().max()*.1)
make_cylindrical_FOV(init_image)

reconstructed_image = init_image.clone()
OSEM_reconstructor.set_up(reconstructed_image)
OSEM_reconstructor.reconstruct(reconstructed_image)

In [None]:
plt.figure()

max_in_recon = reconstructed_image.as_array()[:,:,60].max()

plt.subplot(1,3,1);
imshow(reconstructed_image.as_array()[:,:,60],[0, max_in_recon]);
plt.subplot(1,3,2);
imshow(reconstructed_image.as_array()[:,60,:],[0, max_in_recon]);
plt.subplot(1,3,3);
imshow(reconstructed_image.as_array()[60,:,:],[0, max_in_recon]);

## Now use dePierro MAPEM

In [None]:
def dePierroUpdate(imageEM, imageReg, beta, sensImg):
    
    delta = 1e-6*abs(sensImg).max()
    sensImg[sensImg < delta] = delta # avoid division by zero
    beta_j = beta/sensImg
    
    b_j = 1 - beta_j*imageReg
    
    numer = (2*imageEM)
    denom = ((b_j**2 + 4*beta_j*imageEM)**0.5 + b_j)
    
    delta = 1e-6*abs(denom).max()
    denom[denom < delta] = delta # avoid division by zero
    
    imageUpdated = numer/denom
    
    return imageUpdated


def dePierroReg(image,weights):
    
    # get size and vectorise image for indexing 
    imSize = image.shape
    imageVec = image.reshape(-1,1,order='F').flatten('F')
    
    # get the neigbourhoods of each voxel
    weightsSize = weights.shape
    w = int(round(weightsSize[1]**(1.0/3))) # side length of neighbourhood
    nhoodInd    = neighbourExtract(imSize,w)
    nhoodIndVec = nhoodInd.reshape(-1,1,order='F').flatten('F')
    
    # retrieve voxel intensities for neighbourhoods 
    resultVec = np.float32(imageVec[nhoodIndVec])
    result = resultVec.reshape(nhoodInd.shape,order='F')
    
    # compute xreg
    try:
        imageReg = 0.5*np.sum(weights*(result + np.float32(image).reshape(-1,1,order='F')),axis=1)
    except:
        tmpVar = 1;    
    imageReg = imageReg.reshape(imSize,order='F')
    
    return imageReg


def neighbourExtract(imageSize,w):
    # Adapted from Prior class        
    n = imageSize[0]
    m = imageSize[1]
    h = imageSize[2]
    wlen = 2*np.floor(w/2)
    widx = xidx = yidx = np.arange(-wlen/2,wlen/2+1)

    if h==1:
        zidx = [0]
        nN = w*w
    else:
        zidx = widx
        nN = w*w*w
        
    Y,X,Z = np.meshgrid(np.arange(0,m), np.arange(0,n), np.arange(0,h))                
    N = np.zeros([n*m*h, nN],dtype='int32')
    l = 0
    for x in xidx:
        Xnew = setBoundary(X + x,n)
        for y in yidx:
            Ynew = setBoundary(Y + y,m)
            for z in zidx:
                Znew = setBoundary(Z + z,h)
                N[:,l] = ((Xnew + (Ynew)*n + (Znew)*n*m)).reshape(-1,1).flatten('F')
                l += 1
    return N


def setBoundary(X,n):
    # Boundary conditions for neighbourExtract
    # Adapted from Prior class
    idx = X<0
    X[idx] = X[idx] + n
    idx = X>n-1
    X[idx] = X[idx] - n
    return X.flatten('F')

In [None]:
sensitivity_image = obj_fun.get_subset_sensitivity(0)

MR_image_for_guidance = gt_T1.as_array()
# Can't have <= 0 for weights
MR_image_for_guidance[MR_image_for_guidance<=0] = 0.001

# create a Prior for computing Bowsher weights
myPrior = pr.Prior(gt_T1.as_array().shape)
weights = myPrior.BowshserWeights(MR_image_for_guidance,7)
weights = np.float32(weights/7.0)
if (np.abs(np.sum(weights,axis=1)-1)>1.0e-6).any():
    raise ValueError("Weights should sum to 1 for each voxel")

de_pierro_reconstructed_image = init_image.clone()

# Create OSEM reconstructor
print('Setting up reconstruction object')
OSEM_reconstructor = pet.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)                             
OSEM_reconstructor.set_num_subsets(21)
OSEM_reconstructor.set_num_subiterations(num_subiters)
OSEM_reconstructor.set_up(de_pierro_reconstructed_image)

In [None]:
for iter in range(1,num_subiters + 1):
    print('\n------------- Subiteration %d' % iter) 
    
    # Calculate imageReg and return as an array
    imageReg_array = dePierroReg(de_pierro_reconstructed_image.as_array(),weights)
    
    # OSEM image update
    OSEM_reconstructor.update(de_pierro_reconstructed_image)
    imageEM_array = de_pierro_reconstructed_image.as_array()

    # Final image update
    imageUpdated_array = dePierroUpdate \
        (imageEM_array, imageReg_array, 50000, sensitivity_image.as_array())

    # Fill image and truncate to cylindrical field of view        
    de_pierro_reconstructed_image.fill(imageUpdated_array)
    make_cylindrical_FOV(de_pierro_reconstructed_image)

In [None]:
plt.figure()

max_in_recon = max(max_in_recon, reconstructed_image.as_array()[:,:,60].max())

plt.subplot(2,3,1);
imshow(reconstructed_image.as_array()[:,:,60],[0, max_in_recon]);
plt.subplot(2,3,2);
imshow(reconstructed_image.as_array()[:,60,:],[0, max_in_recon]);
plt.subplot(2,3,3);
imshow(reconstructed_image.as_array()[60,:,:],[0, max_in_recon]);
plt.subplot(2,3,4);
imshow(de_pierro_reconstructed_image.as_array()[:,:,60],[0, max_in_recon]);
plt.subplot(2,3,5);
imshow(de_pierro_reconstructed_image.as_array()[:,60,:],[0, max_in_recon]);
plt.subplot(2,3,6);
imshow(de_pierro_reconstructed_image.as_array()[60,:,:],[0, max_in_recon]);