# Part 1: Imports
- Import the necessary libraries for training and analysis, and define the paths to the images/masks folders to be used for training.
- The naming convention is "image_i" and "image_i_mask_j".
- Images used for training are in the .png format.

In [None]:
# %%capture
# Uncomment last line to install PyTorch with CUDA (check compatible version) 
# If there are conflicting dependencies when installing with Anaconda, it can be installed directly from the notebook
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
%%capture
#Imports
import torch
from torch.utils.data import Dataset, DataLoader, random_split,Subset
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn

import torchvision # Pre-trained Mask R-CNN model
from torchvision import transforms
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as F
from torchvision.models.detection.rpn import AnchorGenerator
import torchvision.transforms as T
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.ops import nms

from PIL import Image, ImageDraw
import os
import numpy as np
import cv2
import random
import copy
import time
import matplotlib.pyplot as plt
import glob
import tifffile
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score, accuracy_score, roc_auc_score,log_loss
from sklearn.model_selection import GroupKFold
import csv
import json
import pandas as pd

In [None]:
# Seeding all the modules ensures that the random processes required for training can be reproduced deterministically.
# This is needed for reporting, but not strictly necessary for predicting.

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    # Force deterministic algorithms
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    os.environ["PYTHONHASHSEED"] = str(seed)

seed_everything(42)
g = torch.Generator()
g.manual_seed(42)

In [None]:
#Check that GPU is available for calculations, and set the "device" variable

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))
# Check if GPU is available, otherwise use CPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device: ",device)

In [None]:
# Paths to image and masks for training

image_dir = r"C:\Users\USER\PROJECT\TRAINING\Images"
mask_dir = r"C:\Users\USER\PROJECT\TRAINING\Masks"
output_dir = r"C:\Users\USER\PROJECT\Pre-process"
models_dir = r"C:\Users\USER\PROJECT\Models"
test_image_dir = r"C:\Users\USER\PROJECT\TRAINING\ANALYSIS\Test\Images"
test_mask_dir = r"C:\Users\USER\PROJECT\TRAINING\ANALYSIS\Test\Masks"

# Part 2: Data pre-processing
- Define the pre-processing of the data that will be used for the training of the model.
- The pre-processing consists of:
    - load the data into the correct shape for training. Some steps require the channel dimension or a dummy dimension to work; 
    - normalizing image values to the 0-1 range;
    - splitting the image in overlapping squares (512x512 size) using a grid (5x3), maintaning the link with the corresponding masks;
    - removing training data that is not fully in a crop; 
    - applying augmentation functions to improve generalization of the model.
- The data is saved as a .pt file to free up memory (large training datasets can consume a lot of memory).

In [None]:
def preprocess_data(image_dir, mask_dir, output_dir, crop_size=(512, 512), grid_size=(2, 3), augment=False):
    """
    Preprocess images and masks by cropping and saving them using a grid.

    Args:
        image_dir (list of str): List of paths to input images.
        mask_dir (list of str): List of directories containing masks for each image.
        output_dir (str): Directory to save preprocessed data.
        crop_size (tuple): Size of the crops (H, W).
        grid_size (tuple): Grid dimensions for cropping (rows, cols).
        augment (bool): Presence/absence of augmentation function
    """
    dataset = []  # List to hold all samples
    os.makedirs(output_dir, exist_ok=True)
    crop_H, crop_W = crop_size

    # List all images in the image directory
    image_paths = sorted(glob.glob(os.path.join(image_dir, "image_*")))  # Match image_i format
    
    for image_path in image_paths:
        # Extract image index
        image_idx = os.path.basename(image_path).split("_")[1].split(".")[0]  # Extract 'i' from 'image_i'
        
        # Match corresponding masks for the image
        mask_pattern = os.path.join(mask_dir, f"image_{image_idx}_mask_*")
        mask_files = sorted(glob.glob(mask_pattern))  # Match image_i_mask_j format

        # Load the image
        image = Image.open(image_path).convert("I")
        image_tensor = transforms.ToTensor()(image)  # Convert to tensor [C, H, W]

        # Load masks
        masks = []
        for mask_path in mask_files:
            mask = Image.open(mask_path).convert('L')
            masks.append(np.array(mask, dtype=np.uint8))
        masks = np.stack(masks, axis=0)  # Shape: [N, H, W]
        masks_tensor = torch.tensor(masks, dtype=torch.uint8)  # [N, H, W]

        # Generate dummy boxes for all masks
        boxes = []
        for mask in masks:
            pos = np.where(mask > 0)
            if pos[0].size > 0:  # Check for non-empty mask
                xmin, ymin, xmax, ymax = np.min(pos[1]), np.min(pos[0]), np.max(pos[1]), np.max(pos[0])
                boxes.append([xmin, ymin, xmax, ymax])
        boxes_tensor = torch.tensor(boxes, dtype=torch.float32)  # [N, 4]

        # Assign dummy labels (1 for all instances, can be adjusted later)
        labels_tensor = torch.ones((len(boxes),), dtype=torch.int64)  # [N]

        # Split into crops
        crops = split_image_and_targets(
            image=image_tensor,
            masks=masks_tensor,
            boxes=boxes_tensor,
            labels=labels_tensor,
            grid_size=grid_size,
            crop_size=crop_size,
            augment=augment,
            image_index=image_idx
        )
        # Append all crops to the dataset
        dataset.extend(crops)
        
    # Save as .pt file
    crop_filename = "dataset_crops.pt"
    torch.save(dataset, os.path.join(output_dir, crop_filename))    

    print(f"Processed images -> Saved crops to {crop_filename}")
  
    return dataset

In [None]:
# Define augmentations
augmentation_pipeline = T.Compose([
    T.RandomHorizontalFlip(p=0.3),
    T.RandomVerticalFlip(p=0.3),
    T.RandomRotation(degrees=30),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.RandomResizedCrop(size=(512, 512), scale=(0.2, 1)),
    T.GaussianBlur(kernel_size=(1, 25), sigma=(10., 15.)),
    T.RandomPerspective(distortion_scale=0.2, p=0.35)
])

