In [None]:
#import libraries
from cil.framework import AcquisitionGeometry 
from cil.plugins.astra.operators import ProjectionOperator
from cil.optimisation.algorithms import PDHG
from cil.optimisation.functions import L2NormSquared
from cil.io import NEXUSDataWriter, NEXUSDataReader
from cil.utilities.display import show2D
from cil.plugins.ccpi_regularisation.functions import FGP_dTV 

import numpy as np

In [None]:
# Load reference image prescan
name = "FBP_reconstructions/FBP_projections_720"   
reader = NEXUSDataReader(file_name=name+".nxs")
fbp_recon_pre_scan = reader.load_data()

# Load reference image postscan
name = "FBP_reconstructions/FBP_projections_1600"   
reader = NEXUSDataReader(file_name=name+".nxs")
fbp_recon_post_scan = reader.load_data()

In [None]:
# Regularisation parameters for each case
alpha={'18':0.0081, '36':0.0072, '72':0.0081, '360':0.03}

In [None]:
# additional parameters for dTV
eta = 0.005
iterations = 100
tolerance = 1e-6
methodTV = 0
nonneg = 1

In [None]:
max_iterations = 100 # change to 1000 to match results of the paper

In [None]:
for i in [18,36,72,360]:  
    
    #load sparse data
    name_proj = "SparseData/data_{}".format(i)    
    reader = NEXUSDataReader(file_name=name_proj+".nxs")
    data = reader.load_data()
    
    # get acquisition geometry
    ag = data.geometry
    ig = ag.get_ImageGeometry()
    ig.voxel_num_x = 256
    ig.voxel_num_y = 256
    
    # Create single slice ProjectionOperator
    ag_sc = ag.subset(channel=0)
    ig_sc = ig.subset(channel=0)
    Aop = ProjectionOperator(ig_sc, ag_sc, 'gpu') 
    
    print("Shape for data is {}".format(ag.shape))    
    
    K = Aop
    normK = K.norm()
    sigma = 1./normK
    tau = 1./normK     
    
    recon = ig.allocate()
    
    ########################## Solve for the first frame ##########################
    print("Start dTV regularisation for the frame 0 with {} projections".format(i))
    F0 = 0.5 * L2NormSquared(b=data.subset(channel=0))
    G0 = alpha[str(i)]*FGP_dTV(reference = fbp_recon_pre_scan, eta=eta, device='gpu')  

    pdhg0 = PDHG(f = F0, g = G0, operator = K, tau = tau, sigma = sigma,
                max_iteration = max_iterations,
                update_objective_interval = 100)
    pdhg0.run(verbose = 0)   
    recon.fill(pdhg0.solution, channel=0)
    print("Finish dTV regularisation for the frame 0 with {} projections\n".format(i))
    ##############################################################################
    
    ########################## Solve for all the others ##########################
    G = alpha[str(i)] * FGP_dTV(reference = fbp_recon_post_scan, eta=eta, device='gpu') 
   
    for tf in range(1,17):
        print("Start dTV regularisation for the frame {} with {} projections".format(tf,i))
        F = 0.5 * L2NormSquared(b=data.subset(channel=tf))
        # Solve for all the other frames
        pdhg = PDHG(f = F, g = G, operator = K, tau = tau, sigma = sigma,
                    max_iteration = max_iterations,
                    update_objective_interval = 100)
        pdhg.run(verbose = 0)   
        recon.fill(pdhg.solution, channel=tf)     
        print("Finish dTV regularisation for the frame {} with {} projections\n".format(tf,i))
        
    show2D(recon, slice_list = [0,5,10,16], num_cols=4, origin="upper",
                       cmap="inferno", title="Projections {}".format(i), size=(25, 20))
    
    name_recon = "dTVReconstruction_projections_{}".format(i)
    writer = NEXUSDataWriter(file_name = "dTV_reconstructions/"+name_recon+".nxs",
                         data = recon)
    writer.write() 
    