# Method 2: Iterative Reconstruction
## SIRT (Simultaneous Iterative Reconstruction Technique) + TV Regularization

This notebook implements iterative reconstruction as an optimization-based method.

**Pipeline:**
```
Low-dose Sinogram → SIRT + TV Regularization → Reconstructed CT Image
```

**Advantages over FBP:**
- Better noise suppression
- Can incorporate prior knowledge (TV regularization)
- More robust to incomplete data

---

## 1. Setup and Imports

In [12]:
# Install required packages
!pip install numpy h5py scipy matplotlib pandas tqdm astra-toolbox scikit-image

Collecting astra-toolbox
  Using cached astra-toolbox-1.8b5.tar.gz (452 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mPreparing metadata [0m[1;32m([0m[32mpyproject.toml[0m[1;32m)[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[26 lines of output][0m
  [31m   [0m ./autogen.sh: line 3: aclocal: command not found
  [31m   [0m Error running aclocal
  [31m   [0m Traceback (most recent call last):
  [31m   [0m   File "/Users/davidranamagar/Project/COSC /4372 Fundametals of Medical Imaging/ct-reconstruction-pipeline/.venv/lib/python3.12/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
  [31m   [0m     main()
  [31m   [0m   File "/Users/davidranamagar/Project/COSC /4372 Fundameta

In [13]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import pandas as pd
import json
from datetime import datetime
from scipy.ndimage import zoom
from skimage.metrics import structural_similarity

# Try to import ASTRA
try:
    import astra
    ASTRA_AVAILABLE = True
    print("✓ ASTRA Toolbox available")
except ImportError:
    ASTRA_AVAILABLE = False
    print("⚠ ASTRA Toolbox not available, using fallback implementation")

print("Setup complete!")

⚠ ASTRA Toolbox not available, using fallback implementation
Setup complete!


## 2. Configuration

In [None]:
# ===========================
# CONFIGURATION
# ===========================

# Data paths
DATA_DIR = Path("../data/prepared/lodopab")  # Update for Lambda Labs
OUTPUT_DIR = Path("../data/results/iterative_reconstruction")
IR_CONFIG = {}
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Iterative reconstruction parameters
IR_CONFIG = {
    'num_iterations': 50,        # SIRT iterations
    'tv_lambda': 0.01,           # TV regularization strength
    'tv_iterations': 10,         # TV minimization iterations
    'num_test_samples': 100,     # Number of samples to process (None = all)
    'use_tv': True,              # Enable TV regularization
}

print("Configuration:")
print(f"  Data directory: {DATA_DIR}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  SIRT iterations: {IR_CONFIG['num_iterations']}")
print(f"  TV regularization: {IR_CONFIG['use_tv']}")
if IR_CONFIG['use_tv']:
    print(f"  TV lambda: {IR_CONFIG['tv_lambda']}")
    print(f"  TV iterations: {IR_CONFIG['tv_iterations']}")

## 3. Forward and Back-Projection Operators

In [15]:
class ProjectionOperator:
    """
    Forward and back projection operators for CT reconstruction
    Uses ASTRA if available, otherwise uses skimage as fallback
    """
    def __init__(self, image_size=362, num_angles=1000, num_detectors=513):
        self.image_size = image_size
        self.num_angles = num_angles
        self.num_detectors = num_detectors
        self.angles = np.linspace(0, np.pi, num_angles, endpoint=False)
        
        if ASTRA_AVAILABLE:
            self._setup_astra()
    
    def _setup_astra(self):
        """Setup ASTRA projector"""
        # Create geometries
        self.vol_geom = astra.create_vol_geom(self.image_size, self.image_size)
        self.proj_geom = astra.create_proj_geom(
            'parallel', 1.0, self.num_detectors, self.angles
        )
        
        # Create projector
        self.proj_id = astra.create_projector('cuda', self.proj_geom, self.vol_geom)
    
    def forward(self, image):
        """
        Forward projection: image → sinogram
        """
        if ASTRA_AVAILABLE:
            sinogram_id, sinogram = astra.create_sino(image, self.proj_id)
            astra.data2d.delete(sinogram_id)
            return sinogram
        else:
            # Fallback: use radon transform
            from skimage.transform import radon
            angles_deg = np.degrees(self.angles)
            sino = radon(image, theta=angles_deg, circle=True)
            # Resize if needed
            if sino.shape[1] != self.num_detectors:
                sino = zoom(sino, (1, self.num_detectors/sino.shape[1]), order=1)
            return sino.T  # (num_angles, num_detectors)
    
    def backward(self, sinogram):
        """
        Back projection: sinogram → image
        """
        if ASTRA_AVAILABLE:
            recon_id = astra.data2d.create('-vol', self.vol_geom)
            sinogram_id = astra.data2d.create('-sino', self.proj_geom, sinogram)
            
            cfg = astra.astra_dict('BP')
            cfg['ProjectionDataId'] = sinogram_id
            cfg['ReconstructionDataId'] = recon_id
            cfg['ProjectorId'] = self.proj_id
            
            alg_id = astra.algorithm.create(cfg)
            astra.algorithm.run(alg_id, 1)
            
            recon = astra.data2d.get(recon_id)
            
            astra.algorithm.delete(alg_id)
            astra.data2d.delete([recon_id, sinogram_id])
            
            return recon
        else:
            # Fallback: use iradon
            from skimage.transform import iradon
            angles_deg = np.degrees(self.angles)
            recon = iradon(sinogram.T, theta=angles_deg, filter_name=None, circle=True)
            # Resize to target size
            if recon.shape[0] != self.image_size:
                scale = self.image_size / recon.shape[0]
                recon = zoom(recon, scale, order=1)
            return recon

print("✓ Projection operators defined")

✓ Projection operators defined


## 4. SIRT Algorithm

In [16]:
def sirt_reconstruction(sinogram, proj_op, num_iterations=50, verbose=False):
    """
    Simultaneous Iterative Reconstruction Technique (SIRT)
    
    Parameters:
    -----------
    sinogram : np.ndarray
        Measured sinogram (num_angles, num_detectors)
    proj_op : ProjectionOperator
        Projection operator
    num_iterations : int
        Number of SIRT iterations
    verbose : bool
        Show progress
    
    Returns:
    --------
    recon : np.ndarray
        Reconstructed image
    """
    # Initialize with back-projection
    recon = proj_op.backward(sinogram)
    recon = np.maximum(recon, 0)  # Non-negativity constraint
    
    # Normalization factors
    ones_image = np.ones((proj_op.image_size, proj_op.image_size), dtype=np.float32)
    ones_sino = np.ones_like(sinogram, dtype=np.float32)
    
    R = proj_op.backward(ones_sino) + 1e-6  # Column normalization
    C = proj_op.forward(ones_image) + 1e-6  # Row normalization
    
    # SIRT iterations
    iterator = tqdm(range(num_iterations), desc="SIRT") if verbose else range(num_iterations)
    
    for iteration in iterator:
        # Forward project
        fp = proj_op.forward(recon)
        
        # Calculate residual
        residual = (sinogram - fp) / C
        
        # Back project residual
        bp = proj_op.backward(residual)
        
        # Update
        recon += bp / R
        
        # Non-negativity
        recon = np.maximum(recon, 0)
    
    return recon

print("✓ SIRT algorithm defined")

✓ SIRT algorithm defined


## 5. Total Variation (TV) Regularization

In [17]:
def tv_denoise(image, lambda_tv=0.01, num_iterations=10):
    """
    Total Variation denoising using gradient descent
    
    Minimizes: ||image - image_noisy||^2 + lambda * TV(image)
    
    Parameters:
    -----------
    image : np.ndarray
        Input image
    lambda_tv : float
        Regularization strength
    num_iterations : int
        Number of iterations
    
    Returns:
    --------
    denoised : np.ndarray
        Denoised image
    """
    denoised = image.copy()
    dt = 0.25  # Time step
    
    for _ in range(num_iterations):
        # Compute gradients
        grad_x = np.zeros_like(denoised)
        grad_y = np.zeros_like(denoised)
        
        grad_x[:, :-1] = denoised[:, 1:] - denoised[:, :-1]
        grad_y[:-1, :] = denoised[1:, :] - denoised[:-1, :]
        
        # TV gradient magnitude
        grad_mag = np.sqrt(grad_x**2 + grad_y**2 + 1e-8)
        
        # Normalized gradients
        grad_x_norm = grad_x / grad_mag
        grad_y_norm = grad_y / grad_mag
        
        # Divergence
        div = np.zeros_like(denoised)
        div[:, 1:] += grad_x_norm[:, :-1]
        div[:, :-1] -= grad_x_norm[:, :-1]
        div[1:, :] += grad_y_norm[:-1, :]
        div[:-1, :] -= grad_y_norm[:-1, :]
        
        # Update
        denoised = denoised + dt * ((image - denoised) + lambda_tv * div)
        denoised = np.maximum(denoised, 0)  # Non-negativity
    
    return denoised

print("✓ TV regularization defined")

✓ TV regularization defined


## 6. Combined SIRT + TV Reconstruction

In [18]:
def sirt_tv_reconstruction(
    sinogram,
    num_sirt_iterations=50,
    tv_lambda=0.01,
    tv_iterations=10,
    use_tv=True,
    verbose=False
):
    """
    SIRT reconstruction with optional TV regularization
    
    Parameters:
    -----------
    sinogram : np.ndarray
        Input sinogram (1000, 513) - LoDoPaB format
    num_sirt_iterations : int
        Number of SIRT iterations
    tv_lambda : float
        TV regularization strength
    tv_iterations : int
        TV denoising iterations
    use_tv : bool
        Apply TV regularization
    verbose : bool
        Show progress
    
    Returns:
    --------
    recon : np.ndarray
        Reconstructed image (362, 362)
    """
    # Create projection operator
    proj_op = ProjectionOperator(
        image_size=362,
        num_angles=sinogram.shape[0],
        num_detectors=sinogram.shape[1]
    )
    
    # SIRT reconstruction
    recon = sirt_reconstruction(
        sinogram,
        proj_op,
        num_iterations=num_sirt_iterations,
        verbose=verbose
    )
    
    # Apply TV denoising
    if use_tv:
        recon = tv_denoise(recon, lambda_tv=tv_lambda, num_iterations=tv_iterations)
    
    return recon.astype(np.float32)

print("✓ SIRT+TV reconstruction defined")

✓ SIRT+TV reconstruction defined


## 7. Evaluation Metrics

In [19]:
def calculate_psnr(img1, img2, data_range=1.0):
    """Calculate PSNR"""
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * np.log10(data_range / np.sqrt(mse))

def calculate_ssim(img1, img2, data_range=1.0):
    """Calculate SSIM"""
    return structural_similarity(img1, img2, data_range=data_range)

def calculate_nrmse(img1, img2):
    """Calculate NRMSE"""
    rmse = np.sqrt(np.mean((img1 - img2) ** 2))
    return rmse / (img2.max() - img2.min())

def normalize_image(img):
    """Normalize to [0, 1]"""
    img_min, img_max = img.min(), img.max()
    if img_max - img_min > 1e-8:
        return (img - img_min) / (img_max - img_min)
    return img

print("✓ Metrics defined")

✓ Metrics defined


## 8. Load Test Data

In [20]:
# Find test files
test_obs_files = sorted(list(DATA_DIR.glob("observation_test_*.hdf5")))
test_gt_files = sorted(list(DATA_DIR.glob("ground_truth_test_*.hdf5")))

print(f"Found {len(test_obs_files)} observation files")
print(f"Found {len(test_gt_files)} ground truth files")

if test_obs_files and test_gt_files:
    total_samples = sum(h5py.File(f, 'r')['data'].shape[0] for f in test_obs_files)
    print(f"Total test samples: {total_samples}")
    
    if IR_CONFIG['num_test_samples']:
        print(f"Will process: {IR_CONFIG['num_test_samples']} samples")
    else:
        print(f"Will process: All samples")

Found 28 observation files
Found 28 ground truth files
Total test samples: 3553
Will process: 100 samples


## 9. Run SIRT+TV on Test Set

In [21]:
# Results storage
results = {
    'sample_id': [],
    'psnr': [],
    'ssim': [],
    'nrmse': [],
}

sample_count = 0
max_samples = IR_CONFIG['num_test_samples'] if IR_CONFIG['num_test_samples'] else float('inf')

print("\nRunning SIRT+TV reconstruction...")
print("="*60)
print("NOTE: This will take longer than FBP (iterative optimization)")
print("="*60)

for file_idx, (obs_file, gt_file) in enumerate(zip(test_obs_files, test_gt_files)):
    
    with h5py.File(obs_file, 'r') as f_obs, h5py.File(gt_file, 'r') as f_gt:
        
        num_in_file = f_obs['data'].shape[0]
        pbar = tqdm(range(num_in_file), desc=f"File {file_idx+1}")
        
        for local_idx in pbar:
            if sample_count >= max_samples:
                break
            
            sinogram = f_obs['data'][local_idx].astype(np.float32)
            ground_truth = f_gt['data'][local_idx].astype(np.float32)
            
            # Reconstruct
            reconstructed = sirt_tv_reconstruction(
                sinogram,
                num_sirt_iterations=IR_CONFIG['num_iterations'],
                tv_lambda=IR_CONFIG['tv_lambda'],
                tv_iterations=IR_CONFIG['tv_iterations'],
                use_tv=IR_CONFIG['use_tv'],
                verbose=False
            )
            
            # Normalize
            recon_norm = normalize_image(reconstructed)
            gt_norm = normalize_image(ground_truth)
            
            # Metrics
            psnr = calculate_psnr(recon_norm, gt_norm)
            ssim = calculate_ssim(recon_norm, gt_norm)
            nrmse = calculate_nrmse(recon_norm, gt_norm)
            
            results['sample_id'].append(sample_count)
            results['psnr'].append(psnr)
            results['ssim'].append(ssim)
            results['nrmse'].append(nrmse)
            
            pbar.set_postfix({'PSNR': f"{psnr:.2f}", 'SSIM': f"{ssim:.3f}"})
            sample_count += 1
    
    if sample_count >= max_samples:
        break

print(f"\n✓ Processed {sample_count} samples")


Running SIRT+TV reconstruction...
NOTE: This will take longer than FBP (iterative optimization)


ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

## 10. Calculate and Display Results

In [11]:
# Convert to DataFrame
df_results = pd.DataFrame(results)

# Summary statistics
summary = {
    'method': 'SIRT+TV' if IR_CONFIG['use_tv'] else 'SIRT',
    'num_iterations': IR_CONFIG['num_iterations'],
    'tv_enabled': IR_CONFIG['use_tv'],
    'tv_lambda': IR_CONFIG['tv_lambda'] if IR_CONFIG['use_tv'] else None,
    'num_samples': len(df_results),
    'psnr_mean': df_results['psnr'].mean(),
    'psnr_std': df_results['psnr'].std(),
    'ssim_mean': df_results['ssim'].mean(),
    'ssim_std': df_results['ssim'].std(),
    'nrmse_mean': df_results['nrmse'].mean(),
    'nrmse_std': df_results['nrmse'].std(),
    'timestamp': datetime.now().isoformat()
}

# Display
print("\n" + "="*60)
print("ITERATIVE RECONSTRUCTION RESULTS")
print("="*60)
print(f"Method: {summary['method']}")
print(f"SIRT Iterations: {summary['num_iterations']}")
if summary['tv_enabled']:
    print(f"TV Lambda: {summary['tv_lambda']}")
print(f"Samples: {summary['num_samples']}")
print()
print(f"PSNR:  {summary['psnr_mean']:.2f} ± {summary['psnr_std']:.2f} dB")
print(f"SSIM:  {summary['ssim_mean']:.4f} ± {summary['ssim_std']:.4f}")
print(f"NRMSE: {summary['nrmse_mean']:.4f} ± {summary['nrmse_std']:.4f}")
print("="*60)

# Save results
df_results.to_csv(OUTPUT_DIR / 'ir_results.csv', index=False)
with open(OUTPUT_DIR / 'ir_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\n✓ Results saved to {OUTPUT_DIR}")


ITERATIVE RECONSTRUCTION RESULTS
Method: SIRT+TV
SIRT Iterations: 50
TV Lambda: 0.01
Samples: 0

PSNR:  nan ± nan dB
SSIM:  nan ± nan
NRMSE: nan ± nan

✓ Results saved to results/iterative
