# CT-PINN-DADif: Physics-Informed Deep Adaptive Diffusion for CT Reconstruction

This notebook provides a **complete and validated** training pipeline for CT-PINN-DADif.

## Features:
- **Physics Validation Tests**: Adjointness, FBP round-trip, comparison with scikit-image
- **Phantom Visualization**: Shepp-Logan and random phantoms
- **Sinogram Visualization**: Forward projection inspection
- **Proper Noise Model**: Poisson noise only (physically correct for CT)
- **Training with Metrics**: PSNR, SSIM, RMSE, MAE
- **Sparse-View and Limited-Angle**: Demonstrations

## Key Physics (CT vs MRI):
- **Forward Model**: Radon Transform (line integrals) instead of Fourier
- **Noise Model**: Poisson photon counting instead of Gaussian
- **Reconstruction**: FBP instead of inverse FFT
- **Consistency**: Sinogram consistency instead of k-space

In [None]:
# Install dependencies
!pip install torch torchvision numpy scipy matplotlib tqdm -q

# Clone repository (if running on Colab)
import os
import sys

# Check if we're in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone if not already done
    if not os.path.exists('/content/PINN_Dadiff'):
        !git clone https://github.com/Iammohithhh/PINN_Dadiff.git /content/PINN_Dadiff
    
    # Add to Python path (the parent directory containing ct_reconstruction)
    if '/content/PINN_Dadiff' not in sys.path:
        sys.path.insert(0, '/content/PINN_Dadiff')
    
    print("Colab setup complete!")
else:
    # Running locally - add parent directory to path
    notebook_dir = os.path.dirname(os.path.abspath('__file__'))
    parent_dir = os.path.dirname(os.path.dirname(notebook_dir))  # PINN_Dadiff
    if parent_dir not in sys.path:
        sys.path.insert(0, parent_dir)
    print("Local setup complete!")

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Import CT-PINN-DADif modules from the package
from ct_reconstruction.src.ct_physics import (
    RadonTransform, FilteredBackProjection, compute_num_detectors,
    create_sparse_view_mask, create_limited_angle_mask, test_adjoint
)
from ct_reconstruction.src.data_loader import (
    create_shepp_logan_phantom, create_random_phantom, SimulatedCTDataset, create_dataloaders
)
from ct_reconstruction.src.model import CT_PINN_DADif, CTReconstructionLoss, create_model, create_loss
from ct_reconstruction.src.train import Trainer, compute_metrics, DEFAULT_CONFIG

print('All modules imported successfully!')

## 1. Physics Validation Tests

**Critical**: Before training, we MUST verify that CT physics is correctly implemented.

In [None]:
# Configuration
IMG_SIZE = 256
NUM_ANGLES = 180
NUM_DETECTORS = compute_num_detectors(IMG_SIZE)

print(f'Image size: {IMG_SIZE}x{IMG_SIZE}')
print(f'Number of angles: {NUM_ANGLES}')
print(f'Number of detectors: {NUM_DETECTORS} (sqrt(2) * {IMG_SIZE} = {np.sqrt(2)*IMG_SIZE:.1f})')

In [None]:
# Create operators
radon = RadonTransform(IMG_SIZE, NUM_ANGLES, NUM_DETECTORS, device=device).to(device)
fbp = FilteredBackProjection(IMG_SIZE, NUM_ANGLES, NUM_DETECTORS, device=device).to(device)

print('Forward (Radon) and inverse (FBP) operators created!')

### 1.1 Adjointness Test: <Ax, y> ≈ <x, A^T y>

In [None]:
passed, rel_error = test_adjoint(radon, fbp, device=device, tol=0.15)

print(f'Adjointness Test: {"PASSED" if passed else "FAILED"}')
print(f'Relative error: {rel_error:.6f}')

if not passed:
    print('WARNING: Large adjoint error may indicate geometry issues.')

### 1.2 FBP Round-Trip Test: FBP(Radon(x)) ≈ x

In [None]:
# Create Shepp-Logan phantom
phantom = create_shepp_logan_phantom(IMG_SIZE)
x = torch.from_numpy(phantom).unsqueeze(0).unsqueeze(0).float().to(device)

# Forward and inverse
with torch.no_grad():
    sinogram = radon.forward(x)
    reconstruction = fbp.forward(sinogram)

