# U-Net for CT Image Reconstruction
## LoDoPaB Dataset - Training on Lambda Labs GPU

This notebook trains a U-Net to enhance FBP reconstructions of low-dose CT images.

**Pipeline:**
```
Sinogram → FBP → Noisy Image → U-Net → Enhanced Image
```

---

## 1. Setup and Imports

In [None]:
# Install required packages (run once)
!pip install torch torchvision h5py scikit-image scipy tqdm matplotlib tensorboard numpy pandas

In [None]:
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
from skimage.transform import iradon
from scipy.ndimage import zoom

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

In [None]:
# ===========================
# CONFIGURATION - EDIT THESE
# ===========================


DATA_DIR = Path("../data/lodopab")   # Update this!

# Training hyperparameters
CONFIG = {
    'num_epochs': 50,
    'batch_size': 16,          # Increase for GPU (Lambda Labs can handle 16-32)
    'learning_rate': 1e-4,
    'num_workers': 8,          # Use multiple workers on Lambda Labs
    'device': 'cuda',          # Use GPU
    'pin_memory': True,        # Enable for GPU
    'save_dir': '../data/results/checkpoints/unet',
    'log_dir': '../data/results/logs/unet',
    'use_all_train_files': False,  # Set True to use all training data
    'num_train_files': 10,     # Number of training files to use (if use_all_train_files=False)
    'num_val_files': 2,        # Number of validation files to use
}