In [None]:
def apply_augmentation(crop_data):
    """
    Apply augmentations to a crop (image, masks, boxes).

    Args:
        crop_data (dict): Contains 'image', 'masks', 'boxes', 'labels'.

    Returns:
        dict: Augmented crop data.
    """
    image = to_pil_image(crop_data['image'])  # Convert tensor to PIL
    augmented_image = augmentation_pipeline(image)  # Apply augmentations
    crop_data['image'] = to_tensor(augmented_image)  # Convert back to tensor
    return crop_data

In [None]:
def split_image_and_targets(image, masks, boxes, labels, grid_size=(2, 3), crop_size=(512, 512), augment=False, image_index=0):
    """
    Splits the image, masks, and bounding boxes into overlapping crops with optional augmentations.

    Args:
        image (Tensor): Original image tensor of shape [C, H, W].
        masks (Tensor): Masks tensor of shape [N, H, W].
        boxes (Tensor): Bounding boxes tensor of shape [N, 4].
        labels (Tensor): Labels tensor of shape [N].
        grid_size (tuple): (rows, cols) for splitting the image.
        crop_size (tuple): (height, width) of each crop.
        augment (bool): Whether to apply augmentations.
        image_index (int): Index of the image in the dataset.

    Returns:
        list: A list of dictionaries, each containing:
              - 'image': Cropped image tensor [C, crop_H, crop_W].
              - 'masks': Cropped masks tensor [N', crop_H, crop_W].
              - 'boxes': Adjusted bounding boxes tensor [N', 4].
              - 'labels': Cropped labels tensor [N'].
              - 'metadata': Metadata dictionary containing:
                            - 'image_index': Index of the original image.
                            - 'x_start': x-coordinate of the crop's top-left corner.
                            - 'y_start': y-coordinate of the crop's top-left corner.
    """
    H, W = image.shape[1], image.shape[2]
    crop_H, crop_W = crop_size
    rows, cols = grid_size

    stride_H = (H - crop_H) // (rows - 1) if rows > 1 else 0
    stride_W = (W - crop_W) // (cols - 1) if cols > 1 else 0

    crops = []

    for row in range(rows):
        for col in range(cols):
            # Compute crop coordinates
            y_start = min(row * stride_H, H - crop_H)
            x_start = min(col * stride_W, W - crop_W)
            y_end = y_start + crop_H
            x_end = x_start + crop_W

            # Crop the image
            cropped_image = image[:, y_start:y_end, x_start:x_end]

            # Adjust bounding boxes
            cropped_masks = []
            cropped_boxes = []
            cropped_labels = []

            for mask_idx, box in enumerate(boxes):
                x_min, y_min, x_max, y_max = box
                # Check if the box overlaps with the crop
                if x_min < x_end and x_max > x_start and y_min < y_end and y_max > y_start:
                    # Adjust box coordinates to crop-relative
                    new_x_min = max(x_min - x_start, 0)
                    new_y_min = max(y_min - y_start, 0)
                    new_x_max = min(x_max - x_start, crop_W)
                    new_y_max = min(y_max - y_start, crop_H)
                    cropped_boxes.append([new_x_min, new_y_min, new_x_max, new_y_max])
                    cropped_labels.append(labels[mask_idx].item())  # Convert label to int
                    
                    # Crop the corresponding mask
                    cropped_mask = masks[mask_idx, y_start:y_end, x_start:x_end]
                    cropped_mask = (cropped_mask > 0).float()
                    cropped_masks.append(cropped_mask)

            # Stack masks into a tensor if there are valid masks
            if cropped_masks:
                cropped_masks = torch.stack(cropped_masks, dim=0)
            else:
                cropped_masks = torch.zeros((0, crop_H, crop_W), dtype=torch.uint8)
                
            # Add metadata to the crop for reconstruction
            metadata = {
                "image_index": image_index,
                "x_start": x_start,
                "y_start": y_start,
            }
        
            # Add to results if there are valid boxes
            crop_data = {
                    "image": cropped_image,
                    "masks": cropped_masks,
                    "boxes": torch.tensor(cropped_boxes, dtype=torch.float32),
                    "labels": torch.tensor(cropped_labels, dtype=torch.int64),
                    "metadata": metadata
                }
                
            if augment:
                crop_data = apply_augmentation(crop_data)

            crops.append(crop_data)

    return crops

In [None]:
def normalize_image(image, percentile_min=2, percentile_max=98, padding_value=0):
    """
    Normalize the image while handling outliers and ignoring padded edges.
    
    Args:
        image (torch.Tensor): Input image tensor.
        percentile_min (float): Lower percentile for clipping (default: 2%).
        percentile_max (float): Upper percentile for clipping (default: 98%).
        padding_value (float): Value of padded background (default: 0).
        
    Returns:
        torch.Tensor: Normalized image tensor.
    """
    
    # Mask out padded regions
    pad_mask = image != padding_value
    if not pad_mask.any():
        print("Warning: Entire image is padding. Returning zeros.")
        return torch.zeros_like(image)
        
    # Ensure the image is a float tensor
    image = image.float()

    # Extract valid pixel values (ignoring padding)
    valid_pixels = image[pad_mask]
    
    # Compute percentile-based clipping bounds
    min_val = torch.quantile(valid_pixels, percentile_min / 100.0)
    max_val = torch.quantile(valid_pixels, percentile_max / 100.0)
    
    # Clip values to the specified percentiles
    clipped_image = torch.clip(image, min=min_val, max=max_val)
    
    # Normalize to [0, 1], ignoring padding
    normalized_image = (clipped_image - min_val) / (max_val - min_val)
    normalized_image[~pad_mask] = 0  # Restore padding values to 0    
    
    # Avoid division by zero if all values are the same
    if max_val == min_val:
        print("Warning: Image has uniform intensity. Returning zeros.")
        return torch.zeros_like(image)

    return normalized_image

