In [None]:
!pip install -q segmentation-models-pytorch

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import segmentation_models_pytorch as smp
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import random


In [None]:
class Config:
    # -- Data Paths --
    DRIVE_PATH = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/"
    IMAGE_DIR = os.path.join(DRIVE_PATH, "images")
    MASK_DIR = os.path.join(DRIVE_PATH, "masks")

    # -- Output Paths --
    # Folders to save the predicted masks
    OUTPUT_DIR = os.path.join(DRIVE_PATH, "outputs")
    # Grayscale masks (for loss calculation & metrics)
    OUTPUT_MASK_DIR = os.path.join(OUTPUT_DIR, "pred_masks")
    # Colorized masks (for easy visualization)
    COLOR_MASK_DIR = os.path.join(OUTPUT_DIR, "color_masks")

    # -- Model Hyperparameters --
    # 'unet' architecture
    ARCHITECTURE = 'unet'
    # 'resnet34' as encoder
    ENCODER = 'resnet34'
    ENCODER_WEIGHTS = 'imagenet'
    LEARNING_RATE = 1e-4
    # 'DiceLoss'
    LOSS_FUNCTION = 'DiceLoss'
    OPTIMIZER = 'AdamW'
    #TRAIN_SIZE + VAL_SIZE + TEST_SIZE <= total images.
    TRAIN_SIZE = 412  # Number of images for training
    VAL_SIZE = 51    # Number of images for validation
    TEST_SIZE = 52   # Number of images for final testing
    # -- Training Settings --
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 4
    NUM_EPOCHS = 25
    IMAGE_HEIGHT = 256
    IMAGE_WIDTH = 256
    NUM_CLASSES = 5

    # -- Visualization --
    # This map is used ONLY for creating the colorized masks for visualization
    COLOR_MAP = {
        0: (0, 0, 0),        # background (black)
        1: (0, 255, 0),      # stem (green)
        2: (255, 255, 0),    # leaf (yellow)
        3: (139, 69, 19),    # root (brown)
        4: (255, 255, 255),  # seed (white)
    }

# Create output directories if they don't exist
os.makedirs(Config.OUTPUT_MASK_DIR, exist_ok=True)
os.makedirs(Config.COLOR_MASK_DIR, exist_ok=True)


In [None]:
class PlantDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_filenames, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = image_filenames

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)
        mask_name = os.path.splitext(img_name)[0] + "_mask.png"
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask.long()

