# Part 1: Imports
- Import the necessary libraries for training and analysis, and define the paths to the images/masks folders to be used for training.
- The naming convention is "image_i" and "image_i_mask_j".
- Images used for training are in the .png format.

In [None]:
# %%capture
# Uncomment last line to install PyTorch with CUDA (check compatible version) 
# If there are conflicting dependencies when installing with Anaconda, it can be installed directly from the notebook
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
%%capture
#Imports
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn

import torchvision # Pre-trained Mask R-CNN model
from torchvision import transforms
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as F
from torchvision.models.detection.rpn import AnchorGenerator
import torchvision.transforms as T
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.ops import nms

from PIL import Image, ImageDraw
import os
import numpy as np
import cv2
import random
import matplotlib.pyplot as plt
import glob
import tifffile
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score
import csv
import pandas as pd

In [None]:
#Check that GPU is available for calculations, and set the "device" variable
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))
# Check if GPU is available, otherwise use CPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device: ",device)

In [None]:
# Paths to image and masks for training
image_dir = r"C:\Users\leona\0_etch pits_analysis\etch pits test\Training\Images"
mask_dir = r"C:\Users\leona\0_etch pits_analysis\etch pits test\Training\Masks"
output_dir = r"C:\Users\leona\0_etch pits_analysis\etch pits test\Pre-process"
models_dir = r"C:\Users\leona\0_etch pits_analysis\etch pits test\Models"

# Part 2: Data pre-processing
- Define the pre-processing of the data that will be used for the training of the model.
- The pre-processing consists of:
    - load the data into the correct shape for training. Some steps require the channel dimension or a dummy dimension to work; 
    - normalizing image values to the 0-1 range;
    - splitting the image in overlapping squares (512x512 size) using a grid (5x3), maintaning the link with the corresponding masks;
    - removing training data that is not fully in a crop; 
    - applying augmentation functions to improve generalization of the model.
- The data is saved as a .pt file to free up memory (large training datasets can consume a lot of memory).

In [None]:
def preprocess_data(image_dir, mask_dirs, output_dir, crop_size=(512, 512), grid_size=(2, 3), augment=False):
    """
    Preprocess images and masks by cropping and saving them using a grid.

    Args:
        image_dir (list of str): List of paths to input images.
        mask_dirs (list of str): List of directories containing masks for each image.
        output_dir (str): Directory to save preprocessed data.
        crop_size (tuple): Size of the crops (H, W).
        grid_size (tuple): Grid dimensions for cropping (rows, cols).
        augment (bool): Presence/absence of augmentation function
    """
    dataset = []  # List to hold all samples
    os.makedirs(output_dir, exist_ok=True)
    crop_H, crop_W = crop_size

    # List all images in the image directory
    image_paths = sorted(glob.glob(os.path.join(image_dir, "image_*")))  # Match image_i format
    
    for image_path in image_paths:
        # Extract image index
        image_idx = os.path.basename(image_path).split("_")[1].split(".")[0]  # Extract 'i' from 'image_i'
        
        # Match corresponding masks for the image
        mask_pattern = os.path.join(mask_dir, f"image_{image_idx}_mask_*")
        mask_files = sorted(glob.glob(mask_pattern))  # Match image_i_mask_j format

        # Load the image
        image = Image.open(image_path).convert("I")
        image_tensor = transforms.ToTensor()(image)  # Convert to tensor [C, H, W]

        # Load masks
        masks = []
        for mask_path in mask_files:
            mask = Image.open(mask_path).convert('L')
            masks.append(np.array(mask, dtype=np.uint8))
        masks = np.stack(masks, axis=0)  # Shape: [N, H, W]
        masks_tensor = torch.tensor(masks, dtype=torch.uint8)  # [N, H, W]

        # Generate dummy boxes for all masks
        boxes = []
        for mask in masks:
            pos = np.where(mask > 0)
            if pos[0].size > 0:  # Check for non-empty mask
                xmin, ymin, xmax, ymax = np.min(pos[1]), np.min(pos[0]), np.max(pos[1]), np.max(pos[0])
                boxes.append([xmin, ymin, xmax, ymax])
        boxes_tensor = torch.tensor(boxes, dtype=torch.float32)  # [N, 4]

        # Assign dummy labels (1 for all instances, can be adjusted later)
        labels_tensor = torch.ones((len(boxes),), dtype=torch.int64)  # [N]

        # Split into crops
        crops = split_image_and_targets(
            image=image_tensor,
            masks=masks_tensor,
            boxes=boxes_tensor,
            labels=labels_tensor,
            grid_size=grid_size,
            crop_size=crop_size,
            augment=augment,
            image_index=image_idx
        )
        # Append all crops to the dataset
        dataset.extend(crops)
        
    # Save as .pt file
    crop_filename = "dataset_crops.pt"
    torch.save(dataset, os.path.join(output_dir, crop_filename))    

    print(f"Processed images -> Saved crops to {crop_filename}")
  

    return dataset

In [None]:
# Define augmentations
augmentation_pipeline = T.Compose([
    T.RandomHorizontalFlip(p=0.3),
    T.RandomVerticalFlip(p=0.3),
    T.RandomRotation(degrees=30),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.RandomResizedCrop(size=(512, 512), scale=(0.2, 1)),
    T.GaussianBlur(kernel_size=(1, 25), sigma=(10., 15.)),
    T.RandomPerspective(distortion_scale=0.2, p=0.35)
])

