In [None]:
import os
from sirf import STIR as pet
from sirf.contrib.partitioner import partitioner

from cil.optimisation.functions import SGFunction
from cil.optimisation.algorithms import GD
from cil.optimisation.utilities import Sampler, ConstantStepSize
from cil.optimisation.utilities.callbacks import ProgressCallback

from img_quality_cil_stir import ImageQualityCallback


# engine's messages go to files, except error messages, which go to stdout
_ = pet.MessageRedirector('info.txt', 'warn.txt')
# Needed for get_subsets()
pet.AcquisitionData.set_storage_scheme('memory')
# fewer message from STIR and SIRF
pet.set_verbosity(0)

def initial_OSEM(acquired_data, additive_term, mult_factors, initial_image):
    num_subsets = 1
    data, acq_models, obj_funs = partitioner.data_partition(acquired_data, additive_term, mult_factors, num_subsets)

    obj_fun = pet.make_Poisson_loglikelihood(data[0])
    obj_fun.set_acquisition_model(acq_models[0])
    recon = pet.OSMAPOSLReconstructor()
    recon.set_objective_function(obj_fun)
    recon.set_current_estimate(initial_image)
    # some arbitrary numbers here
    recon.set_num_subsets(2)
    num_subiters = 14
    recon.set_num_subiterations(num_subiters)
    recon.set_up(initial_image)
    recon.process()
    return recon.get_output()


In [None]:
def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3):
    '''
    Construct the Relative Difference Prior (RDP)
    
    WARNING: return prior with beta/num_subsets (as currently needed for BSREM implementations)
    '''
    prior = pet.RelativeDifferencePrior()
    # need to make it differentiable
    epsilon = initial_image.max() * max_scaling
    prior.set_epsilon(epsilon)
    prior.set_penalisation_factor(penalty_strength)
    prior.set_kappa(kappa)
    prior.set_up(initial_image)
    return prior
    
def add_prior(prior, objective_functions):
    '''Add prior evenly to every objective function.
    
    WARNING: it modifies the objective functions'''
    for f in objective_functions:
        f.set_prior(prior)    

In [None]:
# https://github.com/SyneRBI/PETRIC/blob/recon_with_metrics/metrics/NEMA-IQ-CIL.ipynb
import tensorboardX
from datetime import datetime
import numpy as np
# create a tensorboardX summary writer
dt_string = datetime.now().strftime("%Y%m%d-%H%M%S")
tb_summary_writer = tensorboardX.SummaryWriter(f'recons/exp-{dt_string}')
def MSE(x,y):
    """ mean squared error between two numpy arrays
    """
    return ((x-y)**2).mean()

def MAE(x,y):
    """ mean absolute error between two numpy arrays
    """
    return np.abs(x-y).mean()

def PSNR(x, y, scale = None):
    """ peak signal to noise ratio between two numpy arrays x and y
        y is considered to be the reference array and the default scale
        needed for the PSNR is assumed to be the max of this array
    """
  
    mse = ((x-y)**2).mean()
  
    if scale == None:
        scale = y.max()
  
    return 10*np.log10((scale**2) / mse)



In [None]:
os.chdir('/home/jovyan/work/Challenge24/data')

In [None]:
acquired_data = pet.AcquisitionData('prompts.hs')

additive_term = pet.AcquisitionData('additive.hs')

mult_factors = pet.AcquisitionData('multfactors.hs')

initial_image = pet.ImageData('OSEM_image.hv')
osem_sol = initial_image
# This should be an image to give voxel-dependent weights 
# (here predetermined as the row-sum of the Hessian of the log-likelihood at an initial OSEM reconstruction, see eq. 25 in [7])
kappa = initial_image.allocate(1.)

In [None]:
# load the ROIs

ground_truth = initial_image
roi_image_dict = {
    'S1': pet.ImageData('S1.hv'),
    'S2': pet.ImageData('S2.hv'),
    'S3': pet.ImageData('S3.hv'),
    'S4': pet.ImageData('S4.hv'),
    'S5': pet.ImageData('S5.hv'),
    'S6': pet.ImageData('S6.hv'),
    'S7': pet.ImageData('S7.hv'),
}
# instantiate ImageQualityCallback
img_qual_callback = ImageQualityCallback(ground_truth, tb_summary_writer,
                                              roi_mask_dict = roi_image_dict,
                                              metrics_dict = {'MSE':MSE, 
                                                              'MAE':MAE, 
                                                              'PSNR':PSNR},
                                              statistics_dict = {'MEAN': (lambda x: x.mean()),
                                                                 'STDDEV': (lambda x: x.std()),
                                                                 'MAX': (lambda x: x.max()),
                                                                 'COM': (lambda x: np.array([3,2,1]))},
                                              )


## Using SIRF Objective Functions

In [None]:
num_subsets = 7
data, acq_models, obj_funs = partitioner.data_partition(acquired_data, additive_term, mult_factors, num_subsets, mode='staggered', initial_image=initial_image)

In [None]:

# add RDP prior to the objective functions
step_size = 1e-7
add_regulariser = True
if add_regulariser:
    alpha = 500
    prior = construct_RDP(alpha, initial_image, kappa)
    # epsilon = initial_image.max()*1e-4
    # prior = add_RDP(alpha, epsilon, obj_funs)
    add_prior(prior, obj_funs)
    step_size = 1e-10

In [None]:
#set up and run the gradient descent algorithm

sampler = Sampler.random_without_replacement(len(obj_funs))
# requires a minus sign for CIL's algorithm as they are minimisers
F = - SGFunction(obj_funs, sampler=sampler)
# ISTA default step_size is 0.99*2.0/F.L
step_size_rule = ConstantStepSize(step_size)

alg = GD(initial=initial_image, objective_function=F, step_size=step_size_rule)

In [None]:
alg.run(10, callbacks=[img_qual_callback, ProgressCallback()])

In [None]:
from cil.utilities.display import show2D 
cmax = .15
im_slice = 70
osem_sol = initial_image
show2D([osem_sol.as_array()[im_slice,:,:], 
        alg.solution.as_array()[im_slice,:,:]], 
       title=['OSEM',f"{alg.__class__.__name__} epoch {alg.iteration/num_subsets}"], 
       cmap="PuRd", fix_range=[(0, 0.2),(0,0.2)], origin='upper-left')