# Define augmentations.
!pip install -q albumentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Augmentations for the training set
train_transform = A.Compose([
    A.Resize(height=Config.IMAGE_HEIGHT, width=Config.IMAGE_WIDTH),
    A.Rotate(limit=35, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

# For validation, we only resize and normalize
val_transform = A.Compose([
    A.Resize(height=Config.IMAGE_HEIGHT, width=Config.IMAGE_WIDTH),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

In [None]:
def mask_to_rgb(mask_tensor, color_map):
    """Converts a segmentation mask (tensor) to a colorized RGB image."""
    mask = mask_tensor.cpu().numpy().squeeze()
    rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        rgb_mask[mask == class_idx] = color
    return Image.fromarray(rgb_mask)

In [None]:
def save_predictions_fn(loader, model, folder_basename=""):
    """Saves model predictions on a given dataset to the disk."""
    print(f"\n--- Saving predictions for {folder_basename} set ---")
    model.eval()

    # Create specific subdirectories for train/val/test predictions
    output_mask_dir = os.path.join(Config.OUTPUT_MASK_DIR, folder_basename)
    color_mask_dir = os.path.join(Config.COLOR_MASK_DIR, folder_basename)
    os.makedirs(output_mask_dir, exist_ok=True)
    os.makedirs(color_mask_dir, exist_ok=True)

    for idx, (img_tensor, _) in enumerate(tqdm(loader.dataset, desc=f"Saving {folder_basename} Predictions")):
        with torch.no_grad():
            # The dataset returns single images, so we add a batch dimension
            img_tensor = img_tensor.to(Config.DEVICE).unsqueeze(0)
            preds = model(img_tensor)
            final_mask = torch.argmax(preds, dim=1).squeeze(0)

        # Save the raw integer mask
        pred_mask_img = Image.fromarray(final_mask.cpu().numpy().astype(np.uint8))
        original_filename = loader.dataset.images[idx]
        pred_mask_img.save(os.path.join(output_mask_dir, original_filename))

        # Save the color mask
        color_mask_img = mask_to_rgb(final_mask, Config.COLOR_MAP)
        color_mask_img.save(os.path.join(color_mask_dir, original_filename))
    model.train()

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    """The training loop for one epoch."""
    loop = tqdm(loader, desc="Training")
    total_loss = 0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=Config.DEVICE)
        targets = targets.to(device=Config.DEVICE).unsqueeze(1)

        # Forward pass
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    return total_loss / len(loader)

def eval_fn(loader, model, loss_fn):
    """The evaluation loop."""
    model.eval()
    total_loss = 0
    loop = tqdm(loader, desc="Validation")

    with torch.no_grad():
        for data, targets in loop:
            data = data.to(device=Config.DEVICE)
            targets = targets.to(device=Config.DEVICE).unsqueeze(1)
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            total_loss += loss.item()
            loop.set_postfix(val_loss=loss.item())

    model.train()
    return total_loss / len(loader)

In [None]:
def main():
    all_entries = sorted(os.listdir(Config.IMAGE_DIR))
    all_images = [entry for entry in all_entries if os.path.isfile(os.path.join(Config.IMAGE_DIR, entry))]

    random.seed(42) # for reproducibility
    random.shuffle(all_images)

    total_size = Config.TRAIN_SIZE + Config.VAL_SIZE + Config.TEST_SIZE
    if total_size > len(all_images):
        raise ValueError("Sum of split sizes is larger than the total number of images!")

    train_files = all_images[:Config.TRAIN_SIZE]
    val_files = all_images[Config.TRAIN_SIZE : Config.TRAIN_SIZE + Config.VAL_SIZE]
    test_files = all_images[Config.TRAIN_SIZE + Config.VAL_SIZE : total_size]

    print(f"Total images (after filtering): {len(all_images)}")
    print(f"Training set size: {len(train_files)}")
    print(f"Validation set size: {len(val_files)}")
    print(f"Test set size: {len(test_files)}")
    # --- Create Model ---
    model = smp.create_model(
        arch=Config.ARCHITECTURE,
        encoder_name=Config.ENCODER,
        encoder_weights=Config.ENCODER_WEIGHTS,
        in_channels=3,
        classes=Config.NUM_CLASSES,
    ).to(Config.DEVICE)

    # --- Select Loss Function ---
    if Config.LOSS_FUNCTION == 'DiceLoss':
        loss_fn = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
    elif Config.LOSS_FUNCTION == 'FocalLoss':
        loss_fn = smp.losses.FocalLoss(mode='multiclass')
    else: # Default to CrossEntropy
        loss_fn = nn.CrossEntropyLoss()

    # --- 3. Select Optimizer ---
    if Config.OPTIMIZER == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    else: # Default to Adam
        optimizer = torch.optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)

    # --- Create Datasets and Dataloaders ---
    train_dataset = PlantDataset(
        image_dir=Config.IMAGE_DIR, mask_dir=Config.MASK_DIR,
        image_filenames=train_files, transform=train_transform
    )
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)

    val_dataset = PlantDataset(
        image_dir=Config.IMAGE_DIR, mask_dir=Config.MASK_DIR,
        image_filenames=val_files, transform=val_transform
    )
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)

    test_dataset = PlantDataset(
        image_dir=Config.IMAGE_DIR, mask_dir=Config.MASK_DIR,
        image_filenames=test_files, transform=val_transform
    )
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) # Batch size 1 for saving

    # --- Start Training ---
    scaler = torch.cuda.amp.GradScaler()
    best_val_loss = float('inf')

    for epoch in range(Config.NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{Config.NUM_EPOCHS} ---")
        train_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)
        val_loss = eval_fn(val_loader, model, loss_fn)

        print(f"Average Train Loss: {train_loss:.4f}")
        print(f"Average Val Loss: {val_loss:.4f}")

        # Save model if validation loss improves
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(Config.DRIVE_PATH, "best_model.pth"))
            print("=> Saved new best model")

    # --- Save Predictions After Training ---
    print("\n--- Saving predictions ---")
    model.load_state_dict(torch.load(os.path.join(Config.DRIVE_PATH, "best_model.pth")))
    model.eval()

    # Save predictions for the validation set
    save_predictions_fn(val_loader, model, folder_basename="validation")
    # Save predictions for the test set
    save_predictions_fn(test_loader, model, folder_basename="test")

In [None]:
if __name__ == "__main__":
    main()