In [None]:
def apply_augmentation(crop_data):
    """
    Apply augmentations to a crop (image, masks, boxes).

    Args:
        crop_data (dict): Contains 'image', 'masks', 'boxes', 'labels'.

    Returns:
        dict: Augmented crop data.
    """
    image = to_pil_image(crop_data['image'])  # Convert tensor to PIL
    augmented_image = augmentation_pipeline(image)  # Apply augmentations
    crop_data['image'] = to_tensor(augmented_image)  # Convert back to tensor
    return crop_data

In [None]:
def split_image_and_targets(image, masks, boxes, labels, grid_size=(2, 3), crop_size=(512, 512), augment=False, image_index=0):
    """
    Splits the image, masks, and bounding boxes into overlapping crops with optional augmentations.

    Args:
        image (Tensor): Original image tensor of shape [C, H, W].
        masks (Tensor): Masks tensor of shape [N, H, W].
        boxes (Tensor): Bounding boxes tensor of shape [N, 4].
        labels (Tensor): Labels tensor of shape [N].
        grid_size (tuple): (rows, cols) for splitting the image.
        crop_size (tuple): (height, width) of each crop.
        augment (bool): Whether to apply augmentations.
        image_index (int): Index of the image in the dataset.

    Returns:
        list: A list of dictionaries, each containing:
              - 'image': Cropped image tensor [C, crop_H, crop_W].
              - 'masks': Cropped masks tensor [N', crop_H, crop_W].
              - 'boxes': Adjusted bounding boxes tensor [N', 4].
              - 'labels': Cropped labels tensor [N'].
              - 'metadata': Metadata dictionary containing:
                            - 'image_index': Index of the original image.
                            - 'x_start': x-coordinate of the crop's top-left corner.
                            - 'y_start': y-coordinate of the crop's top-left corner.
    """
    H, W = image.shape[1], image.shape[2]
    crop_H, crop_W = crop_size
    rows, cols = grid_size

    stride_H = (H - crop_H) // (rows - 1) if rows > 1 else 0
    stride_W = (W - crop_W) // (cols - 1) if cols > 1 else 0

    crops = []

    for row in range(rows):
        for col in range(cols):
            # Compute crop coordinates
            y_start = min(row * stride_H, H - crop_H)
            x_start = min(col * stride_W, W - crop_W)
            y_end = y_start + crop_H
            x_end = x_start + crop_W

            # Crop the image
            cropped_image = image[:, y_start:y_end, x_start:x_end]

            # Adjust bounding boxes
            cropped_masks = []
            cropped_boxes = []
            cropped_labels = []

            for mask_idx, box in enumerate(boxes):
                x_min, y_min, x_max, y_max = box
                # Check if the box overlaps with the crop
                if x_min < x_end and x_max > x_start and y_min < y_end and y_max > y_start:
                    # Adjust box coordinates to crop-relative
                    new_x_min = max(x_min - x_start, 0)
                    new_y_min = max(y_min - y_start, 0)
                    new_x_max = min(x_max - x_start, crop_W)
                    new_y_max = min(y_max - y_start, crop_H)
                    cropped_boxes.append([new_x_min, new_y_min, new_x_max, new_y_max])
                    cropped_labels.append(labels[mask_idx].item())  # Convert label to int
                    
                    # Crop the corresponding mask
                    cropped_mask = masks[mask_idx, y_start:y_end, x_start:x_end]
                    cropped_mask = (cropped_mask > 0).float()
                    cropped_masks.append(cropped_mask)

            # Stack masks into a tensor if there are valid masks
            if cropped_masks:
                cropped_masks = torch.stack(cropped_masks, dim=0)
            else:
                cropped_masks = torch.zeros((0, crop_H, crop_W), dtype=torch.uint8)
                
            # Add metadata to the crop for reconstruction
            metadata = {
                "image_index": image_index,
                "x_start": x_start,
                "y_start": y_start,
            }
        
            # Add to results if there are valid boxes
            crop_data = {
                    "image": cropped_image,
                    "masks": cropped_masks,
                    "boxes": torch.tensor(cropped_boxes, dtype=torch.float32),
                    "labels": torch.tensor(cropped_labels, dtype=torch.int64),
                    "metadata": metadata
                }
                
            if augment:
                crop_data = apply_augmentation(crop_data)

            crops.append(crop_data)

    return crops

In [None]:
def normalize_image(image, percentile_min=2, percentile_max=98, padding_value=0):
    """
    Normalize the image while handling outliers and ignoring padded edges.
    
    Args:
        image (torch.Tensor): Input image tensor.
        percentile_min (float): Lower percentile for clipping (default: 2%).
        percentile_max (float): Upper percentile for clipping (default: 98%).
        padding_value (float): Value of padded background (default: 0).
        
    Returns:
        torch.Tensor: Normalized image tensor.
    """
    
    # Mask out padded regions
    pad_mask = image != padding_value
    if not pad_mask.any():
        print("Warning: Entire image is padding. Returning zeros.")
        return torch.zeros_like(image)
        
    # Ensure the image is a float tensor
    image = image.float()

    # Extract valid pixel values (ignoring padding)
    valid_pixels = image[pad_mask]
    
    # Compute percentile-based clipping bounds
    min_val = torch.quantile(valid_pixels, percentile_min / 100.0)
    max_val = torch.quantile(valid_pixels, percentile_max / 100.0)
    
    # Clip values to the specified percentiles
    clipped_image = torch.clip(image, min=min_val, max=max_val)
    
    # Normalize to [0, 1], ignoring padding
    normalized_image = (clipped_image - min_val) / (max_val - min_val)
    normalized_image[~pad_mask] = 0  # Restore padding values to 0    
    
    # Avoid division by zero if all values are the same
    if max_val == min_val:
        print("Warning: Image has uniform intensity. Returning zeros.")
        return torch.zeros_like(image)

    return normalized_image

