In [1]:
#%% Initial imports etc
import os
os.environ['OMP_NUM_THREADS'] = '15'

import numpy as np
import matplotlib.pyplot as plt
import shutil
import sirf.STIR as pet

from sirf.Utilities import examples_data_path
from ccpi.optimisation.algorithms import CGLS, PDHG, FISTA
from ccpi.optimisation.operators import BlockOperator, LinearOperator
from ccpi.optimisation.functions import KullbackLeibler, IndicatorBox, \
         FunctionOperatorComposition, BlockFunction, MixedL21Norm , ZeroFunction, KullbackLeibler
from ccpi.framework import ImageData
from ccpi.plugins.regularisers import FGP_TV, FGP_dTV
setattr(FGP_TV, 'convex_conjugate', lambda self,x: 0.0)

%matplotlib inline

ImportError: No module named sirf.STIR

In [2]:
# imports for plotting
from __future__ import print_function, division
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy

def display_slice(container, direction, title, cmap, minmax, size, axis_labels):
    
        
    def get_slice_3D(x):
        
        if direction == 0:
            img = container[x]
            x_lim = container.shape[2]
            y_lim = container.shape[1]
            x_label = axis_labels[2]
            y_label = axis_labels[1] 
            
        elif direction == 1:
            img = container[:,x,:]
            x_lim = container.shape[2]
            y_lim = container.shape[0] 
            x_label = axis_labels[2]
            y_label = axis_labels[0]             
            
        elif direction == 2:
            img = container[:,:,x]
            x_lim = container.shape[1]
            y_lim = container.shape[0]    
            x_label = axis_labels[1]
            y_label = axis_labels[0]             
        
        if size is None:
            fig = plt.figure()
        else:
            fig = plt.figure(figsize=size)
        
        if isinstance(title, (list, tuple)):
            dtitle = title[x]
        else:
            dtitle = title
        
        gs = gridspec.GridSpec(1, 2, figure=fig, width_ratios=(1,.05), height_ratios=(1,))
        # image
        ax = fig.add_subplot(gs[0, 0])
      
        ax.set_xlabel(x_label)     
        ax.set_ylabel(y_label)
 
        aximg = ax.imshow(img, cmap=cmap, origin='upper', extent=(0,x_lim,y_lim,0))
        aximg.set_clim(minmax)
        ax.set_title(dtitle + " {}".format(x))
        # colorbar
        ax = fig.add_subplot(gs[0, 1])
        plt.colorbar(aximg, cax=ax)
        plt.tight_layout()
        plt.show(fig)
        
    return get_slice_3D

    
def islicer(data, direction, title="", slice_number=None, cmap='gray', minmax=None, size=None, axis_labels=None):

    '''Creates an interactive integer slider that slices a 3D volume along direction
    
    :param data: DataContainer or numpy array
    :param direction: slice direction, int, should be 0,1,2 or the axis label
    :param title: optional title for the display
    :slice_number: int start slice number, optional. If None defaults to center slice
    :param cmap: matplotlib color map
    :param minmax: colorbar min and max values, defaults to min max of container
    :param size: int or tuple specifying the figure size in inch. If int it specifies the width and scales the height keeping the standard matplotlib aspect ratio 
    '''
    
    if axis_labels is None:
        if hasattr(data, "dimension_labels"):
            axis_labels = [data.dimension_labels[0],data.dimension_labels[1],data.dimension_labels[2]]
        else:
            axis_labels = ['X', 'Y', 'Z']

    
    if hasattr(data, "as_array"):
        container = data.as_array()
        
        if not isinstance (direction, int):
            if direction in data.dimension_labels.values():
                direction = data.get_dimension_axis(direction)                             

    elif isinstance (data, numpy.ndarray):
        container = data
        
    if slice_number is None:
        slice_number = int(data.shape[direction]/2)
        
    slider = widgets.IntSlider(min=0, max=data.shape[direction]-1, step=1, 
                             value=slice_number, continuous_update=False, description=axis_labels[direction])

    if minmax is None:
        amax = container.max()
        amin = container.min()
    else:
        amin = min(minmax)
        amax = max(minmax)
    
    if isinstance (size, (int, float)):
        default_ratio = 6./8.
        size = ( size , size * default_ratio )
    
    interact(display_slice(container, 
                           direction, 
                           title=title, 
                           cmap=cmap, 
                           minmax=(amin, amax),
                           size=size, axis_labels=axis_labels),
             x=slider);
    
    return slider
    

