# Dual PET tracer de Pierro no motion

Authors: Kris Thielemans, Sam Ellis, Richard Brown, Casper da Costa-Luis  
First version: 2nd of November 2019

CCP PETMR Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2019  University College London  
Copyright 2019  King's College London  

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# The challenge!

This notebook is an open-ended look into using de Pierro MAPEM to reconstruct dual-PET acquisitions.

- Imagine two different scans (FDG and amyloid) were performed in a short space of time on a single patient. 
- Your task is to implement an alternating reconstruction of the two scans using de Pierro's MAPEM algorithm!

## Suggested workflow - no motion

- Inspire yourself from [de_Pierro_MAPEM.ipynb](de_Pierro_MAPEM.ipynb), in which Bowsher weights are calculated on some known side information.
- Now, imagine that the side information is evolving with our image estimate
- We'll probably want to perform an update on one of our images (image A)
- Then recalculate the Bowsher weights of the second image (image B) with our newly-update image A
- Then perform a normal de Pierro update on image B
- Then recalculate the Bowsher weights of image A using our newly-updated image B

### But what about motion?

It's possible that there's motion between the two images since they were acquired at different times. Once you've got everything working for dual PET reconstructions, it's time to add motion in just to complicate things!

- Imagine two different scans (FDG and amyloid) were performed in a short space of time on a single patient. 
- Your task is to implement an alternating reconstruction of the two scans using de Pierro's MAPEM algorithm!
- Bear in mind that the two scans weren't performed at the same time, so the patient's head isn't necessarily in the same place...

## Suggested workflow - motion

1. Since we can't be sure of patient position, you should probably reconstruct each image individually 
2. Then register them
3. Then modify your non-motion case, such that you resample each image into the others' space before calculating the Bowsher weights

Hints:
- For an implementation of de Pierro MAPEM, checkout the [de_Pierro_MAPEM.ipynb](de_Pierro_MAPEM.ipynb) notebook.
- To go faster, rebin your sinograms (as per [de_Pierro_MAPEM.ipynb](de_Pierro_MAPEM.ipynb)!
- For registration and resampling, check out the [../Reg/sirf_registration.ipynb](../Reg/sirf_registration.ipynb) notebook. 

### One final word

We've given you some pointers down below that you can fill bit by bit. The sections marked with astrisks won't be needed until you implement the motion case.

# 0a. Some includes and imshow-esque functions

In [None]:
# All the normal stuff you've already seen
import notebook_setup

#%% Initial imports etc
import numpy
import matplotlib.pyplot as plt
import os
import sys
import shutil
from tqdm.auto import tqdm, trange
import time
import sirf.STIR as pet
from sirf_exercises import exercises_data_path
import sirf.Reg as Reg
import sirf.contrib.kcl.Prior as pr

# plotting settings
plt.ion() # interactive 'on' such that plots appear during loops

%matplotlib notebook

#%% some handy function definitions
def imshow(image, limits=None, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image)
    if limits is None:
        limits = [image.min(), image.max()]

    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off')
    return bitmap

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)   

#%% define a function for plotting images and the updates
# This is the same function as in `ML_reconstruction`
def plot_progress(all_images1,all_images2, title1, title2, subiterations, cmax):
    if len(subiterations)==0:
        num_subiters = all_images1[0].shape[0]-1
        subiterations = range(1, num_subiters+1);
    num_rows = len(all_images1);
    slice = 60
    for iter in subiterations:
        plt.figure()
        for r in range(num_rows):
            plt.subplot(num_rows,2,2*r+1)
            imshow(all_images1[r][iter,slice,:,:], [0,cmax], '%s at %d' % (title1[r],  iter))
            plt.subplot(num_rows,2,2*r+2)
            imshow(all_images2[r][iter,slice,:,:], [0,cmax], '%s at %d' % (title2[r],  iter))
        plt.show();  

def subplot_(idx,vol,title,clims=None,cmap="viridis"):
    plt.subplot(*idx)
    plt.imshow(vol,cmap=cmap)
    if not clims is None:
        plt.clim(clims)
    plt.colorbar()
    plt.title(title)
    plt.axis("off")

# 0b. Input data

In [None]:
# Get to correct directory
os.chdir(exercises_data_path('Synergistic'))

# copy files to working folder and change directory to where the output files are
shutil.rmtree('working_folder/dual_PET_noMotion',True)
shutil.copytree('brainweb','working_folder/dual_PET_noMotion')
os.chdir('working_folder/dual_PET_noMotion')