In [None]:
# This can handle both data in memory or saved as a .pt file
class ROIDataset(Dataset):
    def __init__(self, dataset=None, data_path=None, transform=None):
        """
        Args:
            dataset (list or None): Preloaded dataset (list of dictionaries with 'image', 'masks', etc.).
            data_path (str or None): Directory containing preprocessed `.pt` files.
            transform (callable, optional): Optional transform to be applied to the image.
        """
        if data_path:
            self.data_path = data_path
            self.data = torch.load(self.data_path)  # Data will be loaded from file
        elif dataset:
            self.data = dataset
            self.data_path = None # No files to load
        else:
            raise ValueError("Either 'dataset' or 'data_path' must be provided.")

        self.transform = transform

    def __len__(self):
        return len(self.data)

    @staticmethod
    def is_mask_touching_border(mask, threshold=0, threshold_size=0):
        """
        Check if a mask touches the border of an image for more than a threshold of pixels.
        """
        
        h, w = mask.shape
        top_border = mask[0, :].sum()
        bottom_border = mask[-1, :].sum()
        left_border = mask[:, 0].sum()
        right_border = mask[:, -1].sum()
        mask_size = h*w
        return (top_border > threshold or 
                bottom_border > threshold or 
                left_border > threshold or 
                right_border > threshold or
                mask_size < threshold_size)
        
    def __getitem__(self, idx):
        if self.data_path:
            # Load preprocessed data from file 
            data = self.data[idx]
        elif self.data:
            # Use in-memory dataset
            data = self.data[idx]
        else:
            raise RuntimeError("Dataset is not properly initialized.")

        image = data["image"]
        if self.transform:
            if isinstance(image, torch.Tensor):
                image = transforms.ToPILImage()(image)  # Convert back to PIL for transforms
            image = self.transform(image)
        image = normalize_image(image)
        
        # Filter masks that touch the border
        masks = data["masks"]
        boxes = data["boxes"]
        labels = data["labels"]
            
        valid_indices = []
        for i, mask in enumerate(masks):
            if not self.is_mask_touching_border(mask.numpy(), threshold=5, threshold_size=30):
                valid_indices.append(i)
                
        # Apply filtering
        masks = masks[valid_indices]
        boxes = boxes[valid_indices]
        labels = labels[valid_indices]
        
        # Skip if no valid masks remain
        if len(valid_indices) == 0:
            return None  # Indicate this crop should be skipped
        
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
        }

        return image, target

In [None]:
# After having defined all the functions, the lines below uses them
data = preprocess_data(image_dir, mask_dir, output_dir, crop_size=(512, 512), grid_size=(3, 5))

In [None]:
# Define transformations
# We use Compose in case we want to add pre-processing steps later on
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Create uncropped dataset and dataloader
data_path = os.path.join(output_dir,"dataset_crops.pt")
dataset = ROIDataset(data_path=data_path, transform=transform)

# Part 3: Build dataloaders
- With the input data in the correct shape, we can now build the training dataset in a way that the model can use it for training.
- The data is split in training, validation and test datasets.
- A few examples are visualized to check that the correct data is being loaded.

In [None]:
def custom_collate_fn(batch):
    # Remove None entries
    batch = [item for item in batch if item is not None]
    
    # If batch is empty, return empty lists (skip this batch during training)
    if len(batch) == 0:
        return [], []
    
    images, targets = zip(*batch)
    return list(images), list(targets)

In [None]:
# Define dataset sizes
dataset_size = len(dataset)
train_size = int(0.80 * dataset_size)  # 80% for training
val_size = int(0.12 * dataset_size)   # 12% for validation
test_size = dataset_size - train_size - val_size  # Remaining 8% for testing

# Split dataset
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