Total images (after filtering): 515
Training set size: 412
Validation set size: 51
Test set size: 52


  scaler = torch.cuda.amp.GradScaler()



--- Epoch 1/25 ---


  with torch.cuda.amp.autocast():
Training: 100%|██████████| 103/103 [00:43<00:00,  2.35it/s, loss=0.775]
Validation: 100%|██████████| 13/13 [00:05<00:00,  2.53it/s, val_loss=0.766]


Average Train Loss: 0.8293
Average Val Loss: 0.7719
=> Saved new best model

--- Epoch 2/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.66it/s, loss=0.624]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.06it/s, val_loss=0.605]


Average Train Loss: 0.6892
Average Val Loss: 0.6201
=> Saved new best model

--- Epoch 3/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 15.71it/s, loss=0.457]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.87it/s, val_loss=0.495]


Average Train Loss: 0.5581
Average Val Loss: 0.4926
=> Saved new best model

--- Epoch 4/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 15.68it/s, loss=0.426]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.36it/s, val_loss=0.455]


Average Train Loss: 0.4611
Average Val Loss: 0.4328
=> Saved new best model

--- Epoch 5/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 15.96it/s, loss=0.479]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.83it/s, val_loss=0.384]


Average Train Loss: 0.3980
Average Val Loss: 0.3768
=> Saved new best model

--- Epoch 6/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 15.68it/s, loss=0.371]
Validation: 100%|██████████| 13/13 [00:00<00:00, 28.18it/s, val_loss=0.263]


Average Train Loss: 0.3385
Average Val Loss: 0.2831
=> Saved new best model

--- Epoch 7/25 ---


Training: 100%|██████████| 103/103 [00:08<00:00, 12.36it/s, loss=0.283]
Validation: 100%|██████████| 13/13 [00:00<00:00, 24.80it/s, val_loss=0.264]


Average Train Loss: 0.2797
Average Val Loss: 0.2596
=> Saved new best model

--- Epoch 8/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.64it/s, loss=0.222]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.34it/s, val_loss=0.214]


Average Train Loss: 0.2490
Average Val Loss: 0.2376
=> Saved new best model

--- Epoch 9/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 15.87it/s, loss=0.188]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.80it/s, val_loss=0.198]


Average Train Loss: 0.2370
Average Val Loss: 0.2249
=> Saved new best model

--- Epoch 10/25 ---


Training: 100%|██████████| 103/103 [00:08<00:00, 12.43it/s, loss=0.239]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.26it/s, val_loss=0.187]


Average Train Loss: 0.2214
Average Val Loss: 0.2040
=> Saved new best model

--- Epoch 11/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.40it/s, loss=0.238]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.86it/s, val_loss=0.216]


Average Train Loss: 0.2191
Average Val Loss: 0.2230

--- Epoch 12/25 ---


Training: 100%|██████████| 103/103 [00:08<00:00, 12.62it/s, loss=0.254]
Validation: 100%|██████████| 13/13 [00:00<00:00, 28.04it/s, val_loss=0.22]


Average Train Loss: 0.2111
Average Val Loss: 0.2177

--- Epoch 13/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 14.92it/s, loss=0.197]
Validation: 100%|██████████| 13/13 [00:00<00:00, 28.11it/s, val_loss=0.186]


Average Train Loss: 0.2082
Average Val Loss: 0.2079

--- Epoch 14/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.27it/s, loss=0.172]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.86it/s, val_loss=0.218]


Average Train Loss: 0.2021
Average Val Loss: 0.2011
=> Saved new best model

--- Epoch 15/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.57it/s, loss=0.203]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.27it/s, val_loss=0.194]


Average Train Loss: 0.1976
Average Val Loss: 0.1981
=> Saved new best model

--- Epoch 16/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.08it/s, loss=0.276]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.86it/s, val_loss=0.209]


Average Train Loss: 0.1990
Average Val Loss: 0.2023

--- Epoch 17/25 ---


Training: 100%|██████████| 103/103 [00:08<00:00, 12.70it/s, loss=0.267]
Validation: 100%|██████████| 13/13 [00:00<00:00, 28.87it/s, val_loss=0.191]


Average Train Loss: 0.1945
Average Val Loss: 0.1914
=> Saved new best model

--- Epoch 18/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.12it/s, loss=0.219]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.55it/s, val_loss=0.186]


Average Train Loss: 0.1908
Average Val Loss: 0.1998

--- Epoch 19/25 ---