In [None]:
# This can handle both data in memory or saved as a .pt file
class ROIDataset(Dataset):
    def __init__(self, dataset=None, data_path=None, transform=None):
        """
        Args:
            dataset (list or None): Preloaded dataset (list of dictionaries with 'image', 'masks', etc.).
            data_path (str or None): Directory containing preprocessed `.pt` files.
            transform (callable, optional): Optional transform to be applied to the image.
        """
        if data_path:
            self.data_path = data_path
            self.data = torch.load(self.data_path)  # Data will be loaded from file
        elif dataset:
            self.data = dataset
            self.data_path = None # No files to load
        else:
            raise ValueError("Either 'dataset' or 'data_path' must be provided.")

        self.transform = transform

    def __len__(self):
        return len(self.data)

    @staticmethod
    def is_mask_touching_border(mask, threshold=0, threshold_size=0):
        """
        Check if a mask touches the border of an image for more than a threshold of pixels.
        """
        
        h, w = mask.shape
        top_border = mask[0, :].sum()
        bottom_border = mask[-1, :].sum()
        left_border = mask[:, 0].sum()
        right_border = mask[:, -1].sum()
        mask_size = h*w
        return (top_border > threshold or 
                bottom_border > threshold or 
                left_border > threshold or 
                right_border > threshold or
                mask_size < threshold_size)
        
    def __getitem__(self, idx):
        if self.data_path:
            # Load preprocessed data from file 
            data = self.data[idx]
        elif self.data:
            # Use in-memory dataset
            data = self.data[idx]
        else:
            raise RuntimeError("Dataset is not properly initialized.")

        image = data["image"]
        if self.transform:
            if isinstance(image, torch.Tensor):
                image = transforms.ToPILImage()(image)  # Convert back to PIL for transforms
            image = self.transform(image)
        image = normalize_image(image)
        
        # Filter masks that touch the border
        masks = data["masks"]
        boxes = data["boxes"]
        labels = data["labels"]
            
        valid_indices = []
        for i, mask in enumerate(masks):
            if not self.is_mask_touching_border(mask.numpy(), threshold=5, threshold_size=30):
                valid_indices.append(i)
                
        # Apply filtering
        masks = masks[valid_indices]
        boxes = boxes[valid_indices]
        labels = labels[valid_indices]
        
        # Skip if no valid masks remain
        if len(valid_indices) == 0:
            return None  # Indicate this crop should be skipped
        
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
        }

        return image, target

In [None]:
# After having defined all the functions, the lines below uses them
# This only needs to be run and saved once. If data was not previously pre-processed and saved, uncomment the line below

#data = preprocess_data(image_dir, mask_dir, output_dir, crop_size=(512, 512), grid_size=(3, 5))

In [None]:
# Define transformations
# We use Compose in case we want to add pre-processing steps later on
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Create uncropped dataset and dataloader
data_path = os.path.join(output_dir,"dataset_crops.pt")
dataset = ROIDataset(data_path=data_path, transform=transform)

# Part 3: Build dataloaders
- With the input data in the correct shape, we can now build the training dataset in a way that the model can use it for training.
- The data is split in training and validation datasets. This is repeated 5 times for k-fold cross-validation.
- Crops from the same image need to be collectively assigned to the same dataset to prevent data leakage. 
- A few examples are visualized to check that the correct data is being loaded.

In [None]:
def custom_collate_fn(batch):
    # Remove None entries
    batch = [item for item in batch if item is not None]
    
    # If batch is empty, return empty lists (skip this batch during training)
    if len(batch) == 0:
        print("Warning: empty batch encountered")
        return [], []
    
    images, targets = zip(*batch)
    return list(images), list(targets)

In [None]:
# k-fold cross-validation is needed to assess the model's metrics, but it requires running the training k times and on a reduced dataset.
# It's possible to then choose the fold with the best metrics, or alternatively set k = 1 and re-run the training
# In the second case, validation can be avoided since with the current setup it does not influence training. 

n_splits = 5       # Number of folds. Can be set to 1 for production training

# Extract groups (image indices) from dataset metadata
groups = [sample["metadata"]["image_index"] for sample in dataset.data]

if n_splits == 1:
    # Production mode: use all data for training, no validation
    splits = [(np.arange(len(dataset)), [])]  # train_idx = all, val_idx = empty
else:
    gkf = GroupKFold(n_splits=n_splits)
    splits = list(gkf.split(X=np.zeros(len(groups)), groups=groups))

# Loop over folds (works for both modes)
for fold, (train_idx, val_idx) in enumerate(splits):
    
    # Build Subsets for train and validation
    train_dataset = Subset(dataset, train_idx)
    val_dataset   = Subset(dataset, val_idx) if len(val_idx) > 0 else None
    
    # Dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=8,
        shuffle=True,
        collate_fn=custom_collate_fn,
        drop_last=True)
    
    val_loader = None
    if val_dataset is not None:
        val_loader = DataLoader(
            val_dataset,
            batch_size=8,
            shuffle=False,
            collate_fn=custom_collate_fn,
            drop_last=True)
    
    # Print summary
    train_groups = set(dataset.data[i]["metadata"]["image_index"] for i in train_idx)
    val_groups   = set(dataset.data[i]["metadata"]["image_index"] for i in val_idx) if val_dataset else set()
    overlap      = train_groups & val_groups
    print(f"Fold {fold+1}/{n_splits}: Train size={len(train_idx)}, "
          f"Val size={len(val_idx)}, Overlap={overlap}")

