# CT-PINN-DADif: Physics-Informed Deep Adaptive Diffusion Network for CT Reconstruction

This notebook provides a complete training pipeline for CT-PINN-DADif, adapted from the MRI PINN-DADif paper (Ahmed et al., 2025) for CT reconstruction.

## Key Differences from MRI Version:
- **Physics**: Radon Transform / Sinogram instead of Fourier / k-space
- **Noise Model**: Poisson photon counting instead of Gaussian
- **Data Consistency**: Sinogram consistency instead of k-space consistency
- **Regularization**: Total Variation + Non-negativity instead of Bloch equations

## Architecture Overview:
1. **CT-LPCE**: Encodes sinogram to latent features with physics constraints
2. **CT-PACE**: Multi-scale context encoding with ASPP and attention
3. **CT-ADRN**: Adaptive diffusion refinement with sinogram consistency
4. **CT-ART**: Final synthesis with dynamic convolutions

## 1. Setup & Installation

In [None]:
# Clone repository (if running on Colab)
!git clone https://github.com/Iammohithhh/PINN_Dadiff.git
%cd PINN_Dadiff/ct_reconstruction

In [None]:
# Install dependencies
!pip install torch torchvision numpy scipy tqdm matplotlib

In [None]:
import sys
sys.path.insert(0, 'src')

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Import CT-PINN-DADif Modules

In [None]:
from ct_physics import (
    RadonTransform, FilteredBackProjection, CTForwardModel,
    create_sparse_view_mask, create_limited_angle_mask
)
from model import CT_PINN_DADif, CTReconstructionLoss, SAM, create_model, create_loss
from data_loader import SimulatedCTDataset, create_dataloaders, create_shepp_logan_phantom
from train import Trainer, compute_metrics, DEFAULT_CONFIG

## 3. Visualize CT Physics

In [None]:
# Create a Shepp-Logan phantom
phantom = create_shepp_logan_phantom(256)

# Convert to tensor
phantom_tensor = torch.from_numpy(phantom).unsqueeze(0).unsqueeze(0).float().to(device)

# Create forward model
ct_model = CTForwardModel(img_size=256, num_angles=180, I0=1e4, device=device)

# Generate sinogram
sinogram_clean = ct_model.forward_project(phantom_tensor)
sinogram_noisy, counts = ct_model(phantom_tensor, add_noise=True, return_counts=True)

# FBP reconstruction
fbp = FilteredBackProjection(256, 180, device=device)
fbp_recon = fbp(sinogram_noisy)

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

axes[0, 0].imshow(phantom, cmap='gray')
axes[0, 0].set_title('Ground Truth Phantom')
axes[0, 0].axis('off')

axes[0, 1].imshow(sinogram_clean.squeeze().cpu().numpy(), cmap='gray', aspect='auto')
axes[0, 1].set_title('Clean Sinogram')
axes[0, 1].set_xlabel('Detector')
axes[0, 1].set_ylabel('Angle')

axes[0, 2].imshow(sinogram_noisy.squeeze().cpu().numpy(), cmap='gray', aspect='auto')
axes[0, 2].set_title('Noisy Sinogram (Poisson)')
axes[0, 2].set_xlabel('Detector')
axes[0, 2].set_ylabel('Angle')

axes[1, 0].imshow(fbp_recon.squeeze().cpu().numpy(), cmap='gray')
axes[1, 0].set_title('FBP Reconstruction')
axes[1, 0].axis('off')

# Sparse-view simulation
sparse_mask = create_sparse_view_mask(180, 30, device)
sinogram_sparse = sinogram_noisy * sparse_mask
fbp_sparse = fbp(sinogram_sparse)

axes[1, 1].imshow(sinogram_sparse.squeeze().cpu().numpy(), cmap='gray', aspect='auto')
axes[1, 1].set_title('Sparse-View Sinogram (30 views)')

