# Writing MAP-EM with SIRF
This notebook provides a start for writing MAP-EM yourself. It gives you basic lines for setting up a simulation, a simplified implementation of MAP-EM (with one known problem), and a few lines to help plot results.

You are strongly advised to complete (at least) the first half of the ML_reconstruction exercise before starting with this one. For instance, the simulation and plotting code here is taken directly from the `ML_reconstruction` code, but with fewer comments.

Author: Kris Thielemans  
First version: 19th of May 2018

CCP PETMR Synergistic Image Reconstruction Framework (SIRF)  
Copyright 2017 - 2018 University College London.

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
     http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Initial set-up

In [None]:
#%% Initial imports etc
import numpy
from numpy.linalg import norm
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import os
import sys
import shutil
#import scipy
#from scipy import optimize
import sirf.STIR as pet
from sirf.Utilities import examples_data_path
# plotting settings
plt.ion() # interactive 'on' such that plots appear during loops

%matplotlib notebook

In [None]:
#%% some handy function definitions
def imshow(image, limits, title=''):
    """Usage: imshow(image, [min,max], title)"""
    plt.title(title)
    bitmap=plt.imshow(image)
    if len(limits)==0:
        limits=[image.min(),image.max()]
                
    plt.clim(limits[0], limits[1])
    plt.colorbar(shrink=.6)
    plt.axis('off');
    return bitmap

def make_positive(image_array):
    """truncate any negatives to zero"""
    image_array[image_array<0] = 0;
    return image_array;

def make_cylindrical_FOV(image):
    """truncate to cylindrical FOV"""
    filter = pet.TruncateToCylinderProcessor()
    filter.apply(image)



# Create some sample data

In [None]:
#%% go to directory with input files
# adapt this path to your situation (or start everything in the relevant directory)
os.chdir(examples_data_path('PET'))


In [None]:
#%% copy files to working folder and change directory to where the output files are
shutil.rmtree('working_folder/thorax_single_slice',True)
shutil.copytree('thorax_single_slice','working_folder/thorax_single_slice')
os.chdir('working_folder/thorax_single_slice')

In [None]:
#%% Read in image
image = pet.ImageData('emission.hv');
# save max for future displays
image_array=image.as_array()
cmax = image_array.max()*.6

In [None]:
#%% create acquisition model
am = pet.AcquisitionModelUsingRayTracingMatrix()
templ = pet.AcquisitionData('template_sinogram.hs')
am.set_up(templ,image); 

In [None]:
#%% simulate some data using forward projection
acquired_data=am.forward(image)

# create OSEM reconstructor

In [None]:
#%% create objective function
obj_fun = pet.make_Poisson_loglikelihood(acquired_data)
obj_fun.set_acquisition_model(am)

In [None]:
#%% create OSEM reconstructor
OSEM_reconstructor = pet.OSMAPOSLReconstructor()
OSEM_reconstructor.set_objective_function(obj_fun)
OSEM_reconstructor.set_num_subsets(1)
num_subiters=10;
OSEM_reconstructor.set_num_subiterations(num_subiters)

In [None]:
#%%  create initial image
init_image=image.get_uniform_copy(cmax/4)
make_cylindrical_FOV(init_image)

In [None]:
#%% initialise
OSEM_reconstructor.set_up(init_image)

# Implement MAP-EM!
Actually, we will implement MAP-OSEM as that's easy to do.

The lines below (almost) implement MAP-OSEM with a prior which just smooths along the horizontal direction. If you run it, you should see some warnings, and get the wrong image. Try to fix that. Once you did that, either evaluate, and/or extend to 2D, and/or increase the speed of your Python code.

We will use the algorithm described in