Training: 100%|██████████| 103/103 [00:08<00:00, 12.54it/s, loss=0.195]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.55it/s, val_loss=0.185]


Average Train Loss: 0.1898
Average Val Loss: 0.1949

--- Epoch 20/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.59it/s, loss=0.18]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.79it/s, val_loss=0.19]


Average Train Loss: 0.1876
Average Val Loss: 0.1952

--- Epoch 21/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.53it/s, loss=0.171]
Validation: 100%|██████████| 13/13 [00:00<00:00, 28.19it/s, val_loss=0.188]


Average Train Loss: 0.1858
Average Val Loss: 0.1971

--- Epoch 22/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.54it/s, loss=0.163]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.80it/s, val_loss=0.171]


Average Train Loss: 0.1813
Average Val Loss: 0.1827
=> Saved new best model

--- Epoch 23/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.51it/s, loss=0.189]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.55it/s, val_loss=0.193]


Average Train Loss: 0.1836
Average Val Loss: 0.2199

--- Epoch 24/25 ---


Training: 100%|██████████| 103/103 [00:08<00:00, 12.72it/s, loss=0.157]
Validation: 100%|██████████| 13/13 [00:00<00:00, 28.06it/s, val_loss=0.164]


Average Train Loss: 0.1801
Average Val Loss: 0.1854

--- Epoch 25/25 ---


Training: 100%|██████████| 103/103 [00:06<00:00, 16.81it/s, loss=0.13]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.68it/s, val_loss=0.163]


Average Train Loss: 0.1794
Average Val Loss: 0.1911

--- Saving predictions ---

--- Saving predictions for validation set ---


Saving validation Predictions: 100%|██████████| 51/51 [00:01<00:00, 35.47it/s]



--- Saving predictions for test set ---


Saving test Predictions: 100%|██████████| 52/52 [00:06<00:00,  8.23it/s]


UNSQUEEZING

In [None]:
import os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm

def resize_saved_masks(
    original_image_dir,
    predicted_mask_dir,
    predicted_color_mask_dir,
    output_dir,
    output_color_dir
):
    """
    Resizes saved data masks and color masks back to their original dimensions.

    Args:
    - original_image_dir (str): Directory containing original images.
    - predicted_mask_dir (str): Directory containing predicted masks.
    - predicted_color_mask_dir (str): Directory containing predicted color masks.
    """
    print(f"Resizing masks from: {predicted_mask_dir}")
    print(f"Saving unsqueezed masks to: {output_dir}")
    print(f"Resizing color masks from: {predicted_color_mask_dir}")
    print(f"Saving unsqueezed color masks to: {output_color_dir}")

    # Create the output directories if they don't exist
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(output_color_dir, exist_ok=True)

    pred_masks = os.listdir(predicted_mask_dir)

    for mask_filename in tqdm(pred_masks, desc="Resizing all masks"):
        mask_basename = os.path.splitext(mask_filename)[0]

        # Find the corresponding original image to get its dimensions
        original_img_filename = ""
        for fname in os.listdir(original_image_dir):
            if fname.startswith(mask_basename):
                original_img_filename = fname
                break

        if not original_img_filename:
            print(f"  - Warning: Could not find original image for mask: {mask_filename}")
            continue

        try:
            # --- Get original dimensions ---
            original_img_path = os.path.join(original_image_dir, original_img_filename)
            original_img = Image.open(original_img_path)
            original_w, original_h = original_img.size

            # --- 1. Process the DATA MASK ---
            pred_mask_path = os.path.join(predicted_mask_dir, mask_filename)
            predicted_mask = cv2.imread(pred_mask_path, cv2.IMREAD_UNCHANGED)
            resized_mask = cv2.resize(
                predicted_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST
            )
            output_mask_path = os.path.join(output_dir, mask_filename)
            cv2.imwrite(output_mask_path, resized_mask)

            # --- 2. Process the COLOR MASK --- ## <<< ADDED
            pred_color_path = os.path.join(predicted_color_mask_dir, mask_filename)
            if os.path.exists(pred_color_path):
                predicted_color_mask = cv2.imread(pred_color_path)
                # Use a smoother interpolation for better visual quality
                resized_color_mask = cv2.resize(
                    predicted_color_mask, (original_w, original_h), interpolation=cv2.INTER_LINEAR
                )
                output_color_path = os.path.join(output_color_dir, mask_filename)
                cv2.imwrite(output_color_path, resized_color_mask)

        except Exception as e:
            print(f"  - Error processing {mask_filename}: {e}")

    print("Done resizing all masks!")