# Create dataloaders for each split
batch_size = 8  # Define batch size
train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False, collate_fn=custom_collate_fn)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
# Visualization for both uncropped and cropped data
def visualize_data(loader, num_samples=1):
    sample_count = 0  # To track the total number of samples visualized
    for batch_images, batch_targets in loader:
        for idx in range(len(batch_images)):  # Loop through each sample in the batch
            if sample_count >= num_samples:
                return  # Stop visualization after reaching the required number of samples

            # Access individual image and target
            img_tensor = batch_images[idx]
            if isinstance(img_tensor, list):  # Handle if it's a nested list
                img_tensor = img_tensor[0]
            
            targets = batch_targets[idx]
            if isinstance(targets, list):  # Handle if targets are a nested list
                targets = targets[0]

            # Handle different image channel formats
            if img_tensor.shape[0] == 1:  # Grayscale
                image = transforms.ToPILImage()(img_tensor.squeeze(0).cpu())
            elif img_tensor.shape[0] == 3:  # RGB
                image = transforms.ToPILImage()(img_tensor.cpu())
            elif img_tensor.shape[0] == 4:  # RGBA
                image = transforms.ToPILImage()(img_tensor[:3, :, :].cpu())  # Drop alpha
            else:
                print(f"Unexpected number of channels: {img_tensor.shape[0]}. Skipping visualization.")
                continue

            # Combine masks into a single 2D mask
            masks = targets["masks"].cpu().numpy()  # Shape: [N, H, W]
            print(f"Mask shape: {masks.shape}")
            combined_mask = np.sum(masks, axis=0)  # Sum across the first dimension (N)

            # Clamp values to avoid unintended brightness if masks overlap
            combined_mask = np.clip(combined_mask, 0, 1)

            # Draw bounding boxes on the image
            draw = ImageDraw.Draw(image)
            boxes = targets["boxes"].cpu().numpy()  # Shape: [N, 4]
            for box in boxes:
                if box.shape != (4,):  # Validate box shape
                    print(f"Invalid box format: {box}. Skipping.")
                    continue
                # Ensure box coordinates are scalar values
                x_min, y_min, x_max, y_max = [float(coord) for coord in box]
                draw.rectangle([x_min, y_min, x_max, y_max], outline="yellow", width=3)

            # Visualize the image and combined mask
            plt.figure(figsize=(10, 5))

            # Image
            plt.subplot(1, 2, 1)
            plt.imshow(image, cmap="gray")
            plt.title(f"Image {sample_count + 1} with Bounding Boxes")

            # Combined mask
            plt.subplot(1, 2, 2)
            plt.imshow(combined_mask, cmap="gray")
            plt.title(f"Combined Mask {sample_count + 1} (Sum Across Channels)")

            plt.show()

            sample_count += 1  # Increment the sample counter

In [None]:
# Visualize examples from cropped datasets
print("Visualizing Data:")
visualize_data(train_loader, num_samples=1)

# Part 4: fine-tune pre-trained model
- Load the pre-trained model (maskrcnn_resnet50_fpn) from torchvision
- Define the necessary training parameters (num_classes) and hyperparameters:
    - learning rate
    - weight decay
    - scheduler's step_size and gamma
    - epochs
- Train model for the number of epochs defined. Note that the model's weights are saved after each epoch by default. This can use a lot of storage and might not be necessary in most cases.
### IMPORTANT - If the notebook is to be used only to predict (without training), the pre-trained model still needs to be imported

In [None]:
%%capture
# Load the pre-trained model
image_height = 512  # Uncropped image height
image_width = 512   # Uncropped image width
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
    min_size=image_height,
    max_size=image_width,
    pretrained=True
)

# Modify the model's head for dataset
num_classes = 2  # Background + 1 ROI class
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

# Update the mask predictor
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

# Move the model to the device (GPU/CPU)
model = model.to(device)

# Print the model to verify. Uncomment if needed
#print(model)

In [None]:
# Define optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0001) # Hyperparameters1

# Define learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)  # Hyperparameters2: Decays lr by gamma every 5 epochs

In [None]:
def apply_nms(outputs, iou_threshold=0.5):
    """
    Apply Non-Maximum Suppression (NMS) to model predictions.

    Args:
        outputs (list of dict): List of predictions, each containing:
                                - 'boxes': Tensor of shape [N, 4]
                                - 'scores': Tensor of shape [N]
                                - 'masks': Tensor of shape [N, H, W]
        iou_threshold (float): IoU threshold for NMS.

    Returns:
        list of dict: Predictions after NMS.
    """
    filtered_outputs = []
    for output in outputs:
        boxes = output['boxes']
        scores = output['scores']
        masks = output['masks']

        # Apply NMS
        keep_indices = nms(boxes, scores, iou_threshold)

        # Filter predictions
        filtered_outputs.append({
            'boxes': boxes[keep_indices],
            'scores': scores[keep_indices],
            'masks': masks[keep_indices]
        })

    return filtered_outputs

In [None]:
# This is the main training loop. We store the weights after each epoch for comparison purposes, but only the last one is needed.
# To save space, "model_save" can be set to "False".

