**MBIRJAX: Cropped Center Reconstruction Demo**

See the [MBIRJAX documentation](https://mbirjax.readthedocs.io/en/latest/) for an overview and details.  

This script demonstrates how to do a reconstruction restricted to an area around the center of rotation.  This is a difficult problem for model-based methods, so the reconstruction has some unavoidable artifacts and intensity shifts, but it does show features near the center without having to do a full reconstruction, which may be useful for large sinograms.  

For this demo, we use a larger detector since the artifacts are more pronounced for very small-scale problems.  To maintain a short run-time, we reduce the number of slices.

Also, we use a phantom that projects fully onto the detector, but you can use the same method when the object projects partially outside the detector.  

For simplicity, we show this only for parallel beam, but the same steps apply for cone beam.   

See [demo_1_shepp_logan.py](https://colab.research.google.com/drive/1zG_H6CDjuQxeMRQHan3XEyX2YVKcSSNC) for the basic steps of synthetic sinogram generation and reconstruction.

Select a GPU as runtime type for best performance.

In [1]:
import numpy as np
import pprint
import time
import jax.numpy as jnp
import mbirjax

import importlib
importlib.reload(mbirjax)



<module 'mbirjax' from '/home/zhengtan/anaconda3/envs/ct_reconstruction/lib/python3.13/site-packages/mbirjax/__init__.py'>

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

**Set the geometry parameters**

In [3]:
# Set parameters for the problem size.  Here we use a fairly large number of views and channels to illustrate
# the behavior of a center cropped recon.
num_views = 400
num_det_rows = 20
num_det_channels = 400

start_angle = - np.pi / 2
end_angle = np.pi / 2

**Data generation:** For demo purposes, we create a phantom and then project it to create a sinogram.

For this demo we use the default recon shape for parallel beam.

Note:  the sliders on the viewer won't work in notebook form.  For that you'll need to run the python code with an interactive matplotlib backend, typcially using the command line or a development environment like Spyder or Pycharm to invoke python.  


In [4]:
# Initialize sinogram
sinogram_shape = (num_views, num_det_rows, num_det_channels)
angles = jnp.linspace(start_angle, end_angle, num_views, endpoint=False)

ct_model_for_generation = mbirjax.ParallelBeamModel(sinogram_shape, angles)

# Generate large 3D Shepp Logan phantom
print('Creating phantom')
phantom = ct_model_for_generation.gen_modified_3d_sl_phantom()
print("Shape of phantom: ", phantom.shape)

# Generate synthetic sinogram data
print('Creating sinogram')
sinogram = ct_model_for_generation.forward_project(phantom)
print("Type of sinogram: ", type(sinogram))
sinogram = np.asarray(sinogram)
print("sinogram.shape: ", sinogram.shape)


# View sinogram
title='Original sinogram\nVery few rows to allow for fast execution'
save_path = './results/sinogram.png'
mbirjax.slice_viewer(sinogram, title=title, slice_axis=0, slice_label='View', save=True, save_path=save_path)

delta_voxel:  None
Creating phantom
Shape of phantom:  (400, 400, 20)
Creating sinogram
Type of sinogram:  <class 'jaxlib.xla_extension.ArrayImpl'>
sinogram.shape:  (400, 20, 400)


**Do a cropped center reconstruction**

We specify a recon size smaller than the default. This specifies that only a subregion near the center of rotation will be reconstructed.  

As in demo_2_large_object, this will yield artifacts because some of the information in the sinogram comes from voxels that are projected to the detector on only some of the views.  We minimize the high-frequency artifacts by decreasing sharpness a little, but thee is a residual intensity shift.  

In [5]:
# Initialize model for reconstruction.
weights = None
ct_model_for_recon = mbirjax.ParallelBeamModel(sinogram_shape, angles)

# Print model parameters
ct_model_for_recon.print_params()

sharpness = -0.5
recon_row_scale = 0.5
recon_col_scale = 0.5
ct_model_for_recon.scale_recon_shape(row_scale=recon_row_scale, col_scale=recon_col_scale)
ct_model_for_recon.set_params(sharpness=sharpness)

print('Starting cropped center recon')
time0 = time.time()


recon, recon_params = ct_model_for_recon.recon(sinogram, weights=weights)
recon.block_until_ready()
elapsed = time.time() - time0

# Print out parameters used in recon
if isinstance(recon_params, dict):
    pprint.pprint(recon_params, compact=True)
else:
    pprint.pprint(recon_params._asdict(), compact=True)
print('Elapsed time for recon is {:.3f} seconds'.format(elapsed))

delta_voxel:  None
----
geometry_type = <class 'mbirjax.parallel_beam.ParallelBeamModel'>
file_format = 1.0
sinogram_shape = (400, 20, 400)
delta_det_channel = 1.0
delta_det_row = 1.0
det_row_offset = 0.0
det_channel_offset = 0.0
sigma_y = 1.0
recon_shape = (400, 400, 20)
delta_voxel = 1.0
sigma_x = 1.0
sigma_prox = 1.0
p = 2.0
q = 1.2
T = 1.0
qggmrf_nbr_wts = [1.0, 1.0, 1.0]
auto_regularize_flag = True
positivity_flag = False
snr_db = 30.0
sharpness = 1.0
granularity = [1, 2, 4, 8, 16, 32, 64, 128, 256]
partition_sequence = [0, 2, 4, 6, 7]
verbose = 1
use_gpu = automatic
view_params_name = angles
----


GPU used for: full
Estimated GPU memory required = 0.509 GB, available = 91.131 GB
Estimated CPU memory required = 0.030 GB, available = 1446.084 GB


Starting cropped center recon


Starting direct recon for initial reconstruction
Initializing error sinogram
Computing Hessian diagonal
Starting VCD iterations

After iteration 0 of a max of 15: Pct change=24.1444, Forward loss=12.5234
Relative step size (alpha)=0.02, Error sino RMSE=23.8305
Number subsets = 1

After iteration 1 of a max of 15: Pct change=27.4062, Forward loss=12.0599
Relative step size (alpha)=0.10, Error sino RMSE=22.9485
Number subsets = 4

After iteration 2 of a max of 15: Pct change=24.2946, Forward loss=11.8987
Relative step size (alpha)=0.28, Error sino RMSE=22.6417
Number subsets = 16

After iteration 3 of a max of 15: Pct change=23.0440, Forward loss=11.8472
Relative step size (alpha)=1.03, Error sino RMSE=22.5438
Number subsets = 64

After iteration 4 of a max of 15: Pct change=14.2731, Forward loss=11.8362
Relative step size (alpha)=1.30, Error sino RMSE=22.5229
Number subsets = 128

After iteration 5 of a max of 15: Pct change=6.5725, Forward loss=11.8328
Relative step size (alpha)=1.32, 

{'model_params': {'T': Param(val=1.0, recompile_flag=False),
                  'angles': Param(val=[-1.5707964  -1.5629424  -1.5550884  -1.5472344  -1.5393804  -1.5315264
 -1.5236725  -1.5158185  -1.5079645  -1.5001105  -1.4922565  -1.4844025
 -1.4765487  -1.4686946  -1.4608406  -1.4529866  -1.4451326  -1.4372786
 -1.4294246  -1.4215707  -1.4137167  -1.4058627  -1.3980087  -1.3901547
 -1.3823007  -1.3744467  -1.3665928  -1.3587388  -1.3508848  -1.3430308
 -1.335177   -1.327323   -1.319469   -1.311615   -1.303761   -1.295907
 -1.288053   -1.280199   -1.2723451  -1.2644911  -1.2566371  -1.2487831
 -1.2409291  -1.2330751  -1.2252212  -1.2173672  -1.2095132  -1.2016592
 -1.1938052  -1.1859514  -1.1780974  -1.1702434  -1.1623894  -1.1545354
 -1.1466814  -1.1388274  -1.1309735  -1.1231195  -1.1152655  -1.1074115
 -1.0995575  -1.0917034  -1.0838495  -1.0759954  -1.0681416  -1.0602875
 -1.0524336  -1.0445796  -1.0367256  -1.0288717  -1.0210177  -1.0131637
 -1.0053097  -0.9974557  -0.98960173 -

**Display the cropped center reconstruction.**

In [6]:
print("Shape of phantom: ", phantom.shape)

Shape of phantom:  (400, 400, 20)


In [7]:
title = 'Cropped center recon with sharpness = {:.1f}: Phantom (left) vs VCD Recon (right)'.format(sharpness)
title += '\nThis recon does not include all pixels used to generate the sinogram.'
title += '\nThe missing pixels lead to an intensity shift (adjust intensity to [0, 1]) and a bright outer ring.'
recon_shape = ct_model_for_recon.get_params('recon_shape')
recon_radius = [length // 2 for length in recon_shape]
start_inds = [phantom.shape[j] // 2 - recon_radius[j] for j in range(2)]
end_inds = [start_inds[j] + recon_shape[j] for j in range(2)]
cropped_phantom = phantom[start_inds[0]:end_inds[0], start_inds[1]:end_inds[1]]

mse = np.mean((cropped_phantom - recon)**2)
print("MSE: ", mse)

print("Shape of cropped_phantom: ", cropped_phantom.shape)
print("recon_shape: ", recon_shape)
print("recon_radius: ", recon_radius)
print("start_inds: ", start_inds)
print("end_inds: ", end_inds)

save_path = './results/cropped_center_recon.png'
mbirjax.slice_viewer(cropped_phantom, recon, title=title, vmin=0.0, vmax=2.0, save=True, save_path=save_path)
# mbirjax.slice_viewer(phantom, recon, title=title, vmin=0.0, vmax=2.0, save=True, save_path=save_path)


MSE:  0.1552195
Shape of cropped_phantom:  (200, 200, 20)
recon_shape:  (200, 200, 20)
recon_radius:  [100, 100, 10]
start_inds:  [100, 100]
end_inds:  [300, 300]


**Next:** Try changing some of the parameters and re-running or try [some of the other demos](https://mbirjax.readthedocs.io/en/latest/demos_and_faqs.html).  