def link_islicer(*args):
    '''links islicers IntSlider widgets'''
    linked = [(widg, 'value') for widg in args]
    # link pair-wise
    pairs = [(linked[i+1],linked[i]) for i in range(len(linked)-1)]
    for pair in pairs:
        widgets.link(*pair)

def psnr(img1, img2, data_range=1):
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 1000
    return 20 * numpy.log10(data_range / numpy.sqrt(mse))


def plotter2D(datacontainers, titles=None, fix_range=False, stretch_y=False, cmap='gray', axis_labels=None):
    '''plotter2D(datacontainers=[], titles=[], fix_range=False, stretch_y=False, cmap='gray', axes_labels=['X','Y'])
    
    plots 1 or more 2D plots in an (n x 2) matix
    multiple datasets can be passed as a list
    
    Can take ImageData, AquistionData or numpy.ndarray as input
    '''
    if(isinstance(datacontainers, list)) is False:
        datacontainers = [datacontainers]

    if titles is not None:
        if(isinstance(titles, list)) is False:
            titles = [titles]
            

    
    nplots = len(datacontainers)
    rows = int(round((nplots+0.5)/2.0))

    fig, (ax) = plt.subplots(rows, 2,figsize=(15,15))

    axes = ax.flatten() 

    range_min = float("inf")
    range_max = 0
    
    if fix_range == True:
        for i in range(nplots):
            if type(datacontainers[i]) is numpy.ndarray:
                dc = datacontainers[i]
            else:
                dc = datacontainers[i].as_array()
                
            range_min = min(range_min, numpy.amin(dc))
            range_max = max(range_max, numpy.amax(dc))
        
    for i in range(rows*2):
        axes[i].set_visible(False)

    for i in range(nplots):
        axes[i].set_visible(True)
        
        if titles is not None:
            axes[i].set_title(titles[i])
       
        if axis_labels is not None:
            axes[i].set_ylabel(axis_labels[1])
            axes[i].set_xlabel(axis_labels[0]) 
            
        if type(datacontainers[i]) is numpy.ndarray:
            dc = datacontainers[i]          
        else:
            dc = datacontainers[i].as_array()
            
            if axis_labels is None:
                axes[i].set_ylabel(datacontainers[i].dimension_labels[0])
                axes[i].set_xlabel(datacontainers[i].dimension_labels[1])        
        
        
        sp = axes[i].imshow(dc, cmap=cmap, origin='upper', extent=(0,dc.shape[1],dc.shape[0],0))
    
        
        im_ratio = dc.shape[0]/dc.shape[1]
        
        if stretch_y ==True:   
            axes[i].set_aspect(1/im_ratio)
            im_ratio = 1
            
        plt.colorbar(sp, ax=axes[i],fraction=0.0467*im_ratio, pad=0.02)
        
        if fix_range == True:
            sp.set_clim(range_min,range_max) 

In [3]:
# Kullback Leibler methods with numba
try:
    import numba
    from numba import jit, prange
    import numpy
    from numpy import sqrt, log, inf
    has_numba = True
    '''Some parallelisation of KL calls'''
    @jit(nopython=True)
    def kl_proximal(x,b, bnoise, tau, out):
            for i in prange(x.size):
                out.flat[i] = 0.5 *  ( 
                    ( x.flat[i] - bnoise.flat[i] - tau ) +\
                    numpy.sqrt( (x.flat[i] + bnoise.flat[i] - tau)**2. + \
                        (4. * tau * b.flat[i]) 
                    )
                )
    @jit(nopython=True)
    def kl_proximal_conjugate(x, b, bnoise, tau, out):
        #z = x + tau * self.bnoise
        #return 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt())

        for i in prange(x.size):
            z = x.flat[i] + ( tau * bnoise.flat[i] )
            out.flat[i] = 0.5 * ( 
                (z + 1) - numpy.sqrt((z-1)*(z-1) + 4 * tau * b.flat[i])
                )
    @jit(nopython=True)
    def kl_gradient(x, b, bnoise, out):
        for i in prange(x.size):
            out.flat[i] = 1 - b.flat[i]/(x.flat[i] + bnoise.flat[i])

    @jit(nopython=True)
    def kl_div(x, y, out):
        for i in prange(x.size):
            X = x.flat[i]
            Y = y.flat[i]    
            if x.flat[i] > 0 and y.flat[i] > 0:
                out.flat[i] = X * numpy.log(X/Y) - X + Y
            elif X == 0 and Y >= 0:
                out.flat[i] = Y
            else:
                out.flat[i] = numpy.inf
    
    # force a jit
    x = numpy.asarray(numpy.random.random((10,10)), dtype=numpy.float32)
    b = numpy.asarray(numpy.random.random((10,10)), dtype=numpy.float32)
    bnoise = numpy.zeros_like(x)
    out = numpy.empty_like(x)
    tau = 1.
    kl_div(b,x,out)
    kl_gradient(x,b,bnoise,out)
    kl_proximal(x,b, bnoise, tau, out)
    kl_proximal_conjugate(x,b, bnoise, tau, out)