# Compute metrics
metrics = compute_metrics(reconstruction, x)

print(f'FBP Round-Trip Test:')
print(f'  PSNR: {metrics["psnr"]:.2f} dB (should be > 25 dB)')
print(f'  SSIM: {metrics["ssim"]:.2f}%')
print(f'  RMSE: {metrics["rmse"]:.6f}')
print(f'  MAE:  {metrics["mae"]:.6f}')
print(f'  Status: {"PASSED" if metrics["psnr"] > 25 else "WARNING: PSNR < 25 dB"}')

In [None]:
# Visualize round-trip
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(x[0, 0].cpu().numpy(), cmap='gray')
axes[0].set_title('Original Phantom')
axes[0].axis('off')

axes[1].imshow(sinogram[0, 0].cpu().numpy(), cmap='gray', aspect='auto')
axes[1].set_title(f'Sinogram ({NUM_ANGLES}x{NUM_DETECTORS})')
axes[1].set_xlabel('Detector')
axes[1].set_ylabel('Angle')

axes[2].imshow(reconstruction[0, 0].cpu().numpy(), cmap='gray')
axes[2].set_title(f'FBP Recon (PSNR: {metrics["psnr"]:.1f} dB)')
axes[2].axis('off')

error = torch.abs(reconstruction - x)[0, 0].cpu().numpy()
im = axes[3].imshow(error, cmap='hot')
axes[3].set_title('Absolute Error')
axes[3].axis('off')
plt.colorbar(im, ax=axes[3], fraction=0.046)

plt.tight_layout()
plt.savefig('physics_validation.png', dpi=150)
plt.show()

### 1.3 Compare with scikit-image (Optional)

In [None]:
try:
    from skimage.transform import radon as sk_radon, iradon as sk_iradon
    
    theta = np.linspace(0., 180., NUM_ANGLES, endpoint=False)
    sk_sinogram = sk_radon(phantom, theta=theta, circle=True)
    sk_recon = sk_iradon(sk_sinogram, theta=theta, circle=True)
    
    # Compare
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    
    axes[0, 0].imshow(sk_sinogram.T, cmap='gray', aspect='auto')
    axes[0, 0].set_title('scikit-image Sinogram')
    
    axes[0, 1].imshow(sinogram[0, 0].cpu().numpy(), cmap='gray', aspect='auto')
    axes[0, 1].set_title('Our Sinogram')
    
    axes[1, 0].imshow(sk_recon, cmap='gray')
    axes[1, 0].set_title('scikit-image FBP')
    
    axes[1, 1].imshow(reconstruction[0, 0].cpu().numpy(), cmap='gray')
    axes[1, 1].set_title('Our FBP')
    
    plt.tight_layout()
    plt.show()
    print('scikit-image comparison completed!')
except ImportError:
    print('scikit-image not available. Skipping comparison.')

## 2. Noise Model Verification (Poisson)

In [None]:
# Simulate different dose levels
I0_levels = {'High (1e5)': 1e5, 'Standard (1e4)': 1e4, 'Low (5e3)': 5e3, 'Ultra-Low (1e3)': 1e3}

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

sinogram_clean = radon.forward(x)
eps = 1e-6

