# Imports and Setup

In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import json
from datetime import datetime

# Add src directory to path
sys.path.append(os.path.join(os.getcwd(), 'src'))

# Check device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"‚úÖ Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

‚úÖ Device: cuda
   GPU: NVIDIA L4
   Memory: 23.67 GB


In [2]:
# Setup path to access custom modules
root_directory = os.path.dirname(os.getcwd())
sys.path.append(os.path.join(root_directory, 'src'))

from ModelDataGenerator import build_dataloader

# Model - UNet Architecture

In [3]:
class UNetBlock(nn.Module):
    """Double convolution block with batch normalization"""
    def __init__(self, in_channels, out_channels):
        super(UNetBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [4]:
class UNet(nn.Module):
    """
    UNet Architecture for medical image super-resolution
    Input: (B, 2, H, W) - prior and posterior slices
    Output: (B, 1, H, W) - predicted middle slice
    """
    def __init__(self, in_channels=2, out_channels=1, init_features=64):
        super(UNet, self).__init__()
        
        features = init_features
        
        # Encoder
        self.enc1 = UNetBlock(in_channels, features)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc2 = UNetBlock(features, features * 2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc3 = UNetBlock(features * 2, features * 4)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc4 = UNetBlock(features * 4, features * 8)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Bottleneck
        self.bottleneck = UNetBlock(features * 8, features * 16)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.dec4 = UNetBlock(features * 16, features * 8)
        
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.dec3 = UNetBlock(features * 8, features * 4)
        
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.dec2 = UNetBlock(features * 4, features * 2)
        
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.dec1 = UNetBlock(features * 2, features)
        
        # Final output layer
        self.final_conv = nn.Conv2d(features, out_channels, kernel_size=1)
    
    def forward(self, x):
        # Encoder with skip connections
        enc1 = self.enc1(x)
        x = self.pool1(enc1)
        
        enc2 = self.enc2(x)
        x = self.pool2(enc2)
        
        enc3 = self.enc3(x)
        x = self.pool3(enc3)
        
        enc4 = self.enc4(x)
        x = self.pool4(enc4)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder with skip connections
        x = self.upconv4(x)
        x = torch.cat([x, enc4], dim=1)
        x = self.dec4(x)
        
        x = self.upconv3(x)
        x = torch.cat([x, enc3], dim=1)
        x = self.dec3(x)
        
        x = self.upconv2(x)
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)
        
        x = self.upconv1(x)
        x = torch.cat([x, enc1], dim=1)
        x = self.dec1(x)
        
        # Output
        x = self.final_conv(x)
        return x

In [5]:
# Test model initialization
model = UNet(in_channels=2, out_channels=1, init_features=64)
print(f"‚úÖ Model created successfully!")
print(f"üìä Total parameters: {sum(p.numel() for p in model.parameters()):,}")

‚úÖ Model created successfully!
üìä Total parameters: 31,042,945


# Load Data from ModelDataGenerator

In [6]:
# Configuration
BATCH_SIZE = 2
NUM_WORKERS = 8  # Set to 0 for Windows, increase for Linux
AUGMENT = True

# Build dataloaders
train_loader = build_dataloader(
    split="train",
    batch_size=BATCH_SIZE,
    augment=AUGMENT,
    num_workers=NUM_WORKERS
)

val_loader = build_dataloader(
    split="val",
    batch_size=BATCH_SIZE,
    augment=False,
    num_workers=NUM_WORKERS
)

test_loader = build_dataloader(
    split="test",
    batch_size=BATCH_SIZE,
    augment=False,
    num_workers=NUM_WORKERS
)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")
print(f"Test loader: {len(test_loader)} batches")

# Display sample batch shapes
print("\nüìä Sample batch shapes:")
for (pre, post), target in train_loader:
    print(f"Prior: {pre.shape}")
    print(f"Posterior: {post.shape}")
    print(f"Target (middle): {target.shape}")
    break

Train loader: 36537 batches
Val loader: 6441 batches
Test loader: 9120 batches

üìä Sample batch shapes:
Prior: torch.Size([2, 1, 256, 256])
Posterior: torch.Size([2, 1, 256, 256])
Target (middle): torch.Size([2, 1, 256, 256])


# Training Configuration and Setup

In [7]:
# Training Configuration
EPOCHS = 10
LEARNING_RATE = 1e-4
EARLY_STOPPING_PATIENCE = 5
MODEL_SAVE_DIR = Path('../models')
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)

print("Training Configuration:")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Early stopping patience: {EARLY_STOPPING_PATIENCE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Augmentation: {AUGMENT}")
print(f"Model save dir: {MODEL_SAVE_DIR}")
print()

# Initialize model, optimizer, and loss
model = UNet(in_channels=2, out_channels=1, init_features=64).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

Training Configuration:
Epochs: 10
Learning rate: 0.0001
Early stopping patience: 5
Batch size: 2
Augmentation: True
Model save dir: ../models



# Training Loop

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0.0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for (pre, post), target in pbar:
        # Stack prior and posterior as input (2 channels)
        inputs = torch.cat([pre, post], dim=1).to(device)
        targets = target.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(train_loader)
    return avg_loss


def validate(model, val_loader, criterion, device):
    """Validate model"""
    model.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc="Validating", leave=False)
        for (pre, post), target in pbar:
            # Stack prior and posterior as input (2 channels)
            inputs = torch.cat([pre, post], dim=1).to(device)
            targets = target.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(val_loader)
    return avg_loss


# Training variables
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0

print(f"Starting training...\n")
print("=" * 70)