ORIGINAL_IMAGES_DIR = Config.IMAGE_DIR
SAVED_OUTPUTS_DIR = Config.OUTPUT_DIR
UNSQUEEZED_OUTPUTS_DIR = os.path.join(Config.OUTPUT_DIR, "Unsqueezed")

SAVED_PRED_MASKS_DIR = os.path.join(SAVED_OUTPUTS_DIR, "pred_masks")

SAVED_COLOR_MASKS_DIR = os.path.join(SAVED_OUTPUTS_DIR, "color_masks")

UNSQUEEZED_MASKS_DIR = os.path.join(UNSQUEEZED_OUTPUTS_DIR, "pred_masks")

UNSQUEEZED_COLOR_MASKS_DIR = os.path.join(UNSQUEEZED_OUTPUTS_DIR, "color_masks")

resize_saved_masks(
    original_image_dir=ORIGINAL_IMAGES_DIR,
    predicted_mask_dir=os.path.join(SAVED_PRED_MASKS_DIR, "test"),
    predicted_color_mask_dir=os.path.join(SAVED_COLOR_MASKS_DIR, "test"),
    output_dir=os.path.join(UNSQUEEZED_MASKS_DIR, "test"),
    output_color_dir=os.path.join(UNSQUEEZED_COLOR_MASKS_DIR, "test")
)

resize_saved_masks(
     original_image_dir=ORIGINAL_IMAGES_DIR,
     predicted_mask_dir=os.path.join(SAVED_PRED_MASKS_DIR, "validation"),
     predicted_color_mask_dir=os.path.join(SAVED_COLOR_MASKS_DIR, "validation"),
     output_dir=os.path.join(UNSQUEEZED_MASKS_DIR, "validation"),
     output_color_dir=os.path.join(UNSQUEEZED_COLOR_MASKS_DIR, "validation")
)

Analysis of above model's performance:

In [None]:
# =================================================================================
# 0. SETUP AND IMPORTS
# =================================================================================
!pip install -q monai pandas

import os
import cv2
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from monai.metrics import (
    compute_dice,
    compute_iou,
    compute_hausdorff_distance
)

# =================================================================================
# 1. CONFIGURATION
# =================================================================================
# Path to your ORIGINAL ground truth masks
GT_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/masks/"

# Path to the UNSQUEEZED predicted masks from your test set
PRED_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs/Unsqueezed/pred_masks/test/"

# --- CLASS DEFINITIONS ---
NUM_CLASSES = 5
CLASS_MAP = {
    0: "Background",
    1: "Stem",
    2: "Leaf",
    3: "Root",
    4: "Seed",
}

# =================================================================================
# 2. HELPER FUNCTION (Converts masks to one-hot format)
# =================================================================================
def to_one_hot(mask, num_classes):
    """
    Converts a (H, W) segmentation mask with class indices
    to a (C, H, W) one-hot encoded tensor.
    """
    # (H, W, C)
    one_hot = np.eye(num_classes)[mask]
    # (C, H, W)
    one_hot = np.transpose(one_hot, (2, 0, 1))
    return torch.from_numpy(one_hot)

# =================================================================================
# 3. MAIN ANALYSIS LOOP
# =================================================================================

print("Starting analysis...")
print(f"GT Directory: {GT_MASK_DIR}")
print(f"Pred Directory: {PRED_MASK_DIR}")

# A list to store the metric results for each image
results_list = []

# Get the list of predicted files
pred_files = os.listdir(PRED_MASK_DIR)