axes[1, 2].imshow(fbp_sparse.squeeze().cpu().numpy(), cmap='gray')
axes[1, 2].set_title('FBP from Sparse-View')
axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig('ct_physics_demo.png', dpi=150)
plt.show()

# Compute metrics
metrics_fbp = compute_metrics(fbp_recon, phantom_tensor)
print(f"\nFBP Reconstruction Metrics:")
print(f"  PSNR: {metrics_fbp['psnr']:.2f} dB")
print(f"  SSIM: {metrics_fbp['ssim']:.2f} %")

## 4. Configure Training

In [None]:
# Training configuration
config = {
    # Model architecture
    'img_size': 256,
    'num_angles': 180,
    'num_detectors': None,  # Auto-compute
    'base_channels': 64,
    'latent_dim': 128,
    'context_dim': 256,
    'num_diffusion_steps': 12,
    'lambda_phys_lpce': 0.3,
    'lambda_phys_pace': 0.1,
    'use_final_dc': True,
    
    # Dataset
    'dataset_type': 'simulated',
    'num_train_samples': 500,   # Reduce for Colab
    'num_val_samples': 100,
    'num_test_samples': 100,
    'phantom_type': 'mixed',    # 'shepp_logan', 'random', 'mixed'
    'noise_level': 'low',       # 'none', 'low', 'medium', 'high'
    'acquisition_type': 'full', # 'full', 'sparse', 'limited'
    'num_views': 60,            # For sparse-view
    
    # Training
    'batch_size': 2,            # Reduce for Colab GPU memory
    'num_epochs': 100,          # Reduce for demo
    'learning_rate': 6e-3,
    'use_sam': True,            # Sharpness-Aware Minimization
    'sam_rho': 0.05,
    'use_amp': True,            # Mixed precision
    'num_workers': 2,
    
    # Loss weights
    'alpha': 0.5,               # Pixel loss
    'beta': 0.2,                # Perceptual loss
    'gamma': 0.3,               # Physics loss
    'tv_weight': 1e-4,
    'nonneg_weight': 1e-3,
    'use_poisson': False,
    'use_perceptual': False,    # Disable for faster training
    
    # Checkpointing
    'checkpoint_dir': 'experiments/checkpoints',
    'log_dir': 'experiments/logs',
    'save_every': 25
}

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 5. Create Model and Data

In [None]:
# Create dataloaders
print("Creating datasets...")
train_loader, val_loader, test_loader = create_dataloaders(
    config,
    batch_size=config['batch_size'],
    num_workers=config['num_workers']
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Visualize a sample
sample = next(iter(train_loader))
print(f"\nSample shapes:")
for k, v in sample.items():
    if isinstance(v, torch.Tensor):
        print(f"  {k}: {v.shape}")

In [None]:
# Create model
print("Creating CT-PINN-DADif model...")
model = create_model(config)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Create loss function
loss_fn = create_loss(config)
loss_fn = loss_fn.to(device)

## 6. Test Forward Pass

In [None]:
# Test forward pass
model.eval()

with torch.no_grad():
    sinogram = sample['sinogram_noisy'].to(device)
    target = sample['image'].to(device)
    weights = sample['weights'].to(device)
    mask = sample['mask'].to(device)
    
    print(f"Input sinogram: {sinogram.shape}")
    
    # Forward pass
    outputs = model(sinogram, weights, mask, return_intermediate=True)
    
    print(f"\nOutput shapes:")
    for k, v in outputs.items():
        if isinstance(v, torch.Tensor):
            print(f"  {k}: {v.shape}")
    
    # Compute metrics
    metrics = compute_metrics(outputs['reconstruction'], target)
    print(f"\nInitial metrics (before training):")
    print(f"  PSNR: {metrics['psnr']:.2f} dB")
    print(f"  SSIM: {metrics['ssim']:.2f} %")

## 7. Train Model

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device
)

# Train!
print("Starting training...")
print("="*60)

