In [13]:
import numpy as np

def read_xcat_bin_image(file_path: str,
                        width: int = 128,
                        height: int = 128,
                        depth: int = 128) -> np.array:
    

    # Define the dimensions of the image


    # Calculate the number of elements in the image
    num_elements = width * height * depth
    
    # Read the binary file
    image_data = np.fromfile(file_path, dtype=np.float32, count=num_elements)
    
    # Reshape the data to the desired dimensions (depth, height, width)
    image_data = image_data.reshape((depth, height, width))
    image_data = np.rot90(image_data, 2)
    return image_data

In [14]:
from src.operators.radon import Radon
from src.operators.total_variation import TotalVariation
from src.solvers.pwls import PWLS
from src.solvers.chambolle_pock import ChambollePock
from src.solvers.admm import ADMM


from PIL import Image
import torch
import numpy as np
import torch_radon as tr

device = 'cuda'

# define projector settings
n_rays = 128
height, width = 128, 128
n_angles = 52
volume = tr.Volume2D()
volume.set_size(height=height, width=width)
angles = torch.linspace(0, 2 * torch.pi * ((n_angles - 1))/n_angles, n_angles, device=device)

# define projector
radon = Radon(n_rays=n_rays,
              angles=angles,
              volume=volume)

# load image
img_path = '/home/adepaepe/Data/xcat/xcat_extractions/default_phantom_diff_breathing/phantoms/p_1_atn_3.bin'
x_true =read_xcat_bin_image(img_path)[64, :, :]
x_true = torch.from_numpy(x_true.copy()).float().unsqueeze(0).unsqueeze(0).to(device)

# generate noisy sinogram
I = 3e5
EPS = 1e-10
sino = radon.transform(x_true)
yi = torch.poisson(I*torch.exp(-sino))
noisy_sino = torch.log(I/ (yi + EPS))


# shared params
x0 = torch.zeros_like(x_true, device=device)
b = noisy_sino
n_iter = 300
n_inner_iter = 100
weights = yi


#############################################################################################################################
#                                                           FBP                                                             #
#############################################################################################################################


x = radon.fbp(b)
import matplotlib.pyplot as plt

plt.imsave('./data/gt/0.png', x_true.cpu().squeeze(), cmap='gray')


#############################################################################################################################
#                                                           PWLS                                                            #
#############################################################################################################################

# tv = TotalVariation(penalty='l2')
# solver = PWLS(radon=radon, regularizer=tv)

# beta = 1e5

# x = solver.solve(x0,
#                 b,
#                 beta,
#                 n_iter,
#                 weights)


#############################################################################################################################
#                                                           ADMM                                                            #
#############################################################################################################################

# tv = TotalVariation(penalty='l1')
# solver = ADMM(radon=radon, regularizer=tv)

# beta = 3e2
# rho = 5000

# x = solver.solve(x0,
#                 b,
#                 beta,
#                 rho,
#                 n_iter,
#                 n_inner_iter,
#                 weights)


#############################################################################################################################
#                                                      Chambolle Pock                                                       #
#############################################################################################################################




# tv = TotalVariation(penalty='l1')
# solver = ChambollePock(radon=radon, regularizer=tv)

# beta = 3e2
# theta = 1
# L = tv.norm(height, width)
# sigma = 0.99 * (1e8 / (np.sqrt(1e8 * 1) * L))
# tau = 0.99 * (1 / (np.sqrt(1e8 * 1) * L))

# x = solver.solve( x0,
#                     b,
#                     beta,
#                     tau,
#                     sigma,
#                     theta,
#                     n_iter,
#                     n_inner_iter,
#                     weights)