fname_FDG_sino = 'FDG_sino_noisy.hs'
fname_FDG_uMap = 'uMap_small.hv'
# No motion filenames
fname_amyl_sino = 'amyl_sino_noisy.hs'
fname_amyl_uMap = 'uMap_small.hv'
# Motion filenames
# fname_amyl_sino = 'amyl_sino_noisy_misaligned.hs'
# fname_amyl_uMap = 'uMap_misaligned.hv'

full_fdg_sino = pet.AcquisitionData(fname_FDG_sino)
fdg_sino = full_fdg_sino.rebin(3)
fdg_uMap = pet.ImageData(fname_FDG_uMap)

full_amyl_sino = pet.AcquisitionData(fname_amyl_sino)
amyl_sino = full_amyl_sino.rebin(3)
amyl_uMap = pet.ImageData(fname_amyl_uMap)

fdg_init_image=fdg_uMap.get_uniform_copy(fdg_uMap.as_array().max()*.1)
make_cylindrical_FOV(fdg_init_image)

amyl_init_image=amyl_uMap.get_uniform_copy(amyl_uMap.as_array().max()*.1)
make_cylindrical_FOV(amyl_init_image)

# 0c. Set up normal reconstruction stuff

In [None]:
# Code to set up objective function and OSEM recontsructors
def get_obj_fun(acquired_data, atten):
    print('\n------------- Setting up objective function')
    #     #%% create objective function
    #%% create acquisition model
    am = pet.AcquisitionModelUsingRayTracingMatrix()
    am.set_num_tangential_LORs(5)

    # Set up sensitivity due to attenuation
    asm_attn = pet.AcquisitionSensitivityModel(atten, am)
    asm_attn.set_up(acquired_data)
    bin_eff = pet.AcquisitionData(acquired_data)
    bin_eff.fill(1.0)
    asm_attn.unnormalise(bin_eff)
    asm_attn = pet.AcquisitionSensitivityModel(bin_eff)

    # Set sensitivity of the model and set up
    am.set_acquisition_sensitivity(asm_attn)
    am.set_up(acquired_data,atten);

    #%% create objective function
    obj_fun = pet.make_Poisson_loglikelihood(acquired_data)
    obj_fun.set_acquisition_model(am)

    print('\n------------- Finished setting up objective function')
    return obj_fun

def get_reconstructor(num_subsets, num_subiters, obj_fun, init_image):
    print('\n------------- Setting up reconstructor') 

    #%% create OSEM reconstructor
    OSEM_reconstructor = pet.OSMAPOSLReconstructor()
    OSEM_reconstructor.set_objective_function(obj_fun)
    OSEM_reconstructor.set_num_subsets(num_subsets)
    OSEM_reconstructor.set_num_subiterations(num_subiters)

    #%% initialise
    OSEM_reconstructor.set_up(init_image)

    print('\n------------- Finished setting up reconstructor')
    return OSEM_reconstructor

In [None]:
num_subsets = 21
num_subiters = 42

# 1. Two individual reconstructions *

In [None]:
# Some code goes here

# 2. Register images *

In [None]:
# Some more code goes here

# 3. A resample function? *

In [None]:
# How about a bit of code here?

# 4. Maybe some de Pierro functions

In [None]:
# A pinch more code here
def dePierroReg(image,weights,nhoodIndVec):
    """Get the de Pierro regularisation image"""
    imSize = image.shape

    # vectorise image for indexing 
    imageVec = image.reshape(-1,order='F')

    # retrieve voxel intensities for neighbourhoods 
    resultVec = imageVec[nhoodIndVec]
    result = resultVec.reshape(weights.shape,order='F')

    # compute xreg
    imageReg = 0.5*numpy.sum(weights*(result + image.reshape(-1,1,order='F')),axis=1)
    imageReg = imageReg.reshape(imSize,order='F')

    return imageReg

def compute_nhoodIndVec(image,weights):
    """Get the neigbourhoods of each voxel"""
    weightsSize = weights.shape
    w = int(round(weightsSize[1]**(1.0/3))) # side length of neighbourhood
    nhoodInd    = neighbourExtract(image.shape,w)
    return nhoodInd.reshape(-1,order='F')

def neighbourExtract(imageSize,w):
    """Adapted from Prior class"""
    n = imageSize[0]
    m = imageSize[1]
    h = imageSize[2]
    wlen = 2*numpy.floor(w/2)
    widx = xidx = yidx = numpy.arange(-wlen/2,wlen/2+1)

    if h==1:
        zidx = [0]
        nN = w*w
    else:
        zidx = widx
        nN = w*w*w

    Y,X,Z = numpy.meshgrid(numpy.arange(0,m), numpy.arange(0,n), numpy.arange(0,h))                
    N = numpy.zeros([n*m*h, nN],dtype='int32')
    l = 0
    for x in xidx:
        Xnew = setBoundary(X + x,n)
        for y in yidx:
            Ynew = setBoundary(Y + y,m)
            for z in zidx:
                Znew = setBoundary(Z + z,h)
                N[:,l] = ((Xnew + (Ynew)*n + (Znew)*n*m)).reshape(-1,1).flatten('F')
                l += 1
    return N