In [None]:
# Visualization for both uncropped and cropped data
def visualize_data(loader, num_samples=1):
    sample_count = 0  # To track the total number of samples visualized
    for batch_images, batch_targets in loader:
        for idx in range(len(batch_images)):  # Loop through each sample in the batch
            if sample_count >= num_samples:
                return  # Stop visualization after reaching the required number of samples

            # Access individual image and target
            img_tensor = batch_images[idx]
            if isinstance(img_tensor, list):  # Handle if it's a nested list
                img_tensor = img_tensor[0]
            
            targets = batch_targets[idx]
            if isinstance(targets, list):  # Handle if targets are a nested list
                targets = targets[0]

            # Handle different image channel formats
            if img_tensor.shape[0] == 1:  # Grayscale
                image = transforms.ToPILImage()(img_tensor.squeeze(0).cpu())
            elif img_tensor.shape[0] == 3:  # RGB
                image = transforms.ToPILImage()(img_tensor.cpu())
            elif img_tensor.shape[0] == 4:  # RGBA
                image = transforms.ToPILImage()(img_tensor[:3, :, :].cpu())  # Drop alpha
            else:
                print(f"Unexpected number of channels: {img_tensor.shape[0]}. Skipping visualization.")
                continue

            # Combine masks into a single 2D mask
            masks = targets["masks"].cpu().numpy()  # Shape: [N, H, W]
            print(f"Mask shape: {masks.shape}")
            combined_mask = np.sum(masks, axis=0)  # Sum across the first dimension (N)

            # Clamp values to avoid unintended brightness if masks overlap
            combined_mask = np.clip(combined_mask, 0, 1)

            # Draw bounding boxes on the image
            draw = ImageDraw.Draw(image)
            boxes = targets["boxes"].cpu().numpy()  # Shape: [N, 4]
            for box in boxes:
                if box.shape != (4,):  # Validate box shape
                    print(f"Invalid box format: {box}. Skipping.")
                    continue
                # Ensure box coordinates are scalar values
                x_min, y_min, x_max, y_max = [float(coord) for coord in box]
                draw.rectangle([x_min, y_min, x_max, y_max], outline="yellow", width=3)

            # Visualize the image and combined mask
            plt.figure(figsize=(10, 5))

            # Image
            plt.subplot(1, 2, 1)
            plt.imshow(image, cmap="gray")
            plt.title(f"Image {sample_count + 1} with Bounding Boxes")

            # Combined mask
            plt.subplot(1, 2, 2)
            plt.imshow(combined_mask, cmap="gray")
            plt.title(f"Combined Mask {sample_count + 1} (Sum Across Channels)")

            plt.show()

            sample_count += 1  # Increment the sample counter

In [None]:
# Visualize examples from cropped datasets
print("Visualizing Data:")
visualize_data(train_loader, num_samples=1)

# Part 4: fine-tune pre-trained model
- Load the pre-trained model (maskrcnn_resnet50_fpn) from torchvision
- Define the necessary training parameters (num_classes) and hyperparameters:
    - learning rate
    - weight decay
    - scheduler's step_size and gamma
    - epochs
- Train model for the number of epochs defined. Note that the model's weights are saved after each epoch by default. This can use a lot of storage and might not be necessary in most cases.
### IMPORTANT - If the notebook is to be used only to predict (without training), the pre-trained model still needs to be imported

In [None]:
%%capture
# Load the pre-trained model
image_height = 512  # Cropped image height
image_width = 512   # Cropped image width
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
    min_size=image_height,
    max_size=image_width,
    pretrained=True
)

# Modify the model's head for dataset
num_classes = 2  # Background + 1 ROI class
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

# Update the mask predictor
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

# Move the model to the device (GPU/CPU)
model = model.to(device)

# Print the model to verify. Uncomment if needed
#print(model)

In [None]:
def apply_nms(outputs, iou_threshold=0.5):
    """
    Apply Non-Maximum Suppression (NMS) to model predictions.

    Args:
        outputs (list of dict): List of predictions, each containing:
                                - 'boxes': Tensor of shape [N, 4]
                                - 'scores': Tensor of shape [N]
                                - 'masks': Tensor of shape [N, H, W]
        iou_threshold (float): IoU threshold for NMS.

    Returns:
        list of dict: Predictions after NMS.
    """
    filtered_outputs = []
    for output in outputs:
        boxes = output['boxes']
        scores = output['scores']
        masks = output['masks']

        # Apply NMS
        keep_indices = nms(boxes, scores, iou_threshold)

        # Filter predictions
        filtered_outputs.append({
            'boxes': boxes[keep_indices],
            'scores': scores[keep_indices],
            'masks': masks[keep_indices]
        })

    return filtered_outputs

In [None]:
# This sets up the main training loop. We store the weights after each epoch for comparison purposes, but only the last one is needed.
# To save space, "model_save" can be set to "False" (a checkpoint is ~300 Mb) 

# PARAMETERS (edit if needed)
num_epochs = 25               # epochs per fold
batch_size = 8
num_workers = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = models_dir         # reuse existing models_dir variable
model_save = False            # save checkpoints
save_best_only = False        # set to True to only save best val model per fold
verbose = True

# Get `data_list` from your dataset instance so we can extract groups (image_index)
if hasattr(dataset, "data") and dataset.data is not None:
    data_list = dataset.data
elif hasattr(dataset, "data_path") and dataset.data_path:
    data_list = torch.load(dataset.data_path)
else:
    raise RuntimeError("Dataset has no data loaded; please provide dataset.data or dataset.data_path")

#Build groups array: image_index must be present in crop metadata
groups = []
for i, d in enumerate(data_list):
    meta = d.get("metadata", {}) if isinstance(d, dict) else {}
    img_idx = meta.get("image_index", None)
    if img_idx is None:
        raise KeyError(f"Crop at index {i} has no metadata/image_index.")
    groups.append(int(img_idx))

# Make sure number of unique groups >= n_splits
n_unique = len(set(groups))
if n_splits > n_unique:
    raise ValueError(f"n_splits={n_splits} is larger than number of unique images ({n_unique}). Reduce n_splits.")

# Save base model state to restart from for each fold (important!)
base_state = copy.deepcopy(model.state_dict())

fold_results = []

In [None]:
# Main training loop.