for idx, (name, I0) in enumerate(I0_levels.items()):
    with torch.no_grad():
        # Poisson noise (physically correct)
        counts = I0 * torch.exp(-sinogram_clean)
        counts_noisy = torch.poisson(counts.clamp(min=eps))
        sinogram_noisy = -torch.log(counts_noisy.clamp(min=eps) / I0)
        recon = fbp.forward(sinogram_noisy)
    
    m = compute_metrics(recon, x)
    
    axes[0, idx].imshow(sinogram_noisy[0, 0].cpu().numpy(), cmap='gray', aspect='auto')
    axes[0, idx].set_title(f'{name}')
    axes[0, idx].axis('off')
    
    axes[1, idx].imshow(recon[0, 0].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axes[1, idx].set_title(f'PSNR: {m["psnr"]:.1f} dB')
    axes[1, idx].axis('off')

axes[0, 0].set_ylabel('Sinogram', fontsize=12)
axes[1, 0].set_ylabel('FBP Recon', fontsize=12)
plt.suptitle('Poisson Noise at Different Dose Levels (I0)', fontsize=14)
plt.tight_layout()
plt.savefig('noise_simulation.png', dpi=150)
plt.show()

## 3. Sparse-View CT Demonstration

In [None]:
sparse_views = [180, 90, 60, 30]

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, num_views in enumerate(sparse_views):
    mask = create_sparse_view_mask(NUM_ANGLES, num_views, device)
    
    with torch.no_grad():
        sinogram_sparse = sinogram_clean * mask
        recon = fbp.forward(sinogram_sparse)
    
    m = compute_metrics(recon, x)
    
    axes[0, idx].imshow(sinogram_sparse[0, 0].cpu().numpy(), cmap='gray', aspect='auto')
    axes[0, idx].set_title(f'{num_views} views')
    axes[0, idx].axis('off')
    
    axes[1, idx].imshow(recon[0, 0].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axes[1, idx].set_title(f'PSNR: {m["psnr"]:.1f} dB')
    axes[1, idx].axis('off')

plt.suptitle('Sparse-View CT (FBP Baseline)', fontsize=14)
plt.tight_layout()
plt.savefig('sparse_view.png', dpi=150)
plt.show()

## 4. Training Configuration

### Choose Your Training Data:

**Option 1: Simulated Phantoms (Recommended for beginners)**
- No download needed, generates data on-the-fly
- Uses Shepp-Logan and random geometric phantoms
- Good for learning and testing the model
- Set `USE_REAL_DATA = False` below

**Option 2: Real CT Dataset from Internet**
- Downloads a small real CT dataset (~50-100 MB)
- More realistic but requires storage space
- Better for final model evaluation
- Set `USE_REAL_DATA = True` below

In [None]:
#@title Training Configuration { run: "auto" }
#@markdown **Choose your data source:**
USE_REAL_DATA = False  #@param {type:"boolean"}

# Training configuration (optimized for CT)
config = DEFAULT_CONFIG.copy()
config.update({
    'img_size': IMG_SIZE,
    'num_angles': NUM_ANGLES,
    'num_detectors': NUM_DETECTORS,
    
    # Dataset settings
    'num_train_samples': 200,   # For simulated: number of phantoms to generate
    'num_val_samples': 50,
    'num_test_samples': 50,
    'phantom_type': 'mixed',    # 'shepp_logan', 'random', or 'mixed'
    'noise_level': 'low',       # 'none', 'low', 'medium', 'high'
    'acquisition_type': 'full', # 'full', 'sparse', 'limited_angle'
    
    # Training hyperparameters
    'batch_size': 2,            # Reduce to 1 if out of memory
    'num_epochs': 20,           # Increase to 100-200 for better results
    'learning_rate': 6e-3,
    'use_sam': True,            # Sharpness-Aware Minimization
    'use_amp': True,            # Mixed precision (faster on GPU)
    
    # Loss weights (CT-optimized)
    'alpha': 0.4,   # pixel loss weight
    'beta': 0.1,    # perceptual loss weight (disabled)
    'gamma': 0.5,   # physics loss weight
    'use_perceptual': False,
})

print('='*50)
print('TRAINING CONFIGURATION')
print('='*50)
print(f'Data source: {"REAL DATASET" if USE_REAL_DATA else "SIMULATED PHANTOMS"}')
print(f'Image size: {IMG_SIZE}x{IMG_SIZE}')
print(f'Training samples: {config["num_train_samples"]}')
print(f'Batch size: {config["batch_size"]}')
print(f'Epochs: {config["num_epochs"]}')
print(f'Noise level: {config["noise_level"]}')
print('='*50)

## 5. Create Dataloaders

The next cell will either:
- **Generate simulated phantoms** (if `USE_REAL_DATA = False`)
- **Download and load real CT images** (if `USE_REAL_DATA = True`)

In [None]:
if USE_REAL_DATA:
    # ============================================
    # OPTION 2: REAL CT DATASET
    # ============================================
    print("Downloading real CT dataset...")
    
    # Download a small CT dataset (COVID-CT or similar)
    # Using a subset of publicly available CT scans
    import urllib.request
    import zipfile
    from PIL import Image
    from torch.utils.data import Dataset, DataLoader
    
    # Download sample CT images from a public source
    DATA_DIR = Path('/content/ct_data') if IN_COLAB else Path('./ct_data')
    DATA_DIR.mkdir(exist_ok=True)
    
    # For demo, we'll use a simpler approach: download individual sample images
    # You can replace this with your own dataset path
    print("Note: For real datasets, you can:")
    print("  1. Upload your own CT images to /content/ct_data/")
    print("  2. Use kaggle datasets (COVID-CT, LIDC-IDRI subset)")
    print("  3. Use the AAPM Low-Dose CT Challenge data")
    print()
    
    # Check if user has uploaded images
    if not list(DATA_DIR.glob('*.png')) and not list(DATA_DIR.glob('*.jpg')):
        print("No images found. Creating synthetic 'real-like' data for demo...")
        print("(Upload your own .png/.jpg CT images to the ct_data folder)")
        
        # Fall back to simulated but with more realistic phantoms
        from ct_reconstruction.src.data_loader import RealCTDataset
        
        # Generate more realistic phantoms as placeholder
        for i in range(config['num_train_samples'] + config['num_val_samples'] + config['num_test_samples']):
            phantom = create_random_phantom(IMG_SIZE)
            img = Image.fromarray((phantom * 255).astype(np.uint8))
            img.save(DATA_DIR / f'phantom_{i:04d}.png')
        print(f"Generated {i+1} placeholder images in {DATA_DIR}")
    
    # Create dataset from images
    from ct_reconstruction.src.data_loader import RealCTDataset
    
    all_images = sorted(list(DATA_DIR.glob('*.png')) + list(DATA_DIR.glob('*.jpg')))
    n_train = int(0.7 * len(all_images))
    n_val = int(0.15 * len(all_images))
    
    train_dataset = RealCTDataset(
        image_paths=all_images[:n_train],
        img_size=IMG_SIZE, num_angles=NUM_ANGLES, num_detectors=NUM_DETECTORS,
        noise_level=config['noise_level'], device=device
    )
    val_dataset = RealCTDataset(
        image_paths=all_images[n_train:n_train+n_val],
        img_size=IMG_SIZE, num_angles=NUM_ANGLES, num_detectors=NUM_DETECTORS,
        noise_level=config['noise_level'], device=device
    )
    test_dataset = RealCTDataset(
        image_paths=all_images[n_train+n_val:],
        img_size=IMG_SIZE, num_angles=NUM_ANGLES, num_detectors=NUM_DETECTORS,
        noise_level=config['noise_level'], device=device
    )
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

else:
    # ============================================
    # OPTION 1: SIMULATED PHANTOMS (Default)
    # ============================================
    print("Using simulated phantoms (no download needed)...")
    
    train_loader, val_loader, test_loader = create_dataloaders(
        config, batch_size=config['batch_size'], num_workers=2
    )

print()
print(f'Train: {len(train_loader.dataset)} samples')
print(f'Val: {len(val_loader.dataset)} samples')
print(f'Test: {len(test_loader.dataset)} samples')
print(f'Batches per epoch: {len(train_loader)}')

In [None]:
# Create model and loss
model = create_model(config).to(device)
loss_fn = create_loss(config).to(device)

params = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {params:,}')

## 6. Train Model

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device
)

# Train
print('Starting training...')
history = trainer.train(num_epochs=config['num_epochs'])

print(f'\nBest PSNR: {trainer.best_psnr:.2f} dB')
print(f'Best SSIM: {trainer.best_ssim:.2f}%')

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True)