except ImportError as ie:
    has_numba = False

class ChangeSign(object):
    def __init__(self):
        pass
    
    @staticmethod
    def get_instance(class_name, *args,**kwargs):
        
        setattr(class_name, '__call__', ChangeSign.KL_call)
        setattr(class_name, 'gradient', ChangeSign.KL_gradient)
        setattr(class_name, 'proximal_conjugate', ChangeSign.KL_proximal_conjugate)
        
        instance = class_name(*args, **kwargs)
        # swap the original set_acquisition_data with a modified one
        set_acquisition_data_sirf = instance.set_acquisition_data
        setattr(class_name, 'set_acquisition_data_sirf', set_acquisition_data_sirf)
        setattr(class_name, 'set_acquisition_data', ChangeSign.set_acquisition_data)
        
        # return the new instance
        return instance

    ### Few fixes for common interface
    @staticmethod
    def set_acquisition_data(self, ad):
        #save a reference to acquisition_data in the class
        self.b = ad
        self.set_acquisition_data_sirf(ad)
        
    @staticmethod
    def KL_call(self, x):
        return - self.get_value(x)
    @staticmethod
    def KL_gradient(self, image, subset = -1, out = None):

        assert_validity(image, pet.ImageData)
        grad = pet.ImageData()
        grad.handle = pystir.cSTIR_objectiveFunctionGradient\
            (self.handle, image.handle, subset)
        check_status(grad.handle)
        # change sign
        #grad*=-1
        if out is None:
            return -1 * grad  
        else:
            out.fill(-1 * grad)
    @staticmethod
    def KL_proximal_conjugate(self, x, tau, out=None):

        r'''Proximal operator of the convex conjugate of KullbackLeibler at x:

           .. math::     prox_{\tau * f^{*}}(x)
        '''
        
        self.bnoise = x * 0.
        if has_numba:
            if out is None:
                out = (x * 0.)
                out_np = out.as_array()
                kl_proximal_conjugate(x.as_array(), self.b.as_array(), self.bnoise.as_array(), tau, out_np)
                out.fill(out_np)
                return out
            else:
                out_np = out.as_array()
                kl_proximal_conjugate(x.as_array(), self.b.as_array(), self.bnoise.as_array(), tau, out_np)
                out.fill(out_np)                    
        else:
            if out is None:
                z = x + tau * self.bnoise
                return 0.5*((z + 1) - ((z-1)**2 + 4 * tau * self.b).sqrt())
            else:
                
                tmp = tau * self.bnoise
                tmp += x
                tmp -= 1
                
                self.b.multiply(4*tau, out=out)    
                
                out.add(tmp.power(2), out=out)
                out.sqrt(out=out)
                out *= -1
                tmp += 2
                out += tmp
                out *= 0.5

In [4]:
# Define norm for the acquisition model
def norm(self, **kwargs):
    return LinearOperator.PowerMethod(self, kwargs.get('iterations',10))[0]

setattr(pet.AcquisitionModelUsingRayTracingMatrix, 'norm', norm)

    



#% go to directory with input files

EXAMPLE = 'SIMULATION'


if EXAMPLE == 'SIMULATION':
    data_dir = os.path.abspath('/home/edo/GitHub/PETMR/sympdata')
    os.chdir(data_dir)
    ##%% copy files to working folder and change directory to where the output files are
    new_dir = os.path.abspath(os.path.join(data_dir, 'CIL-numba'))
    
    # adapt this path to your situation (or start everything in the relevant directory)
    #os.chdir('/mnt/data/CCPPETMR/201909_hackathon/Simulations/PET/SimulationData')
    #shutil.rmtree(new_dir,True)
    if not os.path.exists(new_dir):
        shutil.copytree(data_dir,new_dir)
    os.chdir(new_dir)
    
    ground_truth = 'FDG_small.hv'
    attenuation_header = 'uMap_small.hv'
    image_header = attenuation_header
    sinogram_header = 'FDG_sino_noisy.hs'

