# Dual PET tracer de Pierro with Bowsher, handling motion
This notebook fleshes out the skeleton for the challenge set in the [../Dual_PET notebook](../Dual_PET.ipynb), including motion.

Authors: Richard Brown, Sam Ellis, Kris Thielemans  
First version: 2nd of November 2019

Second version June 2021

CCP PETMR Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2019, 2021  University College London  
Copyright 2019  King's College London  

This is software developed for the Collaborative Computational
Project in Synergistic Reconstruction for Biomedical Imaging.
(http://www.synerbi.ac.uk/).

SPDX-License-Identifier: Apache-2.0

# 0a. Some includes and imshow-esque functions

In [None]:
# All the normal stuff you've already seen
import notebook_setup

#%% Initial imports etc
import numpy
import matplotlib.pyplot as plt
import os
import sys
import shutil
import time
import sirf.STIR as pet
from sirf_exercises import exercises_data_path
import sirf.Reg as Reg
import sirf.contrib.kcl.Prior as pr

# plotting settings
plt.ion() # interactive 'on' such that plots appear during loops

%matplotlib widget

#%% some handy function definitions
def imshow(image, limits=None, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap = plt.imshow(image)
    if limits is None:
        limits = [image.min(), image.max()]
                
    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off')
    return bitmap

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)   
    
#%% define a function for plotting images and the updates
# This is the same function as in `ML_reconstruction`
def plot_progress_compare(all_images1,all_images2, title1, title2, subiterations, cmax):
    if len(subiterations)==0:
        num_subiters = all_images1[0].shape[0]-1;
        subiterations = range(1, num_subiters+1);
    num_rows = len(all_images1);
    slice = 60
    for iter in subiterations:
        plt.figure()
        for r in range(num_rows):
            plt.subplot(num_rows,2,2*r+1)
            imshow(all_images1[r][iter,slice,:,:], [0,cmax], '%s at %d' % (title1[r],  iter))
            plt.subplot(num_rows,2,2*r+2)
            imshow(all_images2[r][iter,slice,:,:], [0,cmax], '%s at %d' % (title2[r],  iter))
        plt.show(); 

def subplot_(idx,vol,title,clims=None,cmap="viridis"):
    plt.subplot(*idx)
    plt.imshow(vol,cmap=cmap)
    if not clims is None:
        plt.clim(clims)
    plt.colorbar()
    plt.title(title)
    plt.axis("off")

# 0b. Input data

In [None]:
# Setup the working directory for the notebook
import notebook_setup

from sirf_exercises import cd_to_working_dir
cd_to_working_dir('Synergistic', 'BrainWeb')

fname_FDG_sino = 'FDG_sino_noisy.hs'
fname_FDG_uMap = 'uMap_small.hv'
# No motion filenames
# fname_amyl_sino = 'amyl_sino_noisy.hs'
# fname_amyl_uMap = 'uMap_small.hv'
# Motion filenames
fname_amyl_sino = 'amyl_sino_noisy_misaligned.hs'
fname_amyl_uMap = 'uMap_misaligned.hv'

full_fdg_sino = pet.AcquisitionData(fname_FDG_sino)
fdg_sino = full_fdg_sino.rebin(3)
fdg_uMap = pet.ImageData(fname_FDG_uMap)

full_amyl_sino = pet.AcquisitionData(fname_amyl_sino)
amyl_sino = full_amyl_sino.rebin(3)
amyl_uMap = pet.ImageData(fname_amyl_uMap)

fdg_init_image=fdg_uMap.get_uniform_copy(fdg_uMap.as_array().max()*.1)
make_cylindrical_FOV(fdg_init_image)

amyl_init_image=amyl_uMap.get_uniform_copy(amyl_uMap.as_array().max()*.1)
make_cylindrical_FOV(amyl_init_image)

# 0c. Set up normal reconstruction stuff

In [None]:
# Code to set up objective function and OSEM recontsructors
def get_obj_fun(acquired_data, atten):
    print('\n------------- Setting up objective function')
    #     #%% create objective function
    #%% create acquisition model
    am = pet.AcquisitionModelUsingRayTracingMatrix()
    am.set_num_tangential_LORs(5)

    # Set up sensitivity due to attenuation
    asm_attn = pet.AcquisitionSensitivityModel(atten, am)
    asm_attn.set_up(acquired_data)
    bin_eff = pet.AcquisitionData(acquired_data)
    bin_eff.fill(1.0)
    asm_attn.unnormalise(bin_eff)
    asm_attn = pet.AcquisitionSensitivityModel(bin_eff)

    # Set sensitivity of the model and set up
    am.set_acquisition_sensitivity(asm_attn)
    am.set_up(acquired_data,atten);

    #%% create objective function
    obj_fun = pet.make_Poisson_loglikelihood(acquired_data)
    obj_fun.set_acquisition_model(am)

    print('\n------------- Finished setting up objective function')
    return obj_fun

def get_reconstructor(num_subsets, num_subiters, obj_fun, init_image):
    print('\n------------- Setting up reconstructor') 

    #%% create OSEM reconstructor
    OSEM_reconstructor = pet.OSMAPOSLReconstructor()
    OSEM_reconstructor.set_objective_function(obj_fun)
    OSEM_reconstructor.set_num_subsets(num_subsets)
    OSEM_reconstructor.set_num_subiterations(num_subiters)

    #%% initialise
    OSEM_reconstructor.set_up(init_image)
    
    print('\n------------- Finished setting up reconstructor')
    return OSEM_reconstructor

In [None]:
num_subsets = 21
num_subiters = 42

In [None]:
# create initial image
osem_fdg=fdg_init_image.clone()
fdg_obj_fn = get_obj_fun(fdg_sino,fdg_uMap)
fdg_reconstructor = get_reconstructor(num_subsets,num_subiters,fdg_obj_fn,fdg_init_image)
fdg_reconstructor.reconstruct(osem_fdg)

# create initial image
osem_amyl=amyl_init_image.clone()
amyl_obj_fn = get_obj_fun(amyl_sino,amyl_uMap)
amyl_reconstructor = get_reconstructor(num_subsets,num_subiters,amyl_obj_fn,amyl_init_image)
amyl_reconstructor.reconstruct(osem_amyl)

plt.figure();
subplot_([1,2,1],osem_fdg.as_array()[60,:,:],"FDG")
subplot_([1,2,2],osem_amyl.as_array()[60,:,:],"Amyloid")

# 2. Register images

In [None]:
# Some more code goes here
registration = Reg.NiftyAladinSym()
registration.set_reference_image(osem_fdg)
registration.set_floating_image(osem_amyl)
registration.set_parameter('SetPerformRigid','1')
registration.set_parameter('SetPerformAffine','0')
registration.set_parameter('SetWarpedPaddingValue','0')
registration.process()
tm_amyl_to_fdg = registration.get_transformation_matrix_forward()
tm_fdg_to_amyl = tm_amyl_to_fdg.get_inverse()
amyl_registered_to_fdg = registration.get_output()

plt.figure();
subplot_([1,2,1],osem_fdg.as_array()[60,:,:],"FDG")
subplot_([1,2,2],amyl_registered_to_fdg.as_array()[60,:,:],"Amyloid in FDG space")

# 3. A resample function?

In [None]:
# How about a bit of code here?
def resample(tm, flo, ref):
    resampler = Reg.NiftyResample()
    resampler.set_reference_image(ref)
    resampler.set_floating_image(flo)
    resampler.set_interpolation_type_to_linear()
    resampler.set_padding_value(0)
    resampler.add_transformation(tm)
    resampler.process()
    return resampler.get_output()

# 4. Maybe some de Pierro functions

Copy lines from the "no motion case"

# 5. Are we ready?

In [None]:
beta = 0.1

# Final code!

# create initial image
fdg_obj_fn = get_obj_fun(fdg_sino,fdg_uMap)
fdg_reconstructor = get_reconstructor(num_subsets,num_subiters,fdg_obj_fn,fdg_init_image)
amyl_obj_fn = get_obj_fun(amyl_sino,amyl_uMap)
amyl_reconstructor = get_reconstructor(num_subsets,num_subiters,amyl_obj_fn,amyl_init_image)

current_fdg_image = fdg_init_image.clone()
current_amyl_image = amyl_init_image.clone()

all_images_fdg = numpy.ndarray(shape=(num_subiters+1,) + current_fdg_image.as_array().shape );
all_images_amyl = numpy.ndarray(shape=(num_subiters+1,) + current_amyl_image.as_array().shape );

all_images_fdg[0,:,:,:] = current_fdg_image.as_array();
all_images_amyl[0,:,:,:] = current_amyl_image.as_array();

for iter in range(1, num_subiters+1):
    start_time = time.time()

    # Update FDG weights as fn. of amyloid image
    current_amyl_in_fdg_space = resample(tm_amyl_to_fdg,current_amyl_image,current_fdg_image)
    weights_fdg = update_bowsher_weights(fdg_prior,current_amyl_in_fdg_space,num_bowsher_neighbours)
    
    # Do FDG de Pierro update
    current_fdg_image = MAPEM_iteration(fdg_reconstructor,current_fdg_image,weights_fdg,nhoodIndVec_fdg,beta)
    all_images_fdg[iter,:,:,:] = current_fdg_image.as_array();
    
    # Now update the amyloid weights as fn. of FDG image
    currentl_fdg_in_amyl_space = resample(tm_fdg_to_amyl,current_fdg_image,current_amyl_image)
    weights_amyl = update_bowsher_weights(amyl_prior,current_fdg_image,num_bowsher_neighbours)
    
    # And do amyloid de Pierro update
    current_amyl_image = MAPEM_iteration(amyl_reconstructor,current_amyl_image,weights_amyl,nhoodIndVec_amyl,beta)
    all_images_amyl[iter,:,:,:] = current_amyl_image.as_array();
    
    print('\n------------- Subiteration %i finished in %i s.' % (iter, time.time() - start_time))

In [None]:
#%% now call this function to see how we went along
plt.figure()
subiterations = (1,2,4,8,16,32,42);
plot_progress_compare([all_images_fdg],[all_images_amyl], ['FDG MAPEM'], ['Amyloid MAPEM'],subiterations, all_images_fdg.max());