# Module 2: Medical Image Enhancement Model Training

This notebook trains a deep learning model for medical image enhancement using U-Net architecture.

## Objectives:
- Train U-Net model for image denoising and enhancement
- Implement custom loss functions (L1 + Perceptual Loss)
- Evaluate model performance with PSNR/SSIM metrics
- Save trained model for deployment

---

## 1. Setup and Imports

In [None]:
import sys
import os

# Add project root to path
project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    sys.path.insert(0, os.path.join(project_root, 'src'))

print(f"Project root: {project_root}")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import json
from datetime import datetime

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Define U-Net Architecture

U-Net is a convolutional neural network designed for medical image segmentation and enhancement.

In [None]:
class DoubleConv(nn.Module):
    """(Conv2D -> BatchNorm -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    """U-Net for Medical Image Enhancement"""
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Encoder (downsampling)
        for feature in features:
            self.encoder.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        
        # Decoder (upsampling)
        for feature in reversed(features):
            self.decoder.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(DoubleConv(feature * 2, feature))
        
        # Final output layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
    
    def forward(self, x):
        skip_connections = []
        
        # Encoder
        for encode in self.encoder:
            x = encode(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        # Bottleneck
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        
        # Decoder
        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip_connection = skip_connections[idx // 2]
            
            # Handle size mismatch
            if x.shape != skip_connection.shape:
                x = nn.functional.interpolate(x, size=skip_connection.shape[2:])
            
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.decoder[idx + 1](concat_skip)
        
        return torch.sigmoid(self.final_conv(x))


# Test the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, out_channels=1).to(device)
test_input = torch.randn(1, 1, 256, 256).to(device)
test_output = model(test_input)
print(f"Model input shape: {test_input.shape}")
print(f"Model output shape: {test_output.shape}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Create Dataset and DataLoader

Generate synthetic training data with various types of degradation.

In [None]:
class MedicalImageDataset(Dataset):
    """Dataset for medical image enhancement training"""
    
    def __init__(self, num_samples=1000, image_size=256, noise_level=0.1):
        self.num_samples = num_samples
        self.image_size = image_size
        self.noise_level = noise_level
        
    def generate_synthetic_image(self):
        """Generate synthetic medical image"""
        image = np.zeros((self.image_size, self.image_size), dtype=np.float32)
        
        # Add circular structures (simulating organs/tissues)
        num_circles = np.random.randint(3, 8)
        for _ in range(num_circles):
            center_x = np.random.randint(50, self.image_size - 50)
            center_y = np.random.randint(50, self.image_size - 50)
            radius = np.random.randint(20, 60)
            intensity = np.random.uniform(0.3, 0.9)
            cv2.circle(image, (center_x, center_y), radius, intensity, -1)
        
        # Add ellipses
        num_ellipses = np.random.randint(2, 5)
        for _ in range(num_ellipses):
            center = (np.random.randint(50, self.image_size - 50),
                     np.random.randint(50, self.image_size - 50))
            axes = (np.random.randint(15, 40), np.random.randint(15, 40))
            angle = np.random.randint(0, 180)
            intensity = np.random.uniform(0.2, 0.7)
            cv2.ellipse(image, center, axes, angle, 0, 360, intensity, -1)
        
        # Apply Gaussian smoothing (tissue-like texture)
        image = cv2.GaussianBlur(image, (15, 15), 3)
        
        return image
    
    def add_degradation(self, image):
        """Add noise and blur to simulate degraded images"""
        degraded = image.copy()
        
        # Add Gaussian noise
        noise = np.random.normal(0, self.noise_level, image.shape).astype(np.float32)
        degraded = degraded + noise
        
        # Add blur
        degraded = cv2.GaussianBlur(degraded, (5, 5), 1.5)
        
        # Clip values
        degraded = np.clip(degraded, 0, 1)
        
        return degraded
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Generate clean image
        clean_image = self.generate_synthetic_image()
        
        # Generate degraded version
        degraded_image = self.add_degradation(clean_image)
        
        # Convert to tensors
        clean_tensor = torch.from_numpy(clean_image).unsqueeze(0)  # Add channel dim
        degraded_tensor = torch.from_numpy(degraded_image).unsqueeze(0)
        
        return degraded_tensor, clean_tensor


# Create datasets
train_dataset = MedicalImageDataset(num_samples=800, image_size=256, noise_level=0.1)
val_dataset = MedicalImageDataset(num_samples=200, image_size=256, noise_level=0.1)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")

## 4. Visualize Sample Data

In [None]:
# Visualize sample images
degraded, clean = next(iter(train_loader))

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle('Training Data Samples', fontsize=16, fontweight='bold')

for i in range(4):
    # Degraded images
    axes[0, i].imshow(degraded[i, 0].cpu().numpy(), cmap='gray')
    axes[0, i].set_title(f'Degraded Image {i+1}')
    axes[0, i].axis('off')
    
    # Clean images
    axes[1, i].imshow(clean[i, 0].cpu().numpy(), cmap='gray')
    axes[1, i].set_title(f'Clean Image {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

## 5. Define Loss Functions and Metrics

In [None]:
class CombinedLoss(nn.Module):
    """Combined L1 and MSE loss for better reconstruction"""
    def __init__(self, l1_weight=0.7, mse_weight=0.3):
        super().__init__()
        self.l1_weight = l1_weight
        self.mse_weight = mse_weight
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred, target):
        l1 = self.l1_loss(pred, target)
        mse = self.mse_loss(pred, target)
        return self.l1_weight * l1 + self.mse_weight * mse


def calculate_metrics(pred, target):
    """Calculate PSNR and SSIM metrics"""
    pred_np = pred.cpu().detach().numpy()
    target_np = target.cpu().detach().numpy()
    
    psnr_values = []
    ssim_values = []
    
    for i in range(pred_np.shape[0]):
        pred_img = pred_np[i, 0]
        target_img = target_np[i, 0]
        
        psnr_val = psnr(target_img, pred_img, data_range=1.0)
        ssim_val = ssim(target_img, pred_img, data_range=1.0)
        
        psnr_values.append(psnr_val)
        ssim_values.append(ssim_val)
    
    return np.mean(psnr_values), np.mean(ssim_values)


# Initialize loss function
criterion = CombinedLoss(l1_weight=0.7, mse_weight=0.3)
print("Loss function initialized: Combined L1 (70%) + MSE (30%)")

## 6. Training Configuration

In [None]:
# Training hyperparameters
EPOCHS = 50
LEARNING_RATE = 0.001
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model, optimizer, scheduler
model = UNet(in_channels=1, out_channels=1).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Create model directory
model_dir = Path('../models/image_enhancement')
model_dir.mkdir(parents=True, exist_ok=True)

print(f"Training configuration:")
print(f"  Device: {DEVICE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Model save path: {model_dir}")

## 7. Training Loop

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    running_psnr = 0.0
    running_ssim = 0.0
    
    progress_bar = tqdm(dataloader, desc='Training')
    for degraded, clean in progress_bar:
        degraded = degraded.to(device)
        clean = clean.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        enhanced = model(degraded)
        loss = criterion(enhanced, clean)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        psnr_val, ssim_val = calculate_metrics(enhanced, clean)
        
        # Update running metrics
        running_loss += loss.item()
        running_psnr += psnr_val
        running_ssim += ssim_val
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'psnr': f'{psnr_val:.2f}',
            'ssim': f'{ssim_val:.4f}'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_psnr = running_psnr / len(dataloader)
    epoch_ssim = running_ssim / len(dataloader)
    
    return epoch_loss, epoch_psnr, epoch_ssim


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    running_ssim = 0.0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc='Validation')
        for degraded, clean in progress_bar:
            degraded = degraded.to(device)
            clean = clean.to(device)
            
            # Forward pass
            enhanced = model(degraded)
            loss = criterion(enhanced, clean)
            
            # Calculate metrics
            psnr_val, ssim_val = calculate_metrics(enhanced, clean)
            
            # Update running metrics
            running_loss += loss.item()
            running_psnr += psnr_val
            running_ssim += ssim_val
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'psnr': f'{psnr_val:.2f}',
                'ssim': f'{ssim_val:.4f}'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_psnr = running_psnr / len(dataloader)
    epoch_ssim = running_ssim / len(dataloader)
    
    return epoch_loss, epoch_psnr, epoch_ssim

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_psnr': [],
    'train_ssim': [],
    'val_loss': [],
    'val_psnr': [],
    'val_ssim': []
}

best_val_loss = float('inf')
best_model_path = model_dir / 'best_unet_model.pth'

print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70 + "\n")

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    print("-" * 70)
    
    # Train
    train_loss, train_psnr, train_ssim = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE
    )
    
    # Validate
    val_loss, val_psnr, val_ssim = validate_epoch(
        model, val_loader, criterion, DEVICE
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_psnr'].append(train_psnr)
    history['train_ssim'].append(train_ssim)
    history['val_loss'].append(val_loss)
    history['val_psnr'].append(val_psnr)
    history['val_ssim'].append(val_ssim)
    
    # Print epoch summary
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | PSNR: {train_psnr:.2f} dB | SSIM: {train_ssim:.4f}")
    print(f"  Val Loss:   {val_loss:.4f} | PSNR: {val_psnr:.2f} dB | SSIM: {val_ssim:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_psnr': val_psnr,
            'val_ssim': val_ssim,
        }, best_model_path)
        print(f"  ✓ Best model saved! (Val Loss: {val_loss:.4f})")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)

## 8. Visualize Training Progress

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('Training Progress', fontsize=16, fontweight='bold')

epochs_range = range(1, len(history['train_loss']) + 1)

# Loss plot
axes[0].plot(epochs_range, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs_range, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Loss Curve', fontsize=14)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# PSNR plot
axes[1].plot(epochs_range, history['train_psnr'], 'b-', label='Train PSNR', linewidth=2)
axes[1].plot(epochs_range, history['val_psnr'], 'r-', label='Val PSNR', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('PSNR (dB)', fontsize=12)
axes[1].set_title('PSNR Progress', fontsize=14)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# SSIM plot
axes[2].plot(epochs_range, history['train_ssim'], 'b-', label='Train SSIM', linewidth=2)
axes[2].plot(epochs_range, history['val_ssim'], 'r-', label='Val SSIM', linewidth=2)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('SSIM', fontsize=12)
axes[2].set_title('SSIM Progress', fontsize=14)
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(model_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Training curves saved to: {model_dir / 'training_curves.png'}")

## 9. Test Model on Validation Samples

In [None]:
# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch'] + 1}")
print(f"  Val Loss: {checkpoint['val_loss']:.4f}")
print(f"  Val PSNR: {checkpoint['val_psnr']:.2f} dB")
print(f"  Val SSIM: {checkpoint['val_ssim']:.4f}")

# Get sample batch
degraded, clean = next(iter(val_loader))
degraded = degraded.to(DEVICE)
clean = clean.to(DEVICE)

# Generate predictions
with torch.no_grad():
    enhanced = model(degraded)

# Visualize results
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
fig.suptitle('Model Predictions - Validation Set', fontsize=16, fontweight='bold')

for i in range(4):
    # Degraded image
    axes[0, i].imshow(degraded[i, 0].cpu().numpy(), cmap='gray')
    axes[0, i].set_title(f'Degraded {i+1}', fontsize=12)
    axes[0, i].axis('off')
    
    # Enhanced image
    axes[1, i].imshow(enhanced[i, 0].cpu().numpy(), cmap='gray')
    psnr_val, ssim_val = calculate_metrics(enhanced[i:i+1], clean[i:i+1])
    axes[1, i].set_title(f'Enhanced {i+1}\nPSNR: {psnr_val:.2f} dB', fontsize=12)
    axes[1, i].axis('off')
    
    # Ground truth
    axes[2, i].imshow(clean[i, 0].cpu().numpy(), cmap='gray')
    axes[2, i].set_title(f'Ground Truth {i+1}\nSSIM: {ssim_val:.4f}', fontsize=12)
    axes[2, i].axis('off')

plt.tight_layout()
plt.savefig(model_dir / 'validation_results.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Validation results saved to: {model_dir / 'validation_results.png'}")

## 10. Save Training Metadata

In [None]:
# Save training history
with open(model_dir / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

# Save model metadata
metadata = {
    'model_architecture': 'U-Net',
    'input_channels': 1,
    'output_channels': 1,
    'image_size': 256,
    'total_parameters': sum(p.numel() for p in model.parameters()),
    'training': {
        'epochs': EPOCHS,
        'learning_rate': LEARNING_RATE,
        'batch_size': 8,
        'optimizer': 'Adam',
        'loss_function': 'Combined L1 + MSE',
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset)
    },
    'best_results': {
        'epoch': checkpoint['epoch'] + 1,
        'val_loss': float(checkpoint['val_loss']),
        'val_psnr': float(checkpoint['val_psnr']),
        'val_ssim': float(checkpoint['val_ssim'])
    },
    'timestamp': datetime.now().isoformat()
}

with open(model_dir / 'model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("\n" + "="*70)
print("MODEL ARTIFACTS SAVED")
print("="*70)
print(f"\nModel directory: {model_dir}")
print(f"\nFiles saved:")
print(f"  ✓ best_unet_model.pth - Trained model weights")
print(f"  ✓ training_history.json - Training metrics")
print(f"  ✓ model_metadata.json - Model configuration")
print(f"  ✓ training_curves.png - Training visualization")
print(f"  ✓ validation_results.png - Validation samples")
print("\n" + "="*70)

## 11. Model Summary

In [None]:
print("\n" + "="*70)
print("FINAL MODEL SUMMARY")
print("="*70)
print(f"\nArchitecture: U-Net for Medical Image Enhancement")
print(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\nTraining Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Initial Learning Rate: {LEARNING_RATE}")
print(f"  Batch Size: 8")
print(f"  Loss Function: Combined L1 (70%) + MSE (30%)")
print(f"\nBest Model Performance:")
print(f"  Validation Loss: {checkpoint['val_loss']:.4f}")
print(f"  Validation PSNR: {checkpoint['val_psnr']:.2f} dB")
print(f"  Validation SSIM: {checkpoint['val_ssim']:.4f}")
print(f"  Saved at Epoch: {checkpoint['epoch'] + 1}")
print(f"\nModel saved to: {best_model_path}")
print("\n" + "="*70)
print("✓ Training notebook completed successfully!")
print("="*70)