image = pet.ImageData(image_header);
image_array=image.as_array()
mu_map = pet.ImageData(attenuation_header);
mu_map_array=mu_map.as_array();

# Show Emission image
print('Size of emission: {}'.format(image.shape))


#%%
sinogram = pet.AcquisitionData(sinogram_header)
# rebin the data to speed up
sinogram = sinogram.rebin(11)

# attenuation
attn_acq_model = pet.AcquisitionModelUsingRayTracingMatrix()
asm_attn = pet.AcquisitionSensitivityModel(mu_map, attn_acq_model)
# converting attenuation into attenuation factors (see previous exercise)
asm_attn.set_up(sinogram)
attn_factors = pet.AcquisitionData(sinogram)
attn_factors.fill(1.0)
print('applying attenuation (please wait, may take a while)...')
asm_attn.unnormalise(attn_factors)
# use these in the final attenuation model
asm_attn = pet.AcquisitionSensitivityModel(attn_factors)

am = pet.AcquisitionModelUsingRayTracingMatrix()
# we will increate the number of rays used for every Line-of-Response (LOR) as an example
# (it is not required for the exercise of course)
am.set_num_tangential_LORs(5)
am.set_acquisition_sensitivity(asm_attn)

# this seems to use a lot of memory! 256 Gb went!
# pet.AcquisitionData.set_storage_scheme('memory')
am.set_up(sinogram,image)

#% simulate some data using forward projection
if EXAMPLE == 'SIMULATION':
    
    acquired_data = sinogram
    image.fill(1)
    noisy_data = acquired_data.clone()

# Show util per iteration
def show_data(it, obj, x):
    plt.imshow(x.as_array()[0])
    plt.colorbar()
    plt.show()

Size of emission: (127, 150, 150)
applying attenuation (please wait, may take a while)...


In [5]:
#%% TV reconstruction using algorithm below

alpha = 0


method = 'implicit'

if method == 'explicit':

    # Create operators
    op1 = GradientSIRF(image) 
    op2 = am

    # Create BlockOperator
    operator = BlockOperator(op1, op2, shape=(2,1) ) 

    f2 = KullbackLeibler(noisy_data)  
    g =  IndicatorBox(lower=0)    

    f1 = alpha * MixedL21Norm() 
    f = BlockFunction(f1, f2)  
    normK = operator.norm()

elif method == 'implicit':

    operator = am      
    # refdata, regularisation_parameter, iterations, tolerance, eta_const, methodTV, nonneg, device
    #g = FGP_dTV(mu_map, alpha, 500, 1e-7, 1e-2, 0, 1, 'gpu' )
    #g = FGP_TV(alpha, 500, 1e-7, 0, 1, 0, 'gpu' ) 
    # f = KullbackLeibler(noisy_data)
    if alpha == 0:
        g = IndicatorBox(lower=0)
    else:
        g = FGP_TV(alpha, 500, 1e-7, 0, 1, 0, 'gpu' )

#         fidelity = pet.PoissonLogLikelihoodWithLinearModelForMeanAndProjData()

    fidelity = ChangeSign.get_instance(pet.PoissonLogLikelihoodWithLinearModelForMeanAndProjData)  
    fidelity.set_acquisition_model(am)
    fidelity.set_acquisition_data(noisy_data)
    fidelity.set_num_subsets(4)
    fidelity.set_up(image)
    fidelity.L = 1e4
    print ("Calculating operator norm")
    normK = operator.norm(iterations=5)
    print ("done")

sigma = 100.
tau = 1/(sigma*normK**2)
tau = 1. / normK
sigma = 1. / normK

# Setup and run the PDHG algorithm
def sirf_update_objective(self):

    p1 = self.f((self.x)) + self.g(self.x)
    #d1 = -(self.f.convex_conjugate(self.y) + self.g.convex_conjugate(-1*self.operator.adjoint(self.y)))
    #p1 = 0.
    #d1 = 0.
    #self.loss.append([p1, d1, p1-d1])
    self.loss.append(p1)

setattr(PDHG, 'update_objective', sirf_update_objective )
algo = PDHG(f = fidelity, g = g, operator = operator, tau = tau, sigma = sigma)
algo.max_iteration = 500
algo.update_objective_interval = 2
algo.run(2)

Calculating operator norm
done
PDHG setting up
PDHG configured
     Iter   Max Iter     Time/Iter            Objective
                               [s]                     
        0        500         0.000          9.16145e+06
        2        500        12.171          3.57824e+06