for fold_idx, (train_idx, val_idx) in enumerate(splits):
    print(f"\n=== Fold {fold_idx + 1}/{n_splits} ===")
    t0_fold = time.time()

    # Restore base model weights
    model.load_state_dict(base_state)
    model.to(device)
    torch.cuda.empty_cache()

    # Build Subset objects and DataLoaders
    train_subset = Subset(dataset, train_idx)
    train_loader_fold = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        generator=g,  
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=custom_collate_fn,
        drop_last=False,)

    val_loader_fold = None
    if len(val_idx) > 0:  # Only create if validation exists
        val_subset = Subset(dataset, val_idx)
        val_loader_fold = DataLoader(
            val_subset,
            batch_size=batch_size//2,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=False,
            collate_fn=custom_collate_fn,
            drop_last=False,)

    # Recreate optimizer & scheduler
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0001)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

    best_val_loss = float("inf")
    history = {"train_loss": [], "val_loss": []}

    # Resume logic (skip val_loss if no validation)
    existing_ckpts = [f for f in os.listdir(save_dir) if f.startswith(f"fold{fold_idx+1}_epoch")]
    if existing_ckpts:
        existing_ckpts.sort(key=lambda x: int(x.split("epoch")[1].split(".")[0]))
        latest_ckpt = existing_ckpts[-1]
        ckpt_path = os.path.join(save_dir, latest_ckpt)
        print(f"Resuming from checkpoint {ckpt_path}")
        checkpoint = torch.load(ckpt_path, map_location=device)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"]
        history["train_loss"] = [checkpoint.get("train_loss", float("inf"))]
        if val_loader_fold is not None:
            history["val_loss"] = [checkpoint.get("val_loss", float("inf"))]
    else:
        print("No checkpoint found for this fold, starting fresh.")
        start_epoch = 0

    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0

        print(f" Fold {fold_idx+1} - Epoch {epoch+1}/{num_epochs}")

        for batch_idx, (images, targets) in enumerate(train_loader_fold):
            if len(images) == 0 or len(targets) == 0:
                continue

            images = [img.to(device) for img in images]
            targets = [{k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()}
                       for t in targets]

            optimizer.zero_grad()
            try:
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                losses.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
                optimizer.step()
            except Exception as e:
                print(f"  Error during train batch {batch_idx+1}: {e}")
                continue

            epoch_loss += losses.item()
            num_batches += 1

        avg_train_loss = epoch_loss / num_batches if num_batches > 0 else float("inf")
        history["train_loss"].append(avg_train_loss)
        print(f"  -> Epoch {epoch+1} train loss: {avg_train_loss:.4f}")

        # Validation (only if val_loader exists)
        avg_val_loss = None
        if val_loader_fold is not None:
            torch.cuda.empty_cache()
            model.eval()
            val_loss_accum = 0.0
            val_batches = 0

            with torch.no_grad():
                for images, targets in val_loader_fold:
                    if len(images) == 0 or len(targets) == 0:
                        continue

                    images = [img.to(device) for img in images]
                    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                    model.train()
                    try:
                        loss_dict = model(images, targets)
                        losses = sum(loss for loss in loss_dict.values())
                    except Exception as e:
                        print(f"  Error during validation: {e}")
                        continue
                    finally:
                        model.eval()

                    val_loss_accum += losses.item()
                    val_batches += 1

            avg_val_loss = val_loss_accum / val_batches if val_batches > 0 else float("inf")
            history["val_loss"].append(avg_val_loss)
            print(f"  -> Epoch {epoch+1} val loss: {avg_val_loss:.4f}")

        # Step scheduler
        scheduler.step()

        # Save checkpoint(s)
        epoch_global = epoch + 1
        if model_save:
            chkname = f"fold{fold_idx+1}_epoch{epoch_global}.pth"
            chkpath = os.path.join(save_dir, chkname)
            torch.save({
                "fold": fold_idx + 1,
                "epoch": epoch_global,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss if avg_val_loss is not None else None,
            }, chkpath)
            # optionally keep only best
            if save_best_only:
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    best_path = os.path.join(save_dir, f"fold{fold_idx+1}_best.pth")
                    torch.save(model.state_dict(), best_path)
        else:
            # if not saving all, still save last or best
            if epoch == (num_epochs - 1):
                last_path = os.path.join(save_dir, f"fold{fold_idx+1}_last.pth")
                torch.save(model.state_dict(), last_path)

        # Save history to JSON after every epoch
        hist_path = os.path.join(save_dir, f"fold{fold_idx+1}_history.json")
        with open(hist_path, "w") as f:
            json.dump(history, f)
            
    # fold finished
    fold_time = time.time() - t0_fold
    # Handle missing validation if k==1
    if history["val_loss"]:
        last_val = history["val_loss"][-1]
        best_val = min(history["val_loss"])
        print(f"Fold {fold_idx+1} finished in {fold_time/60:.2f} minutes. "
              f"Last epoch val loss: {last_val:.4f}")
    else:
        last_val = None
        best_val = None
        print(f"Fold {fold_idx+1} finished in {fold_time/60:.2f} minutes. "
              f"No validation performed for this fold.")
    
    fold_results.append({
        "fold": fold_idx + 1,
        "history": history,
        "last_val_loss": last_val,
        "best_val_loss": best_val,
    })
    
    # free GPU
    torch.cuda.empty_cache()

In [None]:
# Summary of folds
# last_val and best_val will only be different right after training, not if you reload a final model

for res in fold_results:
    fold = res["fold"]
    last_val = res["last_val_loss"]
    best_val = res["best_val_loss"]
    if best_val is not None:
        print(f"Fold {fold}: last_val={last_val:.4f}, best_val={best_val:.4f}")
    else:
        print(f"Fold {fold}: no validation performed")

# Handle k = 1 and no validation case
val_losses = [r["best_val_loss"] for r in fold_results if r["best_val_loss"] is not None]
if val_losses:
    mean_best = np.mean(val_losses)
    print(f"\nGroupKFold summary: mean best val loss across folds = {mean_best:.4f}")
else:
    print("\nNo validation folds available — training completed on full dataset.")

### (FOR REPORTING ONLY, OPTIONAL) 
### Part 4.1: calculate loss vs epoch curves
- Compares history files saved during training
- Not needed if k = 1

In [None]:
# Paths. Folders were manually created to store different version of the training

history_dir = r"C:\Users\USER\PROJECT\Models\Model\History"              # where fold histories are stored
checkpoints_dir = r"C:\Users\USER\PROJECT\Models\Model\Checkpoints"      # where .pth files are stored
best_output_dir = r"C:\Users\USER\PROJECT\Models\Model\Best Models"      # where to save cleaned best models
os.makedirs(output_dir, exist_ok=True)

train_losses_all = []
val_losses_all = []