model_save = True
num_epochs = 30
training_losses = []  # List to store average loss per epoch
validation_losses = []

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    epoch_loss = 0  # Accumulate loss for the epoch
    num_batches = 0  # Count batches in the epoch
    
    print(f"Epoch {epoch + 1}/{num_epochs}")
    for batch_idx, (images, targets) in enumerate(train_loader):
        # Skip empty batches
        if len(images) == 0 or len(targets) == 0:
            print("Skipping empty batch...")
            continue

        # Move data to device
        images = [img.to(device) for img in images]
        targets = [
            {
                k: (v.to(device) if isinstance(v, torch.Tensor) else v)
                for k, v in t.items()
            }
            for t in targets
        ]

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        try:
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values()) # Compute total loss
            losses.backward() # Backward pass
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0) # Gradient clipping
            optimizer.step() # Update parameters
        except Exception as e:
            print(f"Error during forward pass: {e}")
            continue
            
        # Update loss trackers
        epoch_loss += losses.item()
        num_batches += 1
        print(f"  Batch {batch_idx + 1}: Loss = {losses.item():.4f}")

    # Compute average training loss
    avg_train_loss = epoch_loss / num_batches if num_batches > 0 else float('inf')
    training_losses.append(avg_train_loss)
    print(f"Epoch {epoch + 1} Training Loss: {avg_train_loss:.4f}")
    
    # Validation phase
    model.eval()  # Set model to evaluation mode
    val_loss = 0
    val_batches = 0
    
    print(f"Epoch {epoch + 1}/{num_epochs} - Validation Phase")
    
    with torch.no_grad():  # Disable gradient computation
        for batch_idx, (images, targets) in enumerate(val_loader):  # Use validation dataloader
    
            # Move data to device
            images = [img.to(device) for img in images]
            targets = [
                {
                    k: (v.to(device) if isinstance(v, torch.Tensor) else v)
                    for k, v in t.items()
                }
                for t in targets
            ]
            # Temporarily switch to training mode for loss calculation
            model.train()
            try:
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                val_loss += losses.item()
                val_batches += 1
                print(f"  Validation Batch {batch_idx + 1}: Loss = {losses.item():.4f}")
            except Exception as e:
                print(f"Error during validation forward pass on batch {batch_idx + 1}: {e}")
                continue
            finally:
                model.eval()  # Switch back to evaluation mode
    
    # Compute average validation loss
    avg_val_loss = val_loss / val_batches if val_batches > 0 else float('inf')
    validation_losses.append(avg_val_loss)
    print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f}")

    # Adjust learning rate
    scheduler.step()
    
    # Save the model after each epoch
    model_filename = f"model_epoch_{epoch + 1}.pth"
    model_save_path = os.path.join(models_dir, model_filename)
    if model_save == True or epoch = (num_epochs - 1):
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_train_loss,
        }, model_save_path)
        print(f"Saved model at {model_save_path}")
    
# After training, plot the training and validation losses
plt.plot(range(1, num_epochs + 1), training_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), validation_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss Over Epochs')
plt.show()

In [None]:
# After training, plot the training and validation losses
plt.plot(range(1, num_epochs + 1), training_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), validation_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss Over Epochs')
plt.show()

In [None]:
# Training and validation losses are stored to check if the model is being trained correctly and is not overfitting 
with open(r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Models\training_losses.txt", "w") as output:
   for row in training_losses:
        output.write(str(row) + '\n')
with open(r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Models\validation_losses.txt", "w") as output:
   for row in validation_losses:
        output.write(str(row) + '\n')
loss_dict

#### Part 5(optional): Re-load model
- Once the model's weights are calculated from the training and stored, they can be quickly reloaded without going through training again.
- If the notebook is is to be used only to predict (without training), also the pre-trained model needs to be re-loaded (see Part 4)

In [None]:
# Re-load weights and check that all model's keys are matched
weights_path = r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Models\model4_512x512_3x5 grid\model_epoch_30.pth"
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict['model_state_dict'], strict=False)

# Part 6: Visualize predictions
- Visualize predictions using cropped images in the test dataloader.
- The same functions can now be used to analyze any data, including fresh new SEM images from the lab. 

In [None]:
# Here we use the test_dataloader (for which we have a manually labelled ground truth)
# to visually and quantitatively assess the validity of the model's predictions


model.eval()  # Set model to evaluation mode
ious = []  # To store IoU values
precision_list = []
recall_list = []
f1_list = []

with torch.no_grad():
    for images, targets in test_loader:
        # Skip empty batches
        if len(images) == 0 or len(targets) == 0:
            print("Skipping empty batch during testing...")
            continue
        
        # Move images to device
        images = [img.to(device) for img in images]

        # Model predictions
        outputs = model(images)
        outputs = apply_nms(outputs, iou_threshold=1)

        for idx, output in enumerate(outputs):
            # Original image
            original_image = F.to_pil_image(images[idx].cpu())

            # Predicted masks: Combine all into one binary mask
            pred_masks = output['masks'].squeeze(1) > 0.7  # Shape: [N_pred, H, W]
            combined_pred_mask = torch.any(pred_masks, dim=0).cpu().numpy()  # Combine masks: [H, W]
            
            # Ground truth masks
            true_masks = targets[idx]['masks'].cpu().numpy().astype(int)  # Shape: [N_true, H, W]
            combined_true_mask = np.any(true_masks, axis=0).astype(int)  # Combine masks: [H, W]

            # IoU calculation
            iou = jaccard_score(combined_true_mask.flatten(), combined_pred_mask.flatten())
            ious.append(iou)
        
            # Bounding boxes
            pred_boxes = output['boxes'].cpu().numpy()  # Shape: [N_pred, 4]
            pred_scores = output['scores'].cpu().numpy()  # Shape: [N_pred]
            true_boxes = targets[idx]['boxes'].cpu().numpy()

            # Convert bounding boxes to binary masks for comparison
            pred_box_mask = np.zeros_like(combined_true_mask, dtype=int)
            true_box_mask = np.zeros_like(combined_true_mask, dtype=int)

            for box in pred_boxes:
                x1, y1, x2, y2 = map(int, box)
                pred_box_mask[y1:y2, x1:x2] = 1

            for box in true_boxes:
                x1, y1, x2, y2 = map(int, box)
                true_box_mask[y1:y2, x1:x2] = 1
            
            precision = precision_score(true_box_mask.flatten(), pred_box_mask.flatten())
            recall = recall_score(true_box_mask.flatten(), pred_box_mask.flatten())
            f1 = f1_score(true_box_mask.flatten(), pred_box_mask.flatten())

            precision_list.append(precision)
            recall_list.append(recall)
            f1_list.append(f1)  
            
            # Create a plot
            fig, ax = plt.subplots(1, 2, figsize=(15, 7))

            # Plot original image
            ax[0].imshow(original_image, cmap="gray")
            ax[0].set_title("Original Image")
            ax[0].axis("off")

            # Plot combined mask and bounding boxes
            ax[1].imshow(original_image, alpha=0.8)
            ax[1].imshow(combined_pred_mask, cmap='grey', vmin=0, vmax=1, alpha=0.5)
            for box, score in zip(pred_boxes, pred_scores):
                if score > 0.5:  # Confidence threshold
                    x1, y1, x2, y2 = box
                    rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                         edgecolor='red', facecolor='none', linewidth=2)
                    ax[1].add_patch(rect)
            ax[1].set_title("Predicted Masks and Bounding Boxes")
            ax[1].axis("off")

            plt.show()