# Loss weights
LOSS_WEIGHTS = {
    'mse': 1.0,
    'ssim': 0.1,
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 3. Dataset Class

In [None]:
class LoDoPaBDataset(Dataset):
    def __init__(self, obs_files, gt_files, transform=None, use_fbp=True):
        self.obs_files = obs_files
        self.gt_files = gt_files
        self.transform = transform
        self.use_fbp = use_fbp
        
        # Calculate total number of samples
        self.file_offsets = [0]
        for obs_file in obs_files:
            with h5py.File(obs_file, 'r') as f:
                self.file_offsets.append(self.file_offsets[-1] + f['data'].shape[0])
        
        self.total_samples = self.file_offsets[-1]
    
    def __len__(self):
        return self.total_samples
    
    def _get_file_and_index(self, idx):
        """Convert global index to (file_index, local_index)"""
        for i in range(len(self.file_offsets) - 1):
            if idx < self.file_offsets[i + 1]:
                file_idx = i
                local_idx = idx - self.file_offsets[i]
                return file_idx, local_idx
        raise IndexError(f"Index {idx} out of range")
    
    def _simple_fbp(self, sinogram):
        """
        FBP reconstruction using skimage
        
        LoDoPaB sinogram: (1000, 513) = (angles, detectors)
        iradon expects: (detectors, angles)
        """
        # Transpose: (1000, 513) → (513, 1000)
        sinogram_transposed = sinogram.T
        
        # 1000 angles from 0 to 180 degrees
        theta = np.linspace(0, 180, sinogram.shape[0], endpoint=False)
        
        # Perform FBP
        reconstructed = iradon(
            sinogram_transposed,
            theta=theta,
            filter_name='ramp',
            interpolation='linear',
            circle=False,
            output_size= 362
        )
        
        # Resize to 362×362
        if reconstructed.shape != (362, 362):
            scale_y = 362 / reconstructed.shape[0]
            scale_x = 362 / reconstructed.shape[1]
            reconstructed = zoom(reconstructed, (scale_y, scale_x), order=1)

        reconstructed = np.rot90(reconstructed, k=-1)
        return reconstructed.astype(np.float32)
    
    def _normalize(self, img):
        """Normalize image to [0, 1]"""
        img_min = img.min()
        img_max = img.max()
        if img_max - img_min > 1e-8:
            return (img - img_min) / (img_max - img_min)
        return img
    
    def __getitem__(self, idx):
        file_idx, local_idx = self._get_file_and_index(idx)
        
        # Open files on-demand (multiprocessing-safe)
        with h5py.File(self.obs_files[file_idx], 'r') as f_obs, \
             h5py.File(self.gt_files[file_idx], 'r') as f_gt:
            
            sinogram = f_obs['data'][local_idx].astype(np.float32)
            ground_truth = f_gt['data'][local_idx].astype(np.float32)
        
        # Apply FBP reconstruction
        if self.use_fbp:
            input_img = self._simple_fbp(sinogram)
        else:
            # For testing: use ground truth with noise
            input_img = ground_truth + np.random.normal(0, 0.05, ground_truth.shape).astype(np.float32)
        
        # Normalize
        input_img = self._normalize(input_img)
        ground_truth = self._normalize(ground_truth)
        
        # Convert to tensors
        input_img = torch.from_numpy(input_img).unsqueeze(0)  # (1, 362, 362)
        target_img = torch.from_numpy(ground_truth).unsqueeze(0)
        
        return input_img, target_img

## 4. U-Net Model Architecture

In [None]:
class DoubleConv(nn.Module):
    """Double Convolution block"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downsampling block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upsampling block"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Handle size mismatch
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                     diffY // 2, diffY - diffY // 2])
        
        # Concatenate skip connection
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    """Output convolution"""
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    """U-Net for CT Image Reconstruction"""
    def __init__(self, n_channels=1, n_classes=1, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # Decoder
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        
        # Output
        self.outc = OutConv(64, n_classes)
    
    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Output
        logits = self.outc(x)
        return logits

## 5. Loss Function

In [None]:
class CombinedLoss(nn.Module):
    """Combined MSE + SSIM loss"""
    def __init__(self, mse_weight=1.0, ssim_weight=0.1):
        super().__init__()
        self.mse_weight = mse_weight
        self.ssim_weight = ssim_weight
        self.mse_loss = nn.MSELoss()
    
    def ssim_loss(self, pred, target):
        """Simplified SSIM loss"""
        C1 = 0.01 ** 2
        C2 = 0.03 ** 2
        
        mu_pred = pred.mean()
        mu_target = target.mean()
        
        sigma_pred = pred.var()
        sigma_target = target.var()
        sigma_pred_target = ((pred - mu_pred) * (target - mu_target)).mean()
        
        ssim = ((2 * mu_pred * mu_target + C1) * (2 * sigma_pred_target + C2)) / \
               ((mu_pred ** 2 + mu_target ** 2 + C1) * (sigma_pred + sigma_target + C2))
        
        return 1 - ssim
    
    def forward(self, pred, target):
        mse = self.mse_loss(pred, target)
        ssim = self.ssim_loss(pred, target)
        
        total_loss = self.mse_weight * mse + self.ssim_weight * ssim
        
        return total_loss, {'mse': mse.item(), 'ssim': ssim.item()}

## 6. Training and Validation Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_mse = 0
    total_ssim = 0
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch}')
    for batch_idx, (inputs, targets) in enumerate(pbar):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Forward
        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Loss
        loss, loss_dict = criterion(outputs, targets)
        
        # Backward
        loss.backward()
        optimizer.step()
        
        # Accumulate
        total_loss += loss.item()
        total_mse += loss_dict['mse']
        total_ssim += loss_dict['ssim']
        
        # Update progress
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'mse': f"{loss_dict['mse']:.4f}"
        })
    
    return total_loss / len(dataloader), total_mse / len(dataloader), total_ssim / len(dataloader)


def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    total_mse = 0
    total_ssim = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc='Validation'):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            loss, loss_dict = criterion(outputs, targets)
            
            total_loss += loss.item()
            total_mse += loss_dict['mse']
            total_ssim += loss_dict['ssim']
    
    return total_loss / len(dataloader), total_mse / len(dataloader), total_ssim / len(dataloader)


def save_sample_images(model, dataloader, device, save_dir, epoch):
    """Save sample reconstructions"""
    model.eval()
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    with torch.no_grad():
        inputs, targets = next(iter(dataloader))
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        
        # Plot first 4 samples
        fig, axes = plt.subplots(4, 3, figsize=(12, 16))
        for i in range(min(4, inputs.shape[0])):
            axes[i, 0].imshow(inputs[i, 0].cpu().numpy(), cmap='gray')
            axes[i, 0].set_title('Input (FBP)')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(outputs[i, 0].cpu().numpy(), cmap='gray')
            axes[i, 1].set_title('Output (U-Net)')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(targets[i, 0].cpu().numpy(), cmap='gray')
            axes[i, 2].set_title('Target (Ground Truth)')
            axes[i, 2].axis('off')
        
        plt.tight_layout()
        plt.savefig(save_dir / f'epoch_{epoch:03d}.png', dpi=150, bbox_inches='tight')
        plt.close()

## 7. Prepare Data

In [None]:
# Find data files
print(f"Looking for data in: {DATA_DIR}")

if not DATA_DIR.exists():
    print(f"ERROR: Data directory not found: {DATA_DIR}")
    print("Please update DATA_DIR in the configuration cell!")
else:
    # Get training files
    train_obs = sorted(list(DATA_DIR.glob("observation_train_*.hdf5")))
    train_gt = sorted(list(DATA_DIR.glob("ground_truth_train_*.hdf5")))
    
    # Get validation files
    val_obs = sorted(list(DATA_DIR.glob("observation_validation_*.hdf5")))
    val_gt = sorted(list(DATA_DIR.glob("ground_truth_validation_*.hdf5")))
    
    print(f"\nFound {len(train_obs)} training observation files")
    print(f"Found {len(train_gt)} training ground truth files")
    print(f"Found {len(val_obs)} validation observation files")
    print(f"Found {len(val_gt)} validation ground truth files")
    
    # Subset if needed
    if not CONFIG['use_all_train_files']:
        train_obs = train_obs[:CONFIG['num_train_files']]
        train_gt = train_gt[:CONFIG['num_train_files']]
        print(f"\nUsing {len(train_obs)} training files (subset)")
    
    val_obs = val_obs[:CONFIG['num_val_files']]
    val_gt = val_gt[:CONFIG['num_val_files']]
    print(f"Using {len(val_obs)} validation files")
    
    if not train_obs or not train_gt or not val_obs or not val_gt:
        print("\nERROR: Missing data files!")
        print("Make sure your LoDoPaB data is in the correct directory.")
    else:
        print("\n✓ Data files found successfully!")

## 8. Create Datasets and Dataloaders

In [None]:
# Create datasets
print("Creating datasets...")
train_dataset = LoDoPaBDataset(train_obs, train_gt, use_fbp=True)
val_dataset = LoDoPaBDataset(val_obs, val_gt, use_fbp=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create dataloaders (optimized for GPU)
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

print(f"\nBatches per epoch: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 9. Initialize Model, Loss, and Optimizer

In [None]:
# Device
device = torch.device(CONFIG['device'] if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model
model = UNet(n_channels=1, n_classes=1, bilinear=True)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Loss and optimizer
criterion = CombinedLoss(
    mse_weight=LOSS_WEIGHTS['mse'],
    ssim_weight=LOSS_WEIGHTS['ssim']
)
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)

print("\n✓ Model initialized successfully!")

## 10. Training Loop

In [None]:
# Create directories
save_dir = Path(CONFIG['save_dir'])
log_dir = Path(CONFIG['log_dir'])
save_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)

# TensorBoard
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(log_dir / f'run_{timestamp}')

# Training loop
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'train_mse': [], 'val_mse': []}

print("\n" + "="*60)
print("Starting Training")
print("="*60)

for epoch in range(1, CONFIG['num_epochs'] + 1):
    # Train
    train_loss, train_mse, train_ssim = train_epoch(
        model, train_loader, criterion, optimizer, device, epoch
    )
    
    # Validate
    val_loss, val_mse, val_ssim = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Log metrics
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('MSE/train', train_mse, epoch)
    writer.add_scalar('MSE/val', val_mse, epoch)
    writer.add_scalar('SSIM/train', train_ssim, epoch)
    writer.add_scalar('SSIM/val', val_ssim, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_mse'].append(train_mse)
    history['val_mse'].append(val_mse)
    
    # Print summary
    print(f"\nEpoch {epoch}/{CONFIG['num_epochs']}")
    print(f"Train - Loss: {train_loss:.4f}, MSE: {train_mse:.4f}, SSIM: {train_ssim:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, MSE: {val_mse:.4f}, SSIM: {val_ssim:.4f}")
    print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save sample images
    if epoch % 5 == 0:
        save_sample_images(model, val_loader, device, save_dir / 'samples', epoch)
        print(f"✓ Saved sample images")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, save_dir / 'best_model.pth')
        print(f"✓ Saved best model (val_loss: {val_loss:.4f})")
    
    # Save checkpoint
    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, save_dir / f'checkpoint_epoch_{epoch:03d}.pth')
        print(f"✓ Saved checkpoint")

writer.close()
print("\n" + "="*60)
print("Training Completed!")
print("="*60)
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Models saved to: {save_dir}")

## 11. Plot Training History

In [None]:
# Plot loss curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Total loss
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# MSE
ax2.plot(history['train_mse'], label='Train MSE')
ax2.plot(history['val_mse'], label='Val MSE')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('MSE')
ax2.set_title('Training and Validation MSE')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(save_dir / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Saved training history plot to: {save_dir / 'training_history.png'}")

## 12. Load Best Model and Test

In [None]:
# Load best model
checkpoint = torch.load(save_dir / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Best validation loss: {checkpoint['val_loss']:.4f}")

# Test on validation set
with torch.no_grad():
    inputs, targets = next(iter(val_loader))
    inputs = inputs.to(device)
    targets = targets.to(device)
    outputs = model(inputs)
    
    # Visualize
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    
    for i in range(min(4, inputs.shape[0])):
        # Input (FBP)
        axes[0, i].imshow(inputs[i, 0].cpu().numpy(), cmap='gray')
        axes[0, i].set_title(f'Sample {i+1}\nInput (FBP)')
        axes[0, i].axis('off')
        
        # Output (U-Net)
        axes[1, i].imshow(outputs[i, 0].cpu().numpy(), cmap='gray')
        axes[1, i].set_title('Output (U-Net)')
        axes[1, i].axis('off')
        
        # Target (Ground Truth)
        axes[2, i].imshow(targets[i, 0].cpu().numpy(), cmap='gray')
        axes[2, i].set_title('Ground Truth')
        axes[2, i].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_dir / 'final_results.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nSaved final results to: {save_dir / 'final_results.png'}")