# CT Reconstruction with Besov Prior

This notebook demonstrates the use of the Besov prior and the RTO-MH sampler for computed tomography (CT) reconstruction.
The workflow involves setting up the phantom image, applying the CT forward model, and sampling the posterior distribution using Bayesian inference.

In [1]:
# Importing necessary libraries and modules
from skimage.transform import rescale
from skimage.data import shepp_logan_phantom
import numpy as np
from scipy import sparse
from besov_prior2D import besov_prior2D
from inverse_problem_2D import inverse_problem_2D
from CT import CT
from rto_mh_2D import rto_mh_2D
import pywt

## Step 1: Generate the Phantom Image and Set Up the CT Model
We use the Shepp-Logan phantom as the ground truth image for CT reconstruction. The CT forward model is set up with a specified number of projection angles.

In [2]:
# Set random seed for reproducibility
np.random.seed(100)

# Define parameters for the CT model
J = 6 # Wavelet level
n = 2**J # Image size (n x n) 
N = 2**(2*J) # Total number of pixels 
n_theta = 30 # Number of projection angles
im_size = n # Rescaled image size

# Generate the Shepp-Logan phantom and rescale it to the desired size
im = shepp_logan_phantom()
im_1 = rescale(im,im_size/400)

# Projection angles
theta = np.linspace(0,180,n_theta,endpoint=False)

# Initialize the CT forward model and set the observed data
likelihood = CT(theta)
likelihood.set_data(im_1,2.0)

# Save the phantom image and CT data for later use
np.save('CT_IM.npy',im_1)
np.save("CT_data.npy",likelihood.data)

# Compute the number of projection data points
m = len(likelihood.data.ravel())

## Step 2: Save the CT System Matrix
The system matrix generated by the CT forward model is saved for efficiency in future computations.

In [3]:
# Save the CT system matrix as a sparse matrix
sparse.save_npz("A_64.npz",likelihood.jac_const(n))

## Step 3: Perform Bayesian Sampling (s=1.0, p=1.5)
Set up the Besov prior with specific parameters and perform sampling using the RTO-MH algorithm.

In [None]:
# Set random seed for reproducibility
np.random.seed(100)

# Load the CT system matrix
A = sparse.load_npz('A_64.npz')

s= 1.0
p = 1.5
delt = 0.025
level = 0
wavelet = 'db1'
a, slices, shape = pywt.ravel_coeffs(pywt.wavedec2(np.ones((n,n)),wavelet,mode='periodization',level=J-level))
prior = besov_prior2D(J,delt,level, slices, shape,s=s,p=p, wavelet=wavelet)
nsamp= 50000
x0 = np.ones(N)
sampler = rto_mh_2D(x0,N+m,samp=nsamp)
problem = inverse_problem_2D(likelihood, prior, A@prior.jac_const())
z_map= sampler.initialize_Q(problem)
sampler.x0 = 8*z_map
chain, acc_rate, index_accept, log_c_chain  = sampler.sample(problem)
np.save('CTSamples'+ wavelet + str(s) + str(p)+ '.npy',chain)
np.save('CTacc_rate' + wavelet + str(s) + str(p)+ '.npy',acc_rate)
np.save("CTMH" + wavelet + str(s) + str(p)+ '.npy',log_c_chain)
np.save("CTindex_accept" + wavelet + str(s) + str(p)+ '.npy', index_accept)


## Step 4: Perform Bayesian Sampling (s=2.5, p=1.5)
Repeat the above process with modified parameters for the Besov prior.

In [None]:
np.random.seed(100)
A = sparse.load_npz('A_64.npz')
s= 2.5
p = 1.5
delt = 0.025
level = 0
wavelet = 'db1'
a, slices, shape = pywt.ravel_coeffs(pywt.wavedec2(np.ones((n,n)),wavelet,mode='periodization',level=J-level))
prior = besov_prior2D(J,delt,level, slices, shape,s=s,p=p, wavelet=wavelet)
nsamp= 200
x0 = np.ones(N)
sampler = rto_mh_2D(x0,N+m,samp=nsamp)
problem = inverse_problem_2D(likelihood, prior, A@prior.jac_const())
z_map= sampler.initialize_Q(problem)
sampler.x0 = 8*z_map
chain, acc_rate, index_accept, log_c_chain  = sampler.sample(problem)
np.save('CTSamples'+ wavelet + str(s) + str(p)+ '.npy',chain)
np.save('CTacc_rate' + wavelet + str(s) + str(p)+ '.npy',acc_rate)
np.save("CTMH" + wavelet + str(s) + str(p)+ '.npy',log_c_chain)
np.save("CTindex_accept" + wavelet + str(s) + str(p)+ '.npy', index_accept)


## Step 5: Perform Bayesian Sampling (s=2.5, p=1.0)
Repeat the above process with modified parameters for the Besov prior.

In [None]:
np.random.seed(100)
A = sparse.load_npz('A_64.npz')
s= 2.5
p = 1.0
delt = 0.025
level = 0
wavelet = 'db1'
a, slices, shape = pywt.ravel_coeffs(pywt.wavedec2(np.ones((n,n)),wavelet,mode='periodization',level=J-level))
prior = besov_prior2D(J,delt,level, slices, shape,s=s,p=p, wavelet=wavelet)
nsamp= 200
x0 = np.ones(N)
sampler = rto_mh_2D(x0,N+m,samp=nsamp)
problem = inverse_problem_2D(likelihood, prior, A@prior.jac_const())
z_map= sampler.initialize_Q(problem)
sampler.x0 = 8*z_map
chain, acc_rate, index_accept, log_c_chain  = sampler.sample(problem)
np.save('CTSamples'+ wavelet + str(s) + str(p)+ '.npy',chain)
np.save('CTacc_rate' + wavelet + str(s) + str(p)+ '.npy',acc_rate)
np.save("CTMH" + wavelet + str(s) + str(p)+ '.npy',log_c_chain)
np.save("CTindex_accept" + wavelet + str(s) + str(p)+ '.npy', index_accept)