# Print average metrics
print(f"Average IoU: {np.mean(ious):.4f}")
print(f"Average Precision: {np.mean(precision_list):.4f}")
print(f"Average Recall: {np.mean(recall_list):.4f}")
print(f"Average F1-Score: {np.mean(f1_list):.4f}")

In [None]:
# Print average metrics
print(f"Average IoU: {np.mean(ious):.4f}")
print(f"Average Precision: {np.mean(precision_list):.4f}")
print(f"Average Recall: {np.mean(recall_list):.4f}")
print(f"Average F1-Score: {np.mean(f1_list):.4f}")

In [None]:
# Here the predicted masks are visualized independently and with a score. Can be used for debugging if predictions are very far off.

for idx, output in enumerate(outputs):
    print(f"Image {idx}: Number of predicted boxes = {len(output['boxes'])}")
    print(f"Image {idx}: Number of predicted masks = {len(output['masks'])}")

    if len(output["boxes"]) == 0:
        print("No boxes found for this image. Skipping visualization.")
        continue

    # Plot the image, masks, and bounding boxes
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.set_title(f"Image {idx}: Predicted Masks and Bounding Boxes")
    ax.axis("off")
    
    # Visualize masks
    pred_masks = output["masks"].squeeze(1).cpu().numpy()
    boxes = output['boxes'].cpu().numpy()  # Shape: [N_pred, 4]
    scores = output['scores'].cpu().numpy()  # Shape: [N_pred]
    
    # Visualize each mask and bounding box
    for i, (mask, box, score) in enumerate(zip(pred_masks, boxes, scores)):
        if score > 0.5:  # Confidence threshold
            fig, ax = plt.subplots(1, 1, figsize=(6, 6))
            ax.set_title(f"Image {idx} - Mask {i} (Score: {score:.2f})")
            ax.axis("off")

            # Overlay the mask
            ax.imshow(mask, alpha=0.7, cmap="jet")  # Adjust alpha for transparency

            # Draw the bounding box
            x1, y1, x2, y2 = box
            rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                  edgecolor='red', facecolor='none', linewidth=2)
            ax.add_patch(rect)
            
            # Annotate with the score
            ax.text(x1, y1 - 10, f"{score:.2f}", color='red', fontsize=10,
                    bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))
            plt.show()

### Full-image predictions

In [None]:
#Paper images (partly seen by training)
test_dir = r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Training\analysis\images_tiff_batch-size8"
test_dir_pred = r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Training\analysis\images_tiff_batch-size8\predictions"
test_images_pred = r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Training\analysis\images_tiff_batch-size8\images_predictions"