for filename in tqdm(pred_files):
    # Construct full paths
    pred_path = os.path.join(PRED_MASK_DIR, filename)

    # We assume the GT mask has a slightly different name
    # e.g., 'Rep1_0%_mask.png' -> 'Rep1_0%.png'
    # --- ADJUST THIS LOGIC IF YOURS IS DIFFERENT ---
    gt_filename = filename.replace(".jpg", "_mask.png") # Adjust as needed
    gt_path = os.path.join(GT_MASK_DIR, gt_filename)

    # --- Load Masks ---
    # Load ground truth mask
    if not os.path.exists(gt_path):
        print(f"Warning: Missing GT for {filename}, skipping.")
        continue

    gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
    # Load predicted mask
    pred_mask = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
    if gt_mask is not None:
        gt_mask[gt_mask >= NUM_CLASSES] = 0
    if pred_mask is not None:
        pred_mask[pred_mask >= NUM_CLASSES] = 0
    # --- Sanity Check ---
    if gt_mask.shape != pred_mask.shape:
        print(f"Warning: Shape mismatch for {filename}, skipping.")
        print(f"  GT Shape: {gt_mask.shape}, Pred Shape: {pred_mask.shape}")
        continue

    # --- Convert to One-Hot Format ---
    # The metrics functions expect (Batch, Classes, H, W)
    gt_onehot = to_one_hot(gt_mask, NUM_CLASSES).unsqueeze(0) # (1, C, H, W)
    pred_onehot = to_one_hot(pred_mask, NUM_CLASSES).unsqueeze(0) # (1, C, H, W)

    # --- Calculate Metrics ---
    # These functions are class-aware and will return a score for each
    # class. We include background to keep indices consistent.

    # Dice and IOU (higher is better, 0-1)
    # Returns a tensor of shape (1, C)
    dice_scores = compute_dice(pred_onehot, gt_onehot, include_background=True)
    iou_scores = compute_iou(pred_onehot, gt_onehot, include_background=True)

    # Hausdorff Distance 95th Percentile (lower is better, in pixels)
    # HD95 is more robust to outliers than the standard Hausdorff.
    hd95_scores = compute_hausdorff_distance(
        pred_onehot,
        gt_onehot,
        include_background=True,
        percentile=95
    )

    # --- Store Results ---
    # Store per-class metrics in a dictionary
    file_metrics = {'filename': filename}
    for i in range(NUM_CLASSES):
        class_name = CLASS_MAP[i]

        # .item() converts the tensor value to a plain Python number
        file_metrics[f"{class_name}_Dice"] = dice_scores[0, i].item()
        file_metrics[f"{class_name}_IOU"] = iou_scores[0, i].item()

        # HD is NaN if a class is missing from both pred and GT.
        # We'll store it as is.
        file_metrics[f"{class_name}_HD95"] = hd95_scores[0, i].item()

    results_list.append(file_metrics)

print("Analysis complete.")

# =================================================================================
# 4. REPORTING (The "Deliverable")
# =================================================================================

# Convert the list of results into a pandas DataFrame
df = pd.DataFrame(results_list)
df.set_index('filename', inplace=True)

# --- 1. Show a sample of the full results table ---
print("\n--- Full Results Table (Sample) ---")
print(df.head())

# --- 2. Show the overall average statistics ---
# .mean() will automatically (and correctly) ignore NaNs
overall_stats = df.mean()

print("\n\n--- Overall Average Statistics (Test Set) ---")
print("This is the main result. We will use this table for our discussion.")

# Reshape the data for a cleaner summary table
summary_data = []
for i in range(NUM_CLASSES):
    class_name = CLASS_MAP[i]
    summary_data.append({
        "Class": class_name,
        "Dice (↑)": overall_stats.get(f"{class_name}_Dice"),
        "IOU (↑)": overall_stats.get(f"{class_name}_IOU"),
        "HD95 (↓)": overall_stats.get(f"{class_name}_HD95"),
    })

summary_df = pd.DataFrame(summary_data)
print(summary_df.to_markdown(index=False, floatfmt=".4f"))

Starting analysis...
GT Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/masks/
Pred Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs/Unsqueezed/pred_masks/test/


100%|██████████| 52/52 [00:19<00:00,  2.63it/s]

Analysis complete.

--- Full Results Table (Sample) ---
                                              Background_Dice  Background_IOU  \
filename                                                                        
Rep1_0%Sucrose_gaut10-3gaut11-3+_19.jpg              0.984967        0.970380   
Rep1_0%Sucrose_gaut3-1gaut11-3_29.jpg                0.991010        0.982180   
Rep2_0.5%Sucrose_gaut10-3_8.jpg                      0.987785        0.975866   
Rep2_0.5%Sucrose_gaut10-3gaut11-3_5.jpg              0.990434        0.981049   
Rep1_0%Sucrose_gaut3-1gaut10-3gaut11-3_8.jpg         0.995549        0.991138   

                                              Background_HD95  Stem_Dice  \
filename                                                                   
Rep1_0%Sucrose_gaut10-3gaut11-3+_19.jpg              7.565374   0.093120   
Rep1_0%Sucrose_gaut3-1gaut11-3_29.jpg                5.385165   0.097307   
Rep2_0.5%Sucrose_gaut10-3_8.jpg                      7.364026   0.269670


