In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import os
import glob
from pathlib import Path
import nibabel as nib
from torchvision.transforms.functional import pad, center_crop
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms.functional as TF




In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def double_conv_unpadded(in_channels, out_channels):
    """
    Two consecutive 3x3 unpadded convolutions (padding=0).
    Total size reduction per block is 4 pixels (2 from first conv + 2 from second conv).
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class EncoderBlock(nn.Module):
    """
    U-Net Encoder step: Conv-Conv block followed by Max Pooling.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = double_conv_unpadded(in_channels, out_channels)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        conv_output = self.conv(x)
        pool_output = self.pool(conv_output)
        return conv_output, pool_output

class DecoderBlock(nn.Module):
    """
    U-Net Decoder step: Up-convolution, Cropping, Concatenation, and Refinement.
    """
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        
        self.up = nn.ConvTranspose2d(
            in_channels, 
            out_channels, 
            kernel_size=2, 
            stride=2
        )
        self.conv = double_conv_unpadded(out_channels + skip_channels, out_channels)

    def forward(self, x_bottom, x_skip):
        
        x_up = self.up(x_bottom) 
        
        diff_h = x_skip.size(2) - x_up.size(2)
        diff_w = x_skip.size(3) - x_up.size(3)
        
        if diff_h < 0 or diff_w < 0:
             raise RuntimeError(f"Cropping error: x_up ({x_up.shape}) is larger than x_skip ({x_skip.shape}).")

        # Corrected cropping logic for odd/even mismatches
        start_h = diff_h // 2
        end_h = x_skip.size(2) - (diff_h - start_h)
        start_w = diff_w // 2
        end_w = x_skip.size(3) - (diff_w - start_w)

        x_skip_cropped = x_skip[:, :, start_h:end_h, start_w:end_w]
        
        x_combined = torch.cat([x_up, x_skip_cropped], dim=1)
        
        return self.conv(x_combined)

class Encoder(nn.Module):
    """
    Block-based U-Net Encoder: 5 levels of downsampling.
    The bottleneck is the final convolution block *within* the encoder.
    """
    def __init__(self):
        super().__init__()
        self.block1 = EncoderBlock(1, 64)        # 1 -> 64 channels
        self.block2 = EncoderBlock(64, 128)      # 64 -> 128 channels
        self.block3 = EncoderBlock(128, 256)     # 128 -> 256 channels
        self.block4 = EncoderBlock(256, 512)     # 256 -> 512 channels
        
        # This is the U-Net bottleneck
        self.bottleneck = double_conv_unpadded(512, 1024) 

    def forward(self, x):
        x1, p1 = self.block1(x)
        x2, p2 = self.block2(p1)
        x3, p3 = self.block3(p2)
        x4, p4 = self.block4(p3)
        x5 = self.bottleneck(p4) # x5 is the bottleneck tensor
        
        return x1, x2, x3, x4, x5

class Decoder(nn.Module):
    def __init__(self, num_classes=1): 
        super().__init__()
        # bottle be 1024 channels
        self.upconv4 = DecoderBlock(in_channels=1024, skip_channels=512, out_channels=512) 
        self.upconv3 = DecoderBlock(in_channels=512, skip_channels=256, out_channels=256) 
        self.upconv2 = DecoderBlock(in_channels=256, skip_channels=128, out_channels=128) 
        self.upconv1 = DecoderBlock(in_channels=128, skip_channels=64, out_channels=64) 
        self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1) 

    def forward(self, x5_bottleneck, x4_skip, x3_skip, x2_skip, x1_skip):
        d4 = self.upconv4(x5_bottleneck, x4_skip)
        d3 = self.upconv3(d4, x3_skip)
        d2 = self.upconv2(d3, x2_skip)
        d1 = self.upconv1(d2, x1_skip)
        
        return self.out_conv(d1)