test_paths = sorted(glob.glob(os.path.join(test_dir, "image_*")))
for test_image_path in test_paths:

    # Load the test image
    sum_final_boxes = []
    sum_final_center_x = []
    sum_final_center_y = []
    sum_final_scores = []
    sum_final_masks = []
    sum_sides = []
    measurements_unit = "µm"
    test_image_idx = os.path.basename(test_image_path).split("_")[1].split(".")[0]  # Extract 'i' from 'image_i'
    test_image = Image.open(test_image_path).convert("I")
    #x_res = test_image.pages[0].tags['XResolution'].value
    x_res = test_image.info["resolution"][0]
    pixel_size_x = 1 / x_res
    
    # Convert to tensor and normalize
    test_image_tensor = transforms.ToTensor()(test_image)
    test_image_tensor = normalize_image(test_image_tensor)
    
    # Add batch dimension
    test_image_tensor = test_image_tensor.unsqueeze(0).to(device)
    
    test_crops = split_image_and_targets(test_image_tensor[0],[],[],[], grid_size=(3, 5), crop_size=(512, 512))
    
    # Initialize blank canvas for reconstruction
    H, W = test_image_tensor.shape[2:]
    reconstructed_mask = np.zeros((H, W), dtype=np.float32)
    all_boxes = []
    all_scores = []
    all_masks = []
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        for idx, crop in enumerate(test_crops):
            # Prepare the crop
            image = crop['image']
            image = image.unsqueeze(0).to(device)  # Add batch and channel dimensions
            y_start, x_start = crop['metadata']['y_start'], crop['metadata']['x_start']
    
            # Model predictions
            test_outputs = model(image)
            #test_outputs = apply_nms(test_outputs, iou_threshold=0.)
    
            # Extract predictions
            output = test_outputs[0]
            masks = output['masks'].squeeze(1).cpu().numpy() if 'masks' in output else []
            boxes = output['boxes'].cpu().numpy() if 'boxes' in output else []
            scores = output['scores'].cpu().numpy() if 'scores' in output else []
    
            # Store predictions for merging
            for mask, box, score in zip(masks, boxes, scores):
                if score > 0.5:  # Confidence threshold
                    # Adjust bounding boxes to original coordinates
                    x1, y1, x2, y2 = box
                    adjusted_box = [x1 + x_start, y1 + y_start, x2 + x_start, y2 + y_start]
                    all_boxes.append(adjusted_box)
                    all_scores.append(score)
    
                    # Place mask back in original coordinates
                    adjusted_mask = np.zeros((H, W), dtype=np.float32)
                    y_slice = slice(y_start, y_start + mask.shape[0])
                    x_slice = slice(x_start, x_start + mask.shape[1])
                    adjusted_mask[y_slice, x_slice] = mask
                    all_masks.append(adjusted_mask)
    
    # Convert lists to tensors for NMS
    all_boxes = torch.tensor(all_boxes, dtype=torch.float32)
    all_scores = torch.tensor(all_scores, dtype=torch.float32)
    if len(all_boxes) > 0:
        # Apply NMS to filter overlapping boxes
        keep_indices = nms(all_boxes, all_scores, iou_threshold=0.3)
    
        # Filter boxes, scores, and masks based on NMS results
        final_boxes = all_boxes[keep_indices].numpy()
        final_scores = all_scores[keep_indices].numpy()
        final_masks = [all_masks[i] for i in keep_indices]
    
        # Combine masks into the reconstructed image
        for mask in final_masks:
            reconstructed_mask = np.maximum(reconstructed_mask, mask)
    
        # Plot the original image with overlaid predictions
        plt.imsave((os.path.join(test_dir_pred,f"image_{test_image_idx}.png")), reconstructed_mask, cmap='gray')
        plt.figure(figsize=(10, 10))
        plt.imshow(test_image_tensor.cpu().squeeze(), cmap='gray', alpha=0.5)
        plt.imshow(reconstructed_mask, cmap='gray', alpha=0.5)
        
        for box in final_boxes:
            x1, y1, x2, y2 = box
            center_x = x1 + (x2 - x1)/2
            center_y = y1 + (y2 - y1)/2 
            rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, edgecolor='red', facecolor='none', linewidth=2)
            plt.gca().add_patch(rect)
            area = 0.5 * (x2 - x1) * (y2 - y1) * (pixel_size_x**2)
            size = np.sqrt((4/np.sqrt(3))*area)
            sum_final_boxes.append(box)
            sum_final_center_x.append(center_x)
            sum_final_center_y.append(center_y)
            sum_sides.append(size)

        # Save the pits data as CSV file and the final image as a PNG file
        sum_dict = {'x':sum_final_center_x, 'y':sum_final_center_y, 'side': sum_sides}
        df = pd.DataFrame(sum_dict) 
        save_path_data = os.path.join(test_images_pred, f"image_{test_image_idx}.csv")
        df.to_csv(save_path_data)
        plt.axis('off')  # Turn off the axis for better visualization
        save_path_img = os.path.join(test_images_pred, f"image_{test_image_idx}.png")
        plt.savefig(save_path_img, bbox_inches='tight', pad_inches=0)                                    
        plt.figure(figsize=(10, 6))
        plt.hist(sum_sides, bins=np.arange(0, 0.45 + 0.02, 0.02), color='blue', alpha=0.7, edgecolor='black', weights=np.ones_like(sum_sides) / len(sum_sides))
        plt.title("Histogram of image "+str(test_image_idx), fontsize=16)
        plt.xlabel("Size (µm)", fontsize=14)
        plt.ylabel("Relative frequency", fontsize=14)
        ax = plt.gca()
        ax.set_xlim([0, 0.45])
        ax.set_ylim([0, 0.45])
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.show()

In [None]:
#Images not seen during training

test_dir = r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Training\analysis\images_tiff_batch-size8\Unseen_images"
test_dir_pred = r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Training\analysis\images_tiff_batch-size8\Unseen_images\Unseen_predictions"
test_images_pred = r"C:\Users\leona\0_etch-pits-segmentation_Mask-CNN\etch pits test\Training\analysis\images_tiff_batch-size8\Unseen_images\Unseen_images_predictions"