for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, DEVICE)
    val_loss = validate(model, val_loader, criterion, DEVICE)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Print progress
    if epoch % 1 == 0:
        print(f"Epoch {epoch:3d}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}", end="")
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            
            # Save best model
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'train_losses': train_losses,
                'val_losses': val_losses
            }
            torch.save(checkpoint, MODEL_SAVE_DIR / 'unet_best.pt')
            print("(Best)")
        else:
            patience_counter += 1
            print(f" (patience: {patience_counter}/{EARLY_STOPPING_PATIENCE})")
            
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"\nEarly stopping triggered after {epoch} epochs\n")
                break

print("=" * 70)
print(f"Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

Starting training...



Training:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 15389/36537 [18:56<22:07, 15.92it/s, loss=0.0277]   

## 6Ô∏è‚É£ Model Evaluation on Test Set

In [None]:
def evaluate(model, test_loader, criterion, device):
    """Evaluate model on test set"""
    model.eval()
    total_loss = 0.0
    predictions = []
    targets_list = []
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Testing", leave=False)
        for (pre, post), target in pbar:
            # Stack prior and posterior as input (2 channels)
            inputs = torch.cat([pre, post], dim=1).to(device)
            targets = target.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()
            predictions.append(outputs.cpu())
            targets_list.append(targets.cpu())
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(test_loader)
    return avg_loss, predictions, targets_list


# Load best model
best_model_path = MODEL_SAVE_DIR / 'unet_best.pt'
if best_model_path.exists():
    checkpoint = torch.load(best_model_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"‚úÖ Best model loaded from epoch {checkpoint['epoch']}")
else:
    print("‚ö†Ô∏è  Best model not found, using current model")

# Evaluate on test set
print("\nüß™ Evaluating on test set...\n")
test_loss, predictions, targets_list = evaluate(model, test_loader, criterion, DEVICE)

print(f"‚úÖ Test Loss: {test_loss:.4f}")

## 7Ô∏è‚É£ Training Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss', linewidth=2, marker='o', markersize=3)
plt.plot(val_losses, label='Validation Loss', linewidth=2, marker='s', markersize=3)
plt.xlabel('Epoch', fontsize=11)
plt.ylabel('Loss (MSE)', fontsize=11)
plt.title('Training Progress', fontsize=12, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(val_losses, label='Validation Loss', linewidth=2, color='orange')
plt.xlabel('Epoch', fontsize=11)
plt.ylabel('Loss (MSE)', fontsize=11)
plt.title('Validation Loss (Zoomed)', fontsize=12, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plot_path = MODEL_SAVE_DIR / 'training_curves.png'
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"üìä Training curves saved to {plot_path}")

## 8Ô∏è‚É£ Save Training Logs

In [None]:
# Save training history
history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'test_loss': test_loss,
    'best_val_loss': best_val_loss,
    'epochs_trained': len(train_losses),
    'config': {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'early_stopping_patience': EARLY_STOPPING_PATIENCE,
        'augmentation': AUGMENT,
        'init_features': 64
    },
    'timestamp': datetime.now().isoformat()
}

log_path = MODEL_SAVE_DIR / 'training_history.json'
with open(log_path, 'w') as f:
    json.dump(history, f, indent=4)

print(f"üìù Training history saved to {log_path}")
print(f"\nüìä Training Summary:")
print(f"   Epochs trained: {len(train_losses)}")
print(f"   Best validation loss: {best_val_loss:.4f}")
print(f"   Final test loss: {test_loss:.4f}")
print(f"   Final train loss: {train_losses[-1]:.4f}")

## 9Ô∏è‚É£ Prediction Visualization

In [None]:
# Concatenate all predictions and targets
all_predictions = torch.cat(predictions, dim=0)  # (N, 1, H, W)
all_targets = torch.cat(targets_list, dim=0)      # (N, 1, H, W)

# Visualize some predictions
n_samples = 4
fig, axes = plt.subplots(n_samples, 3, figsize=(12, 4*n_samples))

for i in range(n_samples):
    pred = all_predictions[i, 0].numpy()
    target = all_targets[i, 0].numpy()
    diff = np.abs(pred - target)
    
    axes[i, 0].imshow(target, cmap='gray')
    axes[i, 0].set_title(f'Target Slice {i+1}', fontweight='bold')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(pred, cmap='gray')
    axes[i, 1].set_title(f'Predicted Slice {i+1}', fontweight='bold')
    axes[i, 1].axis('off')
    
    im = axes[i, 2].imshow(diff, cmap='hot')
    axes[i, 2].set_title(f'Difference {i+1}', fontweight='bold')
    axes[i, 2].axis('off')
    plt.colorbar(im, ax=axes[i, 2], fraction=0.046, pad=0.04)

plt.tight_layout()
pred_path = MODEL_SAVE_DIR / 'predictions_visualization.png'
plt.savefig(pred_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"üñºÔ∏è  Predictions visualization saved to {pred_path}")

## üîü Model Save Summary

In [None]:
print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETED SUCCESSFULLY!")
print("="*70)
print(f"\nüìÇ Models saved to: {MODEL_SAVE_DIR}")
print(f"\nüìã Files created:")

for file in sorted(MODEL_SAVE_DIR.glob('*')):
    size = file.stat().st_size / (1024*1024)  # MB
    print(f"   ‚úì {file.name:40s} ({size:6.2f} MB)")

print(f"\nüìä Key Results:")
print(f"   Best Validation Loss: {best_val_loss:.4f}")
print(f"   Final Test Loss: {test_loss:.4f}")
print(f"   Total Epochs: {len(train_losses)}")
print(f"   Improvement: {(train_losses[0] - test_loss) / train_losses[0] * 100:.2f}%")
print("\n" + "="*70)