for fold in range(1, n_splits + 1):
    # Load history json
    history_path = os.path.join(history_dir, f"fold{fold}_history.json")
    with open(history_path, "r") as f:
        history = json.load(f)

    train_losses = history["train_loss"]
    val_losses = history["val_loss"]

    train_losses_all.append(train_losses)
    val_losses_all.append(val_losses)

    # Find best epoch (lowest val loss)
    best_epoch = int(np.argmin(val_losses))
    best_val = val_losses[best_epoch]
    print(f"Fold {fold}: best epoch = {best_epoch+1}, val_loss = {best_val:.4f}")

    # Locate corresponding checkpoint
    checkpoint_path = os.path.join(checkpoints_dir, f"fold{fold}_epoch{best_epoch+1}.pth")
    if not os.path.exists(checkpoint_path):
        print(f"!! Warning: checkpoint not found for fold {fold}, epoch {best_epoch+1}")
        continue

    # Load and re-save clean version (only model state_dict)
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    if "model_state_dict" in checkpoint:
        model_state = checkpoint["model_state_dict"]
    else:
        model_state = checkpoint  # assume already state_dict

    save_path = os.path.join(best_output_dir, f"fold{fold}_best.pth")
    #torch.save(model_state, save_path)
    print(f"Saved best model for fold {fold} at {save_path}")


# Plot mean ± std of train/val losses
max_epochs = max(len(v) for v in val_losses_all)

# Pad to same length for aggregation
train_padded = [np.pad(v, (0, max_epochs - len(v)), constant_values=np.nan) for v in train_losses_all]
val_padded   = [np.pad(v, (0, max_epochs - len(v)), constant_values=np.nan) for v in val_losses_all]

train_array = np.vstack(train_padded)
val_array = np.vstack(val_padded)

train_mean = np.nanmean(train_array, axis=0)
train_std  = np.nanstd(train_array, axis=0)
val_mean   = np.nanmean(val_array, axis=0)
val_std    = np.nanstd(val_array, axis=0)

epochs = np.arange(1, max_epochs + 1)

plt.figure(figsize=(8,6))
plt.plot(epochs, train_mean, label="Train loss (mean)")
plt.fill_between(epochs, train_mean-train_std, train_mean+train_std, alpha=0.2)
plt.plot(epochs, val_mean, label="Val loss (mean)")
plt.fill_between(epochs, val_mean-val_std, val_mean+val_std, alpha=0.2)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Cross-validation loss curves")
plt.legend()
plt.grid(True)
#plt.savefig(r"C:\Users\USER\PROJECT\TRAINING\ANALYSIS")
plt.show()

### (OPTIONAL)
### Part 4.2 : Re-load model
- Once the model's weights are calculated from the training and stored, they can be quickly reloaded without going through training again.
- If the notebook is is to be used only to predict (without training), also the pre-trained model needs to be re-loaded (see Part 4)

In [None]:
# Re-load weights and check that all model's keys are matched
weights_path = r"C:\Users\USER\PROJECT\Models\Model\Checkpoints\foldX_epochY.pth" # Change to desired model to re-load
state_dict = torch.load(weights_path, map_location=device, weights_only=True) 
model.load_state_dict(state_dict['model_state_dict'], strict=False)

# Part 5: Visualize predictions
- Visualize predictions using cropped images in the test dataloader.
- If the test images were not seen during training, the calculated metrics can be used for reporting.
- Metrics can be calculated for different folds to determine their standard deviation
- The same functions can now be used to analyze any data, including fresh new SEM images from the lab. 

In [None]:
# As the training data, also the test data only needs to be run and saved once. Uncomment on first run only

#test_data = preprocess_data(test_image_dir, test_mask_dir, output_dir, crop_size=(512, 512), grid_size=(3, 5))

In [None]:
test_data_path = os.path.join(output_dir,"test_dataset_crops.pt")
test_dataset = ROIDataset(data_path=test_data_path, transform=transform)

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,   # keeps deterministic order for evaluation
    collate_fn=custom_collate_fn,
    drop_last=False
)

In [None]:
def evaluate_model(model, loader, device, iou_threshold=1, mask_threshold=0.7):
    """Compute quantitative metrics for one model and one loader."""
    model.eval()
    ious, precisions, recalls, f1s, accuracies, aucs, log_losses = [], [], [], [], [], [], []

    with torch.no_grad():
        for images, targets in loader:
            if len(images) == 0 or len(targets) == 0:
                continue

            images = [img.to(device) for img in images]
            outputs = apply_nms(model(images), iou_threshold=iou_threshold)

            for idx, output in enumerate(outputs):
                pred_masks = output["masks"].squeeze(1) > mask_threshold
                combined_pred_mask = torch.any(pred_masks, dim=0).cpu().numpy()

                true_masks = targets[idx]["masks"].cpu().numpy().astype(int)
                combined_true_mask = np.any(true_masks, axis=0).astype(int)

                # IoU
                ious.append(jaccard_score(combined_true_mask.flatten(), combined_pred_mask.flatten()))

                # Boxes → masks
                pred_boxes = output["boxes"].cpu().numpy()
                true_boxes = targets[idx]["boxes"].cpu().numpy()
                pred_box_mask = np.zeros_like(combined_true_mask, dtype=int)
                true_box_mask = np.zeros_like(combined_true_mask, dtype=int)
                for box in pred_boxes:
                    x1, y1, x2, y2 = map(int, box)
                    pred_box_mask[y1:y2, x1:x2] = 1
                for box in true_boxes:
                    x1, y1, x2, y2 = map(int, box)
                    true_box_mask[y1:y2, x1:x2] = 1

                precisions.append(precision_score(true_box_mask.flatten(), pred_box_mask.flatten(), zero_division=0))
                recalls.append(recall_score(true_box_mask.flatten(), pred_box_mask.flatten(), zero_division=0))
                f1s.append(f1_score(true_box_mask.flatten(), pred_box_mask.flatten(), zero_division=0))
                accuracies.append(accuracy_score(true_box_mask.flatten(), pred_box_mask.flatten()))

                # Probabilistic metrics
                raw_pred_probs = output["masks"].squeeze(1).cpu().numpy()
                combined_prob_map = np.max(raw_pred_probs, axis=0) if raw_pred_probs.shape[0] > 0 else np.zeros_like(combined_true_mask)
                combined_prob_map = np.clip(combined_prob_map, 1e-7, 1 - 1e-7)
                flat_true = combined_true_mask.flatten()
                flat_pred = combined_prob_map.flatten()

                if np.any(flat_true) and np.any(flat_true == 0):
                    try:
                        aucs.append(roc_auc_score(flat_true, flat_pred))
                        log_losses.append(log_loss(flat_true, flat_pred))
                    except ValueError:
                        aucs.append(np.nan)
                        log_losses.append(np.nan)
                else:
                    aucs.append(np.nan)
                    log_losses.append(np.nan)

    return {
        "IoU": np.array(ious),
        "Precision": np.array(precisions),
        "Recall": np.array(recalls),
        "F1": np.array(f1s),
        "Accuracy": np.array(accuracies),
        "AUC_ROC": np.array(aucs),
        "LogLoss": np.array(log_losses)
    }