axes[1].plot(history['train_psnr'], label='Train')
axes[1].plot(history['val_psnr'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('PSNR (dB)')
axes[1].legend()
axes[1].grid(True)

axes[2].plot(history['train_ssim'], label='Train')
axes[2].plot(history['val_ssim'], label='Val')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('SSIM (%)')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

## 7. Evaluation

In [None]:
# Evaluate on test set
model.eval()

test_metrics = {'psnr': [], 'ssim': [], 'rmse': [], 'mae': []}
fbp_metrics = {'psnr': [], 'ssim': [], 'rmse': [], 'mae': []}

with torch.no_grad():
    for batch in test_loader:
        sinogram = batch['sinogram_noisy'].to(device)
        target = batch['image'].to(device)
        weights = batch['weights'].to(device)
        mask = batch['mask'].to(device)
        fbp_recon = batch['fbp'].to(device)
        
        outputs = model(sinogram, weights, mask)
        
        m = compute_metrics(outputs['reconstruction'], target)
        for k in test_metrics:
            test_metrics[k].append(m[k])
        
        m_fbp = compute_metrics(fbp_recon, target)
        for k in fbp_metrics:
            fbp_metrics[k].append(m_fbp[k])

# Results table
print('\n' + '='*70)
print('TEST SET RESULTS')
print('='*70)
print(f'{"Metric":<10} {"FBP":>15} {"PINN-DADif":>15} {"Improvement":>15}')
print('-'*70)
for k in ['psnr', 'ssim', 'rmse', 'mae']:
    fbp_val = np.mean(fbp_metrics[k])
    model_val = np.mean(test_metrics[k])
    if k in ['psnr', 'ssim']:
        diff = model_val - fbp_val
        print(f'{k.upper():<10} {fbp_val:>15.2f} {model_val:>15.2f} {diff:>+15.2f}')
    else:
        diff = (1 - model_val/fbp_val) * 100
        print(f'{k.upper():<10} {fbp_val:>15.4f} {model_val:>15.4f} {diff:>+14.1f}%')

In [None]:
# Visual comparison
batch = next(iter(test_loader))
sinogram = batch['sinogram_noisy'].to(device)
target = batch['image'].to(device)
weights = batch['weights'].to(device)
mask = batch['mask'].to(device)
fbp_recon = batch['fbp'].to(device)

with torch.no_grad():
    outputs = model(sinogram, weights, mask)
    pred = outputs['reconstruction']

n = min(4, len(target))
fig, axes = plt.subplots(3, n, figsize=(4*n, 12))

for i in range(n):
    axes[0, i].imshow(target[i, 0].cpu().numpy(), cmap='gray')
    axes[0, i].axis('off')
    
    m_fbp = compute_metrics(fbp_recon[i:i+1], target[i:i+1])
    axes[1, i].imshow(fbp_recon[i, 0].cpu().numpy(), cmap='gray')
    axes[1, i].set_title(f'{m_fbp["psnr"]:.1f} dB')
    axes[1, i].axis('off')
    
    m_pred = compute_metrics(pred[i:i+1], target[i:i+1])
    axes[2, i].imshow(pred[i, 0].cpu().numpy(), cmap='gray')
    axes[2, i].set_title(f'{m_pred["psnr"]:.1f} dB')
    axes[2, i].axis('off')

axes[0, 0].set_ylabel('Ground Truth', fontsize=12)
axes[1, 0].set_ylabel('FBP', fontsize=12)
axes[2, 0].set_ylabel('PINN-DADif', fontsize=12)

plt.suptitle('Reconstruction Comparison', fontsize=14)
plt.tight_layout()
plt.savefig('comparison.png', dpi=150)
plt.show()

## 8. Save Model

In [None]:
# Save model
save_path = Path('experiments/checkpoints/ct_pinn_dadif.pt')
save_path.parent.mkdir(parents=True, exist_ok=True)

torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'test_psnr': np.mean(test_metrics['psnr']),
    'test_ssim': np.mean(test_metrics['ssim']),
}, save_path)

print(f'Model saved to {save_path}')

# Download (Colab)
try:
    from google.colab import files
    files.download(str(save_path))
except:
    pass

## Summary

### What This Notebook Does:
1. **Physics validation** - Tests Radon/FBP operators (must pass before training!)
2. **Noise simulation** - Shows Poisson noise at different dose levels
3. **Training** - Trains CT-PINN-DADif with physics-informed loss
4. **Evaluation** - Compares model vs FBP baseline

### Two Training Options:
| Option | Setting | Best For |
|--------|---------|----------|
| **Simulated Phantoms** | `USE_REAL_DATA = False` | Learning, testing, quick experiments |
| **Real CT Images** | `USE_REAL_DATA = True` | Final evaluation, paper results |

### Tips for Better Results:
- Increase `num_epochs` to 100-200
- Increase `num_train_samples` to 500-1000
- Try different `noise_level`: 'low', 'medium', 'high'
- For sparse-view CT: set `acquisition_type: 'sparse'`

### Expected Results:
- **FBP baseline**: ~25-30 dB PSNR
- **CT-PINN-DADif**: ~30-35 dB PSNR (after proper training)