def setBoundary(X,n):
    """Boundary conditions for neighbourExtract.
    Adapted from Prior class"""
    idx = X<0
    X[idx] = X[idx] + n
    idx = X>n-1
    X[idx] = X[idx] - n
    return X.flatten('F')

def dePierroUpdate(xEM, imageReg, beta):
    """Update the image based on the de Pierro regularisation image"""
    return (2*xEM)/(((1 - beta*imageReg)**2 + 4*beta*xEM)**0.5 + (1 - beta*imageReg) + 0.00001)


fdg_prior = pr.Prior(fdg_init_image.shape)
amyl_prior = pr.Prior(amyl_init_image.shape)

num_bowsher_neighbours = 7

def update_bowsher_weights(prior,side_image,num_bowsher_neighbours):
    weights = prior.BowshserWeights(side_image.as_array(),num_bowsher_neighbours)
    weights = numpy.float32(weights/float(num_bowsher_neighbours))
    return weights

weights_fdg = update_bowsher_weights(fdg_prior,amyl_init_image,num_bowsher_neighbours)
weights_amyl = update_bowsher_weights(amyl_prior,fdg_init_image,num_bowsher_neighbours)

# compute indices of the neighbourhood
nhoodIndVec_fdg=compute_nhoodIndVec(fdg_init_image,weights_fdg)
nhoodIndVec_amyl=compute_nhoodIndVec(amyl_init_image,weights_amyl)

def MAPEM_iteration(OSEM_reconstructor,current_image,weights,nhoodIndVec,beta):
    image_reg = dePierroReg(current_image.as_array(),weights,nhoodIndVec) # compute xreg
    OSEM_reconstructor.update(current_image); # compute EM update
    image_EM=current_image.as_array() # get xEM as a numpy array
    updated = dePierroUpdate(image_EM, image_reg, beta) # compute new update
    current_image.fill(updated) # store for next iteration
    return current_image

# 5. Are we ready?

In [None]:
beta = 0.1

# Final code!

# create initial image
fdg_obj_fn = get_obj_fun(fdg_sino,fdg_uMap)
fdg_reconstructor = get_reconstructor(num_subsets,num_subiters,fdg_obj_fn,fdg_init_image)
amyl_obj_fn = get_obj_fun(amyl_sino,amyl_uMap)
amyl_reconstructor = get_reconstructor(num_subsets,num_subiters,amyl_obj_fn,amyl_init_image)

current_fdg_image = fdg_init_image.clone()
current_amyl_image = amyl_init_image.clone()

all_images_fdg = numpy.ndarray(shape=(num_subiters+1,) + current_fdg_image.as_array().shape)
all_images_amyl = numpy.ndarray(shape=(num_subiters+1,) + current_amyl_image.as_array().shape)

all_images_fdg[0,:,:,:] = current_fdg_image.as_array()
all_images_amyl[0,:,:,:] = current_amyl_image.as_array()

for it in trange(1, num_subiters+1):
    # Update FDG weights as fn. of amyloid image
    weights_fdg = update_bowsher_weights(fdg_prior,current_amyl_image,num_bowsher_neighbours)

    # Do FDG de Pierro update
    current_fdg_image = MAPEM_iteration(fdg_reconstructor,current_fdg_image,weights_fdg,nhoodIndVec_fdg,beta)
    all_images_fdg[it,:,:,:] = current_fdg_image.as_array()

    # Now update the amyloid weights as fn. of FDG image
    weights_amyl = update_bowsher_weights(amyl_prior,current_fdg_image,num_bowsher_neighbours)

    # And do amyloid de Pierro update
    current_amyl_image = MAPEM_iteration(amyl_reconstructor,current_amyl_image,weights_amyl,nhoodIndVec_amyl,beta)
    all_images_amyl[it,:,:,:] = current_amyl_image.as_array();

In [None]:
#%% now call this function to see how we went along
plt.figure()
subiterations = (1,2,4,8,16,32,42)
plot_progress([all_images_fdg],[all_images_amyl], ['FDG MAPEM'], ['Amyloid MAPEM'],subiterations, all_images_fdg.max());