class Unet(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(num_classes)
        
    def forward(self, x):
        x1, x2, x3, x4, x5 = self.encoder(x)
        
        # Decoder
        output = self.decoder(x5, x4, x3, x2, x1)
        
        return output


In [None]:
class BrainScanDataset(Dataset):
    """
    Made by AI
    Loads 2D .png images and .png masks from 'images' and 'masks' folders.
    Resizes images to 572x572 (U-Net input) and
    masks to 388x388 (U-Net output).
    """
    def __init__(self, image_dir, mask_dir, 
                 input_size=(572, 572), 
                 output_size=(388, 388)):
        
        self.input_size = input_size
        self.output_size = output_size
        
        # Find all image files
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
        
        # Create corresponding mask paths
        self.mask_paths = []
        for img_path in self.image_paths:
            filename = os.path.basename(img_path)
            mask_path = os.path.join(mask_dir, filename)
            
            if os.path.exists(mask_path):
                self.mask_paths.append(mask_path)
            else:
                print(f"Warning: Missing mask for image {img_path}")

        if not self.image_paths:
            print(f"Warning: No '*.png' files found in {image_dir}")
        if not self.mask_paths:
            print(f"Warning: No corresponding '*.png' files found in {mask_dir}")

        print(f"Dataset created with {len(self.image_paths)} image/mask pairs.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        
        # Load image and mask
        image = Image.open(img_path).convert("L") # Load as grayscale (1 channel)
        mask = Image.open(mask_path).convert("L")  # Load as grayscale (1 channel)

        # --- Apply resizing ---
        # 1. Resize image to 572x572 (U-Net Input)
        img_resized = TF.resize(image, self.input_size, interpolation=TF.InterpolationMode.BILINEAR)
        
        # 2. Resize mask to 388x388 (U-Net Output)
        mask_resized = TF.resize(mask, self.output_size, interpolation=TF.InterpolationMode.NEAREST)

        # --- Convert to Tensors ---
        img_tensor = TF.to_tensor(img_resized)
        mask_tensor = TF.to_tensor(mask_resized)

        # Binarize mask (image is 0 or 255)
        mask_tensor = (mask_tensor > 0.5).float() # > 0.5 to be safe
        
        # Image is already normalized [0, 1] by to_tensor()
        
        return img_tensor, mask_tensor


def get_dataloaders(image_dir, mask_dir, batch_size=4, debug_subset_size=None):
    """
    Made with AI
    Creates and splits the dataset into train, validation, and test loaders.
    """
    # Create the training/validation dataset from the TrainingData folder
    full_dataset = BrainScanDataset(image_dir=image_dir, mask_dir=mask_dir)
    
    total_size = len(full_dataset)
    if total_size == 0:
        raise ValueError("Dataset is empty. Check file paths and file types.")
    
    if debug_subset_size is not None:
        if debug_subset_size > total_size:
            print(f"Warning: debug_subset_size ({debug_subset_size}) is larger than total dataset ({total_size}). Using total dataset.")
        else:
            print(f"--- DEBUG MODE: Using a subset of {debug_subset_size} images ---")
            # Split the full dataset into a tiny subset and the rest
            debug_set, _ = random_split(full_dataset, [debug_subset_size, total_size - debug_subset_size])
            full_dataset = debug_set 
            total_size = len(full_dataset)
    
    # Split the training data into train and validation sets
    train_size = int(total_size * 0.7) # 70% for train
    val_size = int(total_size * 0.15)  # 15% for val
    test_size = total_size - train_size - val_size # 15% for test
    
    print(f"Splitting {total_size} images into: Train ({train_size}), Val ({val_size}), Test ({test_size})")

    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size]
    )
    
    # Create DataLoaders with num_workers for speed
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    print(f"DataLoaders created:")
    print(f"  Train: {len(train_loader.dataset)} images")
    print(f"  Validation: {len(val_loader.dataset)} images")
    print(f"  Test: {len(test_loader.dataset)} images")
    
    return train_loader, val_loader, test_loader