history = trainer.train(
    num_epochs=config['num_epochs'],
    start_epoch=0
)

print("="*60)
print(f"Training complete!")
print(f"Best PSNR: {trainer.best_psnr:.2f} dB")
print(f"Best SSIM: {trainer.best_ssim:.2f} %")

## 8. Visualize Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True)

# PSNR
axes[1].plot(history['train_psnr'], label='Train')
axes[1].plot(history['val_psnr'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('PSNR (dB)')
axes[1].set_title('PSNR')
axes[1].legend()
axes[1].grid(True)

# SSIM
axes[2].plot(history['train_ssim'], label='Train')
axes[2].plot(history['val_ssim'], label='Val')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('SSIM (%)')
axes[2].set_title('SSIM')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()

## 9. Evaluate on Test Set

In [None]:
# Load best model
checkpoint = torch.load('experiments/checkpoints/best_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Evaluate on test set
test_psnr = []
test_ssim = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        sinogram = batch['sinogram_noisy'].to(device)
        target = batch['image'].to(device)
        weights = batch['weights'].to(device)
        mask = batch['mask'].to(device)
        
        outputs = model(sinogram, weights, mask)
        metrics = compute_metrics(outputs['reconstruction'], target)
        
        test_psnr.append(metrics['psnr'])
        test_ssim.append(metrics['ssim'])

print(f"\nTest Results:")
print(f"  PSNR: {np.mean(test_psnr):.2f} +/- {np.std(test_psnr):.2f} dB")
print(f"  SSIM: {np.mean(test_ssim):.2f} +/- {np.std(test_ssim):.2f} %")

## 10. Visualize Reconstructions

In [None]:
# Get a few test samples
model.eval()
test_batch = next(iter(test_loader))

with torch.no_grad():
    sinogram = test_batch['sinogram_noisy'].to(device)
    target = test_batch['image'].to(device)
    weights = test_batch['weights'].to(device)
    mask = test_batch['mask'].to(device)
    fbp_recon = test_batch['fbp'].to(device)
    
    outputs = model(sinogram, weights, mask)

# Visualize
n_samples = min(4, sinogram.shape[0])
fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4*n_samples))

for i in range(n_samples):
    # Ground truth
    axes[i, 0].imshow(target[i, 0].cpu().numpy(), cmap='gray')
    axes[i, 0].set_title('Ground Truth')
    axes[i, 0].axis('off')
    
    # FBP
    fbp_metrics = compute_metrics(fbp_recon[i:i+1], target[i:i+1])
    axes[i, 1].imshow(fbp_recon[i, 0].cpu().numpy(), cmap='gray')
    axes[i, 1].set_title(f'FBP (PSNR: {fbp_metrics["psnr"]:.1f})')
    axes[i, 1].axis('off')
    
    # CT-PINN-DADif
    rec = outputs['reconstruction']
    rec_metrics = compute_metrics(rec[i:i+1], target[i:i+1])
    axes[i, 2].imshow(rec[i, 0].cpu().numpy(), cmap='gray')
    axes[i, 2].set_title(f'CT-PINN-DADif (PSNR: {rec_metrics["psnr"]:.1f})')
    axes[i, 2].axis('off')
    
    # Error map
    error = torch.abs(rec[i, 0] - target[i, 0]).cpu().numpy()
    axes[i, 3].imshow(error, cmap='hot', vmin=0, vmax=0.1)
    axes[i, 3].set_title('Error Map')
    axes[i, 3].axis('off')

plt.tight_layout()
plt.savefig('reconstructions.png', dpi=150)
plt.show()

## 11. Save Model for Deployment

In [None]:
# Save model for deployment
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config
}, 'ct_pinn_dadif_final.pt')

print("Model saved to ct_pinn_dadif_final.pt")

# Download from Colab
try:
    from google.colab import files
    files.download('ct_pinn_dadif_final.pt')
except:
    pass