def visualize_predictions(model, loader, device, mask_threshold=0.7, score_threshold=0.5):
    """Visualize predictions and masks for qualitative inspection."""
    model.eval()
    with torch.no_grad():
        for images, targets in loader:
            if len(images) == 0 or len(targets) == 0:
                continue
            images = [img.to(device) for img in images]
            outputs = apply_nms(model(images), iou_threshold=1)

            for idx, output in enumerate(outputs):
                img_pil = F.to_pil_image(images[idx].cpu())
                pred_masks = output["masks"].squeeze(1) > mask_threshold
                combined_pred_mask = torch.any(pred_masks, dim=0).cpu().numpy()
                pred_boxes = output["boxes"].cpu().numpy()
                scores = output["scores"].cpu().numpy()

                fig, ax = plt.subplots(1, 2, figsize=(15, 7))
                ax[0].imshow(img_pil, cmap="gray")
                ax[0].set_title("Original")
                ax[0].axis("off")

                ax[1].imshow(img_pil, alpha=0.8)
                ax[1].imshow(combined_pred_mask, cmap="gray", alpha=0.5)
                for box, score in zip(pred_boxes, scores):
                    if score > score_threshold:
                        x1, y1, x2, y2 = box
                        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                             edgecolor="red", facecolor="none", linewidth=2)
                        ax[1].add_patch(rect)
                        ax[1].text(x1, y1 - 5, f"{score:.2f}", color="red",
                                   bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"))
                ax[1].set_title("Predicted Masks + Boxes")
                ax[1].axis("off")
                plt.show()

def evaluate_and_visualize(model, loader, device, best_output_dir=None, checkpoints_dir=None, history_dir=None,
                           n_splits=1, visualize=False, visualize_best_fold=False):
    """
    Unified evaluation:
      - n_splits > 1: evaluate all folds and summarize.
      - n_splits == 1: evaluate single model.
    If visualize_best_fold=True (for k>1), visualize predictions from the best fold.
    """
    all_results = []

    if n_splits > 1:
        print(f"\nEvaluating {n_splits}-fold models...\n")
        best_fold, best_val_loss = None, float("inf")

        for fold in range(1, n_splits + 1):
            hist_path = os.path.join(history_dir, f"fold{fold}_history.json")
            best_path = os.path.join(best_output_dir, f"fold{fold}_best.pth")
            
            # Fallback: if best model not found, use last checkpoint
            if not os.path.exists(best_path):
                ckpts = sorted(
                    [f for f in os.listdir(checkpoints_dir) if f.startswith(f"fold{fold}_epoch")],
                    key=lambda x: int(x.split("epoch")[1].split(".")[0])
                )
                if ckpts:
                    best_path = os.path.join(checkpoints_dir, ckpts[-1])            # 🔹 use checkpoints_dir for fallback
                    print(f"No best model found for fold {fold}, using last checkpoint: {ckpts[-1]}")
            
            # Retrieve best validation loss if available
            if os.path.exists(hist_path):
                with open(hist_path, "r") as f:
                    hist = json.load(f)
                    val_losses = hist.get("val_loss", [])
                    if len(val_losses) > 0:
                        min_val = min(val_losses)
                        if min_val < best_val_loss:
                            best_val_loss = min_val
                            best_fold = fold

            # Load model for metrics computation
            model.load_state_dict(torch.load(best_path, map_location=device))
            model.to(device)
            metrics = evaluate_model(model, loader, device)
            all_results.append(metrics)

        # Aggregate results
        summary = {}
        for key in all_results[0].keys():
            values = [np.nanmean(res[key]) for res in all_results]
            summary[key] = {"mean": np.nanmean(values), "std": np.nanstd(values)}

        print("\nCross-validation test metrics (mean ± std):")
        for metric, stats in summary.items():
            print(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")

        # Visualize the best fold if requested
        if visualize_best_fold and best_fold is not None:
            print(f"\nBest fold: {best_fold} (val_loss={best_val_loss:.4f})")
            best_path = os.path.join(best_output_dir, f"fold{best_fold}_best.pth")
            model.load_state_dict(torch.load(best_path, map_location=device, weights_only=True), strict=False)
            model.to(device)
            visualize_predictions(model, loader, device)

        return summary

    else:
        print("\nEvaluating single model (production mode)...")
        metrics = evaluate_model(model, loader, device)
        print("\nSingle-model metrics:")
        for k, v in metrics.items():
            print(f"{k}: {np.nanmean(v):.4f}")

        if visualize:
            print("\nVisualizing predictions...")
            visualize_predictions(model, loader, device)

        return metrics

In [None]:
if n_splits == 1:
    metrics = evaluate_and_visualize(model, test_loader, device,
                                     n_splits=n_splits,
                                     visualize=True)
elif n_splits > 1:
    results_summary = evaluate_and_visualize(model, test_loader, device,
                                             n_splits=n_splits,
                                             visualize_best_fold=True,
                                             best_output_dir=best_output_dir,
                                             checkpoints_dir=checkpoints_dir,
                                             history_dir=history_dir)

# Full-image predictions
### - After loading a model (off-the-shelf or fine-tuned), this is the code actually used for predicting etch pits.

In [None]:
# Full-image prediction via cropping and reconstruction
# Pixel_size was previously determined and inserted manually in the metadata

test_dir = r"C:\Users\USER\PROJECT\TRAINING\ANALYSIS\Images-to-predict"
test_dir_pred = r"C:\Users\USER\PROJECT\TRAINING\ANALYSIS\Images-to-predict\Masks_predictions"
test_images_pred = r"C:\Users\USER\PROJECT\TRAINING\ANALYSIS\Images-to-predict\Images_predictions"
save_prediction_csv = False     # Save the pits data as CSV file 
save_prediction_images = False  # Save the final reconstructed image with overlaid predictions as a PNG file
save_prediction_masks = False   # Save the summed prediction masks as a PNG file

test_paths = sorted(glob.glob(os.path.join(test_dir, "image_*")))
for test_image_path in test_paths:

    # Load the test image
    sum_final_boxes = []
    sum_final_center_x = []
    sum_final_center_y = []
    sum_final_scores = []
    sum_final_masks = []
    sum_sides = []
    measurements_unit = "µm"
    test_image_idx = os.path.basename(test_image_path).split("_")[1].split(".")[0]  # Extract 'i' from 'image_i'
    test_image = Image.open(test_image_path).convert("I")
    x_res = test_image.info["resolution"][0] # Check image metadata format
    pixel_size_x = 1 / x_res
    
    # Convert to tensor and normalize
    test_image_tensor = transforms.ToTensor()(test_image)
    test_image_tensor = normalize_image(test_image_tensor)
    
    # Add batch dimension
    test_image_tensor = test_image_tensor.unsqueeze(0).to(device)
    
    test_crops = split_image_and_targets(test_image_tensor[0],[],[],[], grid_size=(3, 5), crop_size=(512, 512))
    
    # Initialize blank canvas for reconstruction
    H, W = test_image_tensor.shape[2:]
    reconstructed_mask = np.zeros((H, W), dtype=np.float32)
    all_boxes = []
    all_scores = []
    all_masks = []
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        for idx, crop in enumerate(test_crops):
            # Prepare the crop
            image = crop['image']
            image = image.unsqueeze(0).to(device)  # Add batch and channel dimensions
            y_start, x_start = crop['metadata']['y_start'], crop['metadata']['x_start']
    
            # Model predictions
            test_outputs = model(image)
            #test_outputs = apply_nms(test_outputs, iou_threshold=0.)
    
            # Extract predictions
            output = test_outputs[0]
            masks = output['masks'].squeeze(1).cpu().numpy() if 'masks' in output else []
            boxes = output['boxes'].cpu().numpy() if 'boxes' in output else []
            scores = output['scores'].cpu().numpy() if 'scores' in output else []
    
            # Store predictions for merging
            for mask, box, score in zip(masks, boxes, scores):
                if score > 0.5:  # Confidence threshold
                    # Adjust bounding boxes to original coordinates
                    x1, y1, x2, y2 = box
                    adjusted_box = [x1 + x_start, y1 + y_start, x2 + x_start, y2 + y_start]
                    all_boxes.append(adjusted_box)
                    all_scores.append(score)
    
                    # Place mask back in original coordinates
                    adjusted_mask = np.zeros((H, W), dtype=np.float32)
                    y_slice = slice(y_start, y_start + mask.shape[0])
                    x_slice = slice(x_start, x_start + mask.shape[1])
                    adjusted_mask[y_slice, x_slice] = mask
                    all_masks.append(adjusted_mask)
    
    # Convert lists to tensors for NMS
    all_boxes = torch.tensor(all_boxes, dtype=torch.float32)
    all_scores = torch.tensor(all_scores, dtype=torch.float32)
    if len(all_boxes) > 0:
        # Apply NMS to filter overlapping boxes
        keep_indices = nms(all_boxes, all_scores, iou_threshold=0.3)
    
        # Filter boxes, scores, and masks based on NMS results
        final_boxes = all_boxes[keep_indices].numpy()
        final_scores = all_scores[keep_indices].numpy()
        final_masks = [all_masks[i] for i in keep_indices]
    
        # Combine masks into the reconstructed image
        for mask in final_masks:
            reconstructed_mask = np.maximum(reconstructed_mask, mask)
    
        # Plot summed predicted masks
        if save_prediction_masks == True:
            plt.imsave((os.path.join(test_dir_pred,f"image_{test_image_idx}.png")), reconstructed_mask, cmap='gray')
        plt.figure(figsize=(10, 10))
        plt.imshow(test_image_tensor.cpu().squeeze(), cmap='gray', alpha=0.5)
        plt.imshow(reconstructed_mask, cmap='gray', alpha=0.5)
        
        for box in final_boxes:
            x1, y1, x2, y2 = box
            center_x = x1 + (x2 - x1)/2
            center_y = y1 + (y2 - y1)/2 
            rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, edgecolor='red', facecolor='none', linewidth=2)
            plt.gca().add_patch(rect)
            area = 0.5 * (x2 - x1) * (y2 - y1) * (pixel_size_x**2)
            size = np.sqrt((4/np.sqrt(3))*area)
            sum_final_boxes.append(box)
            sum_final_center_x.append(center_x)
            sum_final_center_y.append(center_y)
            sum_sides.append(size)
        
        # Plot the reconstructed images with overlaid predictions. Plot pits size histograms
        sum_dict = {'x':sum_final_center_x, 'y':sum_final_center_y, 'side': sum_sides}
        df = pd.DataFrame(sum_dict) 
        if save_prediction_csv == True:
            save_path_data = os.path.join(test_images_pred, f"image_{test_image_idx}.csv")
            df.to_csv(save_path_data)
        plt.axis('off')  # Turn off the axis for better visualization
        if save_prediction_images == True:
            save_path_img = os.path.join(test_images_pred, f"image_{test_image_idx}.png")
            plt.savefig(save_path_img, bbox_inches='tight', pad_inches=0)                                    
        plt.figure(figsize=(10, 6))
        plt.hist(sum_sides, bins=np.arange(0, 0.45 + 0.02, 0.02), color='blue', alpha=0.7, edgecolor='black', weights=np.ones_like(sum_sides) / len(sum_sides))
        plt.title("Histogram of image "+str(test_image_idx), fontsize=16)
        plt.xlabel("Size (µm)", fontsize=14)
        plt.ylabel("Relative frequency", fontsize=14)
        ax = plt.gca()
        ax.set_xlim([0, 0.45])
        ax.set_ylim([0, 0.45])
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.show()

### End of notebook