Guobao Wang and Jinyi Qi,  
Penalized Likelihood PET Image Reconstruction using Patch-based Edge-preserving Regularization  
IEEE Trans Med Imaging. 2012 Dec; 31(12): 2194â€“2204.   
[doi:  10.1109/TMI.2012.2211378](https://dx.doi.org/10.1109%2FTMI.2012.2211378)

However, we will not use patches here, but just a simple quadratic prior. Also, for simplicity/speed we will just consider a 1D neighbourhood (horizontal lines)

We will occasionally use the notation of the paper (but not consistently, mainly to avoid conflicts between a coordinate *x* and the image *x*).

Please note: this code was written to be simple. It is not terribly safe, very slow, and doesn't use best programming practices (it uses global variables, has no docstrings, no testing of validate of input, etc).

## define weights as an array
Note that code further below assumes that this has 3 elements


In [None]:
w=numpy.array([1.,0.,1.])
# normalise to have sum 1
w/=w.sum()

## define a function that computes $x_{reg}$
For a simple quadratic prior (with normalised weights)

$x^{reg}_j={1\over 2}\sum_{k\in N_j} w_{jk}(x_j+x_k)$

In [None]:
def compute_xreg(image_array):
    sizes=image_array.shape
    image_reg= image_array*0 # make a copy first. Will then change values
    for z in range(0,sizes[0]):
        for y in range(0,sizes[1]):
            for x in range(1,sizes[2]-1): # ignore first and last pixel for simplicity
                for dx in (-1,0,1):
                    image_reg[z,y,x] += w[dx+1]/2*(image_array[z,y,x]+image_array[z,y,x+dx])
            
    return image_reg

## define a function that computes the MAP-EM update, given $x_{EM}$ and $x_{reg}$

$x^{new}={2 x^{EM}  \over \sqrt{(1-\beta x^{reg})^2 + 4 \beta x^{EM}} + (1-\beta x^{reg})}$

In [None]:
def compute_MAPEM_update(xEM,xreg, beta):
    return 2*xEM/(numpy.sqrt((1 - beta*xreg)**2 + 4*beta*xEM) + (1 - beta*xreg))

## define some functions for testing

In [None]:
#%% useful for timing (using `time.time()`)
import time

In [None]:
#%% define a function for plotting images and the updates
# This is the same function as in `ML_reconstruction`
def plot_progress(all_images, title, subiterations = []):
    if len(subiterations)==0:
        num_subiters = all_images[0].shape[0]-1;
        subiterations = range(1, num_subiters+1);
    num_rows = len(all_images);
    for iter in subiterations:
        plt.figure()
        for r in range(num_rows):
            plt.subplot(num_rows,2,2*r+1)
            imshow(all_images[r][iter,slice,:,:], [0,cmax], '%s at %d' % (title[r],  iter))
            plt.subplot(num_rows,2,2*r+2)
            imshow(all_images[r][iter,slice,:,:]-all_images[r][iter-1,slice,:,:],[-cmax*.1,cmax*.1], 'update')
        plt.show()

## Finally, let's implement MAP-OSEM!
We have an existing EMML SIRF reconstruction that does the hard work, i.e. compute $x_{EM}$, so let's use that!
Note however that you will get warning messages when you execute this, and the images will not be good. Try and fix it!

In [None]:
#%% do a loop, saving images as we go along
num_subiters = 16;
beta =1;
current_image = init_image.clone()
all_images = numpy.ndarray(shape=(num_subiters+1,) + current_image.as_array().shape );
all_images[0,:,:,:] =  current_image.as_array();
for iter in range(1, num_subiters+1):
    image_reg= compute_xreg(current_image.as_array()) # compute xreg
    OSEM_reconstructor.update(current_image); # compute EM update
    image_EM=current_image.as_array() # get xEM as a numpy array
    updated = compute_MAPEM_update(image_EM, image_reg, beta) # compute new update
    current_image.fill(updated) # store for next iteration
    print(iter)
    all_images[iter,:,:,:] =  updated; # save for plotting later on

In [None]:
#%% display
slice = 0
imshow(current_image.as_array()[slice,:,:],[])

In [None]:
#%% now check how to updates progressed
subiterations = (1,2,4,8,16);
plot_progress([all_images], ['MAP-OSEM'],subiterations)
    

## What now?
Example ideas: experiment with setting beta, maybe optimise the `xreg` code, check what subsets do, compare with another algorithm solving the same penalised reconstruction.