def visualize_results(model, loader, device, num_to_show=3):
    """
    Made by AI
    Shows a 1x3 plot of (Input, True Mask, Predicted Mask)
    """
    print(f"\n--- Visualizing {num_to_show} validation examples ---")
    model.eval()
    
    try:
        images, masks = next(iter(loader))
    except StopIteration:
        print("Data loader is empty, skipping visualization.")
        return

    images = images.to(device)
    masks = masks.to(device)
    
    with torch.no_grad():
        outputs = model(images)
        # Apply sigmoid (since we use BCEWithLogitsLoss) and threshold
        preds = (torch.sigmoid(outputs) > 0.5).float()

    images = images.cpu()
    masks = masks.cpu()
    preds = preds.cpu()

    fig, axes = plt.subplots(num_to_show, 3, figsize=(15, 5 * num_to_show))
    fig.suptitle("Model Input vs. True Mask vs. Predicted Mask", fontsize=16)

    for i in range(num_to_show):
        if i >= len(images): 
            break
            
        ax = axes[i, 0] if num_to_show > 1 else axes[0]
        ax.imshow(images[i].squeeze(), cmap='gray')
        ax.set_title(f"Input Image (Resized {images[i].shape[-2:]})")
        ax.axis('off')

        ax = axes[i, 1] if num_to_show > 1 else axes[1]
        ax.imshow(masks[i].squeeze(), cmap='gray')
        ax.set_title(f"True Mask (Resized {masks[i].shape[-2:]})")
        ax.axis('off')

        ax = axes[i, 2] if num_to_show > 1 else axes[2]
        ax.imshow(preds[i].squeeze(), cmap='gray')
        ax.set_title(f"Predicted Mask ({preds[i].shape[-2:]})")
        ax.axis('off')

    plt.tight_layout()
    plt.savefig("model_visualization.png")
    print("Visualization saved as 'model_visualization.png'")


def train_model(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=5):
    """
    A complete training and validation loop.
    """
    print("\n--- Starting Training ---")
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        if len(train_loader) == 0:
            print("Warning: Training data loader is empty.")
            continue
        
        num_batches = len(train_loader) 
            
        for batch_idx, (images, masks) in enumerate(train_loader):
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if batch_idx % 2 == 0:
                print(f"Batch {batch_idx}/{len(train_loader)} | Total Loss: {train_loss:.4f}")

        
        model.eval()
        val_loss = 0.0
        
        if len(val_loader) == 0:
            print("Warning: Validation data loader is empty.")
            avg_train_loss = train_loss / num_batches
            print(f"Epoch {epoch+1}/{num_epochs} Summary | Train Loss: {avg_train_loss:.4f} | Val Loss: N/A")
            continue

        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
        avg_train_loss = train_loss / num_batches
        avg_val_loss = val_loss / len(val_loader)
        
        print(f"\nEpoch {epoch+1}/{num_epochs} summary | Avg Train Loss: {avg_train_loss:.4f} | Avg Val Loss: {avg_val_loss:.4f}\n")
    
    print("--- Training Finished ---")


if __name__ == "__main__":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # model 
    model = Unet(num_classes=1).to(device)
        
    base_dir = Path.cwd() 
    DATA_ROOT = base_dir.parent / "archive"
        
    image_folder = DATA_ROOT / "images"
    mask_folder = DATA_ROOT / "masks"
        
    print(f"Loading data from: {image_folder}")
    print(f"Loading labels from: {mask_folder}")
        
    train_loader, val_loader, test_loader = get_dataloaders(
        image_dir=image_folder,
        mask_dir=mask_folder,
        batch_size=5,
        debug_subset_size= 50)
        
    # Loss and Optimizer.
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
        
    # start train
    train_model(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=5)
        
    # visualisse
    if len(val_loader) > 0:
        visualize_results(model, val_loader, device, num_to_show=3)
        
    # 6. run test
    if len(test_loader) > 0:
        print("\n--- Evaluating on Test Set ---")
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for images, masks in test_loader:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                test_loss += loss.item()
                    
        avg_test_loss = test_loss / len(test_loader)
        print(f"Average Test Loss: {avg_test_loss:.4f}")
    else:
        print("\nNo test set to evaluate.")




Using device: mps
Loading data from: /Users/holgermaxfloelyng/Desktop/BioMed/MSc_Biomed/SEM_2/02456_Deep_learning/Project/archive/images
Loading labels from: /Users/holgermaxfloelyng/Desktop/BioMed/MSc_Biomed/SEM_2/02456_Deep_learning/Project/archive/masks
Dataset created with 3064 image/mask pairs.
--- DEBUG MODE: Using a subset of 50 images ---
Splitting 50 images into: Train (35), Val (7), Test (8)
DataLoaders created:
  Train: 35 images
  Validation: 7 images
  Test: 8 images

--- Starting Training ---


ValueError: Target size (torch.Size([5, 1, 129, 129])) must be the same as input size (torch.Size([5, 1, 4, 4]))

In [None]:
unet = Unet().to(device)
output = unet(torch.randn(1,3,256,256).to(device))
print(output.shape)