test_paths = sorted(glob.glob(os.path.join(test_dir, "image_*")))
for test_image_path in test_paths:

    # Load the test image
    sum_final_boxes = []
    sum_final_center_x = []
    sum_final_center_y = []
    sum_final_scores = []
    sum_final_masks = []
    sum_sides = []
    measurements_unit = "µm"
    test_image_idx = os.path.basename(test_image_path).split("_")[1].split(".")[0]  # Extract 'i' from 'image_i'
    test_image = Image.open(test_image_path).convert("I")
    #x_res = test_image.pages[0].tags['XResolution'].value
    x_res = test_image.info["resolution"][0]
    pixel_size_x = 1 / x_res
    
    # Convert to tensor and normalize
    test_image_tensor = transforms.ToTensor()(test_image)
    test_image_tensor = normalize_image(test_image_tensor)
    
    # Add batch dimension
    test_image_tensor = test_image_tensor.unsqueeze(0).to(device)
    
    test_crops = split_image_and_targets(test_image_tensor[0],[],[],[], grid_size=(3, 5), crop_size=(512, 512))
    
    # Initialize blank canvas for reconstruction
    H, W = test_image_tensor.shape[2:]
    reconstructed_mask = np.zeros((H, W), dtype=np.float32)
    all_boxes = []
    all_scores = []
    all_masks = []
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        for idx, crop in enumerate(test_crops):
            # Prepare the crop
            image = crop['image']
            image = image.unsqueeze(0).to(device)  # Add batch and channel dimensions
            y_start, x_start = crop['metadata']['y_start'], crop['metadata']['x_start']
    
            # Model predictions
            test_outputs = model(image)
            #test_outputs = apply_nms(test_outputs, iou_threshold=0.)
    
            # Extract predictions
            output = test_outputs[0]
            masks = output['masks'].squeeze(1).cpu().numpy() if 'masks' in output else []
            boxes = output['boxes'].cpu().numpy() if 'boxes' in output else []
            scores = output['scores'].cpu().numpy() if 'scores' in output else []
    
            # Store predictions for merging
            for mask, box, score in zip(masks, boxes, scores):
                if score > 0.5:  # Confidence threshold
                    # Adjust bounding boxes to original coordinates
                    x1, y1, x2, y2 = box
                    adjusted_box = [x1 + x_start, y1 + y_start, x2 + x_start, y2 + y_start]
                    all_boxes.append(adjusted_box)
                    all_scores.append(score)
    
                    # Place mask back in original coordinates
                    adjusted_mask = np.zeros((H, W), dtype=np.float32)
                    y_slice = slice(y_start, y_start + mask.shape[0])
                    x_slice = slice(x_start, x_start + mask.shape[1])
                    adjusted_mask[y_slice, x_slice] = mask
                    all_masks.append(adjusted_mask)
    
    # Convert lists to tensors for NMS
    all_boxes = torch.tensor(all_boxes, dtype=torch.float32)
    all_scores = torch.tensor(all_scores, dtype=torch.float32)
    if len(all_boxes) > 0:
        # Apply NMS to filter overlapping boxes
        keep_indices = nms(all_boxes, all_scores, iou_threshold=0.3)
    
        # Filter boxes, scores, and masks based on NMS results
        final_boxes = all_boxes[keep_indices].numpy()
        final_scores = all_scores[keep_indices].numpy()
        final_masks = [all_masks[i] for i in keep_indices]
    
        # Combine masks into the reconstructed image
        for mask in final_masks:
            reconstructed_mask = np.maximum(reconstructed_mask, mask)
    
        # Plot the original image with overlaid predictions
        plt.imsave((os.path.join(test_dir_pred,f"image_{test_image_idx}.png")), reconstructed_mask, cmap='gray')
        plt.figure(figsize=(10, 10))
        plt.imshow(test_image_tensor.cpu().squeeze(), cmap='gray', alpha=0.5)
        plt.imshow(reconstructed_mask, cmap='gray', alpha=0.5)
        
        for box in final_boxes:
            x1, y1, x2, y2 = box
            center_x = x1 + (x2 - x1)/2
            center_y = y1 + (y2 - y1)/2 
            rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, edgecolor='red', facecolor='none', linewidth=2)
            plt.gca().add_patch(rect)
            area = 0.5 * (x2 - x1) * (y2 - y1) * (pixel_size_x**2)
            size = np.sqrt((4/np.sqrt(3))*area)
            sum_final_boxes.append(box)
            sum_final_center_x.append(center_x)
            sum_final_center_y.append(center_y)
            sum_sides.append(size)

        # Save the pits data as CSV file and the final image as a PNG file
        sum_dict = {'x':sum_final_center_x, 'y':sum_final_center_y, 'side': sum_sides}
        df = pd.DataFrame(sum_dict) 
        save_path_data = os.path.join(test_images_pred, f"image_{test_image_idx}.csv")
        df.to_csv(save_path_data)
        plt.axis('off')  # Turn off the axis for better visualization
        save_path_img = os.path.join(test_images_pred, f"image_{test_image_idx}.png")
        plt.savefig(save_path_img, bbox_inches='tight', pad_inches=0)                                    
        plt.figure(figsize=(10, 6))
        plt.hist(sum_sides, bins=np.arange(0, 0.45 + 0.02, 0.02), color='blue', alpha=0.7, edgecolor='black', weights=np.ones_like(sum_sides) / len(sum_sides))
        plt.title("Histogram of image "+str(test_image_idx), fontsize=16)
        plt.xlabel("Size (µm)", fontsize=14)
        plt.ylabel("Relative frequency", fontsize=14)
        ax = plt.gca()
        ax.set_xlim([0, 0.45])
        ax.set_ylim([0, 0.45])
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.show()

### End of notebook