In [None]:
# OSMAPOSL reconstruction
#fidelity_sirf = pet.make_Poisson_loglikelihood(noisy_data)
# fidelity_sirf = pet.PoissonLogLikelihoodWithLinearModelForMeanAndProjData()

# #fidelity_sirf??
# fidelity_sirf.set_acquisition_model(am)
# fidelity_sirf.set_acquisition_data(noisy_data)
# fidelity_sirf.set_num_subsets(4)
# fidelity_sirf.set_up(image)

In [None]:
ground_truth_image = pet.ImageData(ground_truth)
reconstructed_image = ground_truth_image.allocate(1)
recon = pet.OSMAPOSLReconstructor()
recon.set_objective_function(fidelity)
recon.set_num_subsets(1)
num_iters=10;
recon.set_num_subiterations(num_iters)
#recon.set_input(noisy_data)
recon.set_up(reconstructed_image)
recon.reconstruct(reconstructed_image)

islicer(ground_truth_image, 1)

In [6]:
# load MR data for dTV
refdata = pet.ImageData('T2.hv')
print (refdata.shape)

(127, 285, 285)


In [23]:
### how to find a proper alpha
alpha = 5e-3
obj = []
r_iterations = 500
r_alpha = alpha
r_tolerance = 1e-7
r_eta_const = 1e-2
r_iso = 0
r_nonneg = 1
#FGP_dTV(refdata, r_alpha, r_iterations, r_tolerance, r_eta_const, r_iso, r_nonneg, device='gpu')

In [24]:
print (tau)
print (algo.get_output().shape)

0.00553003608753365
(127, 150, 150)


In [None]:
# reset algo and run 500 iterations
# it should take 86.107 s/iter * 500 iter / 3600 s/h = 11.9 h
alphas = [ r_alpha, r_alpha] #, 0.,]
gs = [ FGP_dTV(refdata, r_alpha, r_iterations, r_tolerance, r_eta_const, r_iso, r_nonneg, device='gpu'), 
      FGP_TV(alpha, 500, 1e-7, 0, 1, 0, 'gpu' )]#, IndicatorBox(lower=0.)]
regul = ['FGP_dTV', 'FGP_TV', 'IndicatorBox']
algos = []
for alpha, g, reg in zip(alphas, gs, regul):
    algo = PDHG(f = fidelity, g = g, operator = operator, tau = tau, sigma = sigma)
    algo.max_iteration = 500
    algo.update_objective_interval = 10

    if False:
        run = 10
        for i in range(algo.max_iteration / run):
            algo.run(run)
            # saves to os.getcwd()
            #print (os.getcwd())
            fname = os.path.join(os.getcwd(),"PDHG_{}_alpha{}_iter_{}".format(reg,alpha,algo.iteration))
            algo.get_output().write(fname)
    else:
        def save_output(iteration,obj,x):
            int_alpha = '1e-2'
            fname = os.path.join(os.getcwd(),"PDHG_{}_alpha{}_iter_{}".format(reg,int_alpha,iteration))
            #print (fname)
            x.write(fname)

        algo.run(500,callback=save_output,verbose=True)
        #algos.append(algo)

PDHG setting up
PDHG configured
     Iter   Max Iter     Time/Iter            Objective
                               [s]                     
        0        500         0.000          9.16145e+06
       10        500        14.124          3.57980e+06
       20        500        13.598          1.43798e+06
       30        500        12.554          1.52704e+06
       40        500        11.740          1.62166e+06
       50        500        11.582          1.60829e+06
       60        500        11.695          1.43776e+06
       70        500        11.212          1.36896e+06
       80        500        11.580          1.36988e+06
       90        500        11.737          1.37458e+06
      100        500        11.468          1.38151e+06
      110        500        11.373          1.36226e+06
      120        500        11.815          1.35911e+06
      130        500        11.772          1.35196e+06
      140        500        11.901          1.35066e+06
      150       

In [None]:
# fidelity_sirf = pet.PoissonLogLikelihoodWithLinearModelForMeanAndProjData()

# fidelity_sirf??
# fidelity_sirf.set_acquisition_model(am)
# fidelity_sirf.set_acquisition_data(noisy_data)
# fidelity_sirf.set_num_subsets(1)
# fidelity_sirf.set_up(image)

In [29]:
reg = regul[0]
algos[0].run()

reg = regul[1]
algos[1].run()


IndexError: list index out of range