Installs:

In [None]:
!pip install -q segmentation-models-pytorch albumentations monai

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m81.1 MB/s[0m eta [36m0:00:00[0m
[?25h

Imports:

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from tqdm import tqdm
import random

# Import advanced loss functions
from monai.losses import DiceLoss, HausdorffDTLoss



In [None]:
class Config:
    # -- Base Paths --
    # Points to the new COMBINED dataset (Original + Synthetic)
    BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/"
    DATA_PATH = os.path.join(BASE_PATH, "combined_data")

    IMAGE_DIR = os.path.join(DATA_PATH, "images")
    MASK_DIR = os.path.join(DATA_PATH, "masks")

    # Path to the Quality Control CSV
    QC_REPORT_CSV = os.path.join(BASE_PATH, "missing_classes_from_mask.csv")

    # -- Output Paths --
    OUTPUT_DIR = os.path.join(BASE_PATH, "outputs_v3") # V3 Output
    OUTPUT_MASK_DIR = os.path.join(OUTPUT_DIR, "pred_masks_unsqueezed")
    COLOR_MASK_DIR = os.path.join(OUTPUT_DIR, "color_masks_unsqueezed")

    # -- Data Split Sizes --
    # Since we are filtering data, the total available images might change (diff of 49 images).
    VAL_SIZE = 51
    TEST_SIZE = 52
    # Train size will be whatever is left

    # -- Model Hyperparameters --
    ARCHITECTURE = 'unet'
    ENCODER = 'resnet34'
    ENCODER_WEIGHTS = 'imagenet'
    LEARNING_RATE = 1e-4
    OPTIMIZER = 'AdamW'

    # -- Training Settings --
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 4
    NUM_EPOCHS = 30
    IMAGE_HEIGHT = 256
    IMAGE_WIDTH = 256
    NUM_CLASSES = 5

    # -- LOSS CONFIGURATION (Successful V2 Settings) --
    # Weights: [Background, Stem, Leaf, Root, Seed]
    # We use the moderate weights that worked well
    CLASS_WEIGHTS = torch.tensor([
        1.0,  # Background
        7.0,  # Stem
        5.0,  # Leaf
        5.0,  # Root
        7.0   # Seed
    ], device=DEVICE)

    # Weights for the hybrid loss (Pure Weighted Dice)
    HYBRID_WEIGHT_DICE = 1.0
    HYBRID_WEIGHT_HD = 0.0

    # -- Visualization --
    COLOR_MAP = {
        0: (0, 0, 0),         # background - black
        1: (139, 69, 19),     # stem - brown
        2: (0, 255, 0),       # leaf - green
        3: (255, 255, 0),     # root - yellow
        4: (255, 0, 0),       # seed - red
    }

os.makedirs(Config.OUTPUT_MASK_DIR, exist_ok=True)
os.makedirs(Config.COLOR_MASK_DIR, exist_ok=True)

Loads all images, excludes those marked FALSE in the CSV, and splits the rest.

In [None]:
def get_filtered_data_splits(image_dir, qc_csv_path, val_size, test_size):
    # 1. Get all available images in the combined folder
    all_images = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))])
    print(f"Total images found in folder: {len(all_images)}")

    # 2. Load the unsafe image list
    excluded_filenames = set()
    if os.path.exists(qc_csv_path):
        df = pd.read_csv(qc_csv_path)
        # We look for rows where Mask_correct is FALSE (boolean or string)
        bad_rows = df[df['Mask_correct'].astype(str).str.upper() == 'FALSE']
        excluded_filenames = set(bad_rows['filename'].tolist())
        print(f"Found {len(excluded_filenames)} images marked as FALSE in CSV to exclude.")
    else:
        print("Warning:CSV not found. Proceeding without filtering.")

    # 3. Filter the list
    valid_images = []
    for img_name in all_images:
        # We need to check the MASK name against the blacklist
        # Assuming standard naming: image.jpg -> image_mask.png
        mask_name = os.path.splitext(img_name)[0] + "_mask.png"

        if mask_name in excluded_filenames:
            continue # Skip this image

        valid_images.append(img_name)

    print(f"Total Valid Images for Training: {len(valid_images)}")

    # 4. Shuffle and Split
    random.seed(42)
    random.shuffle(valid_images)

    # Prioritize Test and Val sets, give rest to Train
    if len(valid_images) < (val_size + test_size):
        raise ValueError("Not enough valid images to create Validation and Test sets!")

    test_files = valid_images[:test_size]
    val_files = valid_images[test_size : test_size + val_size]
    train_files = valid_images[test_size + val_size :]

    print(f"Training set size: {len(train_files)}")
    print(f"Validation set size: {len(val_files)}")
    print(f"Test set size: {len(test_files)}")

    return train_files, val_files, test_files

In [None]:
class PlantDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_filenames, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = image_filenames

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)
        mask_name = os.path.splitext(img_name)[0] + "_mask.png"
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # Get original dimensions for unsqueezing
        original_height, original_width = image.shape[:2]

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Clip mask to be safe (0 to 4)
        mask[mask >= Config.NUM_CLASSES] = 0

        return image, mask.long(), (original_height, original_width)

# Augmentations
train_transform = A.Compose([
    A.Resize(height=Config.IMAGE_HEIGHT, width=Config.IMAGE_WIDTH),
    A.Rotate(limit=35, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=Config.IMAGE_HEIGHT, width=Config.IMAGE_WIDTH),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

In [None]:
class WeightedHybridLoss(nn.Module):
    def __init__(self, class_weights, weight_dice=1.0, weight_hd=0.0):
        super(WeightedHybridLoss, self).__init__()
        self.w_dice = weight_dice
        self.w_hd = weight_hd
        self.class_weights = class_weights

        self.dice_loss = DiceLoss(
            softmax=True, to_onehot_y=True, include_background=True, reduction='none'
        )
        self.hd_loss = HausdorffDTLoss(
            softmax=True, to_onehot_y=True, include_background=True, reduction='none'
        )

    def forward(self, preds_logits, targets_idx):
        targets_idx = targets_idx.unsqueeze(1)

        loss_dice_per_class = self.dice_loss(preds_logits, targets_idx)

        # Apply weights to Dice
        weighted_loss_dice_all = loss_dice_per_class * self.class_weights
        weighted_loss_dice = weighted_loss_dice_all.mean()

        total_loss = self.w_dice * weighted_loss_dice

        # Add HD loss if weight > 0 (Currently 0 in Config)
        if self.w_hd > 0:
            loss_hd_per_class = self.hd_loss(preds_logits, targets_idx)
            weighted_loss_hd_all = loss_hd_per_class * self.class_weights
            weighted_loss_hd = weighted_loss_hd_all.mean()
            total_loss += (self.w_hd * weighted_loss_hd)

        return total_loss

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader, desc="Training")
    total_loss = 0

    for batch_idx, (data, targets, _) in enumerate(loop):
        data = data.to(device=Config.DEVICE)
        targets = targets.to(device=Config.DEVICE)

        with torch.amp.autocast('cuda'):
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    return total_loss / len(loader)

In [None]:
def eval_fn(loader, model, loss_fn):
    model.eval()
    total_loss = 0
    loop = tqdm(loader, desc="Validation")

    with torch.no_grad():
        for data, targets, _ in loop:
            data = data.to(device=Config.DEVICE)
            targets = targets.to(device=Config.DEVICE)
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            total_loss += loss.item()
            loop.set_postfix(val_loss=loss.item())

    model.train()
    return total_loss / len(loader)

In [None]:
def mask_to_rgb(mask_tensor, color_map):
    mask = mask_tensor.cpu().numpy().squeeze()
    rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        rgb_mask[mask == class_idx] = color
    return Image.fromarray(rgb_mask)

In [None]:
def save_predictions_fn(loader, model, folder_basename=""):
    print(f"\n--- Saving predictions for {folder_basename} set ---")
    model.eval()

    # Define sub-folders for output
    output_mask_dir = os.path.join(Config.OUTPUT_MASK_DIR, folder_basename)
    color_mask_dir = os.path.join(Config.COLOR_MASK_DIR, folder_basename)
    os.makedirs(output_mask_dir, exist_ok=True)
    os.makedirs(color_mask_dir, exist_ok=True)

    for idx in tqdm(range(len(loader.dataset)), desc=f"Saving {folder_basename} Predictions"):
        img_tensor, _, (original_h, original_w) = loader.dataset[idx]

        with torch.no_grad():
            img_tensor = img_tensor.to(Config.DEVICE).unsqueeze(0)
            preds = model(img_tensor)
            final_mask_tensor = torch.argmax(preds, dim=1).squeeze(0)

        pred_mask_np = final_mask_tensor.cpu().numpy().astype(np.uint8)

        # Unsqueeze
        resized_mask = cv2.resize(
            pred_mask_np,
            (original_w, original_h),
            interpolation=cv2.INTER_NEAREST
        )

        # Save Raw Mask
        pred_mask_img = Image.fromarray(resized_mask)
        original_filename = loader.dataset.images[idx]
        mask_filename = os.path.splitext(original_filename)[0] + "_mask.png"
        pred_mask_img.save(os.path.join(output_mask_dir, mask_filename))

        # Save Color Mask
        color_mask_img = mask_to_rgb(torch.from_numpy(resized_mask), Config.COLOR_MAP)
        color_mask_img.save(os.path.join(color_mask_dir, mask_filename))

    model.train()

In [None]:
def main():
    print(f"Using device: {Config.DEVICE}")
    print(f"Dataset Path: {Config.DATA_PATH}")

    # --- 1. Get Filtered Splits ---
    train_files, val_files, test_files = get_filtered_data_splits(
        Config.IMAGE_DIR, Config.QC_REPORT_CSV, Config.VAL_SIZE, Config.TEST_SIZE
    )

    # --- 2. Create Loaders ---
    train_dataset = PlantDataset(Config.IMAGE_DIR, Config.MASK_DIR, train_files, train_transform)
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)

    val_dataset = PlantDataset(Config.IMAGE_DIR, Config.MASK_DIR, val_files, val_transform)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)

    test_dataset = PlantDataset(Config.IMAGE_DIR, Config.MASK_DIR, test_files, val_transform)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # --- 3. Model & Loss ---
    model = smp.create_model(
        arch=Config.ARCHITECTURE,
        encoder_name=Config.ENCODER,
        encoder_weights=Config.ENCODER_WEIGHTS,
        in_channels=3,
        classes=Config.NUM_CLASSES,
    ).to(Config.DEVICE)

    print("\n--- Initializing V3 Loss Function ---")
    loss_fn = WeightedHybridLoss(
        class_weights=Config.CLASS_WEIGHTS,
        weight_dice=Config.HYBRID_WEIGHT_DICE,
        weight_hd=Config.HYBRID_WEIGHT_HD
    )
    print(f"Class Weights: {Config.CLASS_WEIGHTS.cpu().numpy()}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scaler = torch.amp.GradScaler('cuda')
    best_val_loss = float('inf')

    # --- 4. Training Loop ---
    for epoch in range(Config.NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{Config.NUM_EPOCHS} ---")
        train_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)
        val_loss = eval_fn(val_loader, model, loss_fn)

        print(f"Average Train Loss: {train_loss:.4f}")
        print(f"Average Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(Config.BASE_PATH, "best_model_v3.pth"))
            print("=> Saved new best model")

    # --- 5. Save Test Predictions ---
    print("\n--- Loading best model for final testing ---")
    model.load_state_dict(torch.load(os.path.join(Config.BASE_PATH, "best_model_v3.pth")))
    save_predictions_fn(test_loader, model, folder_basename="test_set")
    print("--- V3 Training Complete ---")

In [None]:
if __name__ == "__main__":
    main()

Using device: cuda
Dataset Path: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/combined_data
Total images found in folder: 595
Found 32 images marked as FALSE in CSV to exclude.
Total Valid Images for Training: 563
Training set size: 460
Validation set size: 51
Test set size: 52


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]


--- Initializing V3 Loss Function ---
Class Weights: [1. 7. 5. 5. 7.]

--- Epoch 1/30 ---


Training: 100%|██████████| 115/115 [06:21<00:00,  3.32s/it, loss=3.7]
Validation: 100%|██████████| 13/13 [00:39<00:00,  3.05s/it, val_loss=4.04]


Average Train Loss: 4.3135
Average Val Loss: 3.7911
=> Saved new best model

--- Epoch 2/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.33it/s, loss=3.41]
Validation: 100%|██████████| 13/13 [00:00<00:00, 24.52it/s, val_loss=3.17]


Average Train Loss: 3.5842
Average Val Loss: 3.0804
=> Saved new best model

--- Epoch 3/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 13.85it/s, loss=3.1]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.31it/s, val_loss=2.85]


Average Train Loss: 3.0366
Average Val Loss: 2.8079
=> Saved new best model

--- Epoch 4/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 13.76it/s, loss=2.64]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.56it/s, val_loss=2.64]


Average Train Loss: 2.8407
Average Val Loss: 2.7043
=> Saved new best model

--- Epoch 5/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.25it/s, loss=2.71]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.26it/s, val_loss=2.63]


Average Train Loss: 2.7526
Average Val Loss: 2.6528
=> Saved new best model

--- Epoch 6/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.09it/s, loss=2.67]
Validation: 100%|██████████| 13/13 [00:00<00:00, 24.31it/s, val_loss=2.47]


Average Train Loss: 2.6647
Average Val Loss: 2.5042
=> Saved new best model

--- Epoch 7/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.40it/s, loss=2.29]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.95it/s, val_loss=1.96]


Average Train Loss: 2.3966
Average Val Loss: 2.2028
=> Saved new best model

--- Epoch 8/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.25it/s, loss=2.11]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.79it/s, val_loss=1.86]


Average Train Loss: 2.1797
Average Val Loss: 2.1293
=> Saved new best model

--- Epoch 9/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.10it/s, loss=1.92]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.88it/s, val_loss=1.8]


Average Train Loss: 2.0932
Average Val Loss: 2.0099
=> Saved new best model

--- Epoch 10/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.26it/s, loss=2.25]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.27it/s, val_loss=1.73]


Average Train Loss: 2.0366
Average Val Loss: 2.0159

--- Epoch 11/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.42it/s, loss=1.91]
Validation: 100%|██████████| 13/13 [00:00<00:00, 23.96it/s, val_loss=1.74]


Average Train Loss: 1.9944
Average Val Loss: 1.9860
=> Saved new best model

--- Epoch 12/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 13.95it/s, loss=1.99]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.48it/s, val_loss=1.75]


Average Train Loss: 1.9748
Average Val Loss: 1.9690
=> Saved new best model

--- Epoch 13/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.06it/s, loss=2.12]
Validation: 100%|██████████| 13/13 [00:00<00:00, 24.38it/s, val_loss=1.55]


Average Train Loss: 1.9137
Average Val Loss: 1.7627
=> Saved new best model

--- Epoch 14/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.17it/s, loss=1.87]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.97it/s, val_loss=1.55]


Average Train Loss: 1.7910
Average Val Loss: 1.5326
=> Saved new best model

--- Epoch 15/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.36it/s, loss=1.81]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.68it/s, val_loss=1.14]


Average Train Loss: 1.6321
Average Val Loss: 1.4085
=> Saved new best model

--- Epoch 16/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.36it/s, loss=1.51]
Validation: 100%|██████████| 13/13 [00:00<00:00, 24.98it/s, val_loss=1.11]


Average Train Loss: 1.5207
Average Val Loss: 1.4077
=> Saved new best model

--- Epoch 17/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.23it/s, loss=1.38]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.96it/s, val_loss=1.44]


Average Train Loss: 1.3084
Average Val Loss: 1.3009
=> Saved new best model

--- Epoch 18/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.31it/s, loss=0.993]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.31it/s, val_loss=0.925]


Average Train Loss: 1.2645
Average Val Loss: 1.2515
=> Saved new best model

--- Epoch 19/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.15it/s, loss=1.17]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.86it/s, val_loss=1.04]


Average Train Loss: 1.2489
Average Val Loss: 1.2832

--- Epoch 20/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 12.79it/s, loss=1.27]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.32it/s, val_loss=0.853]


Average Train Loss: 1.2013
Average Val Loss: 1.1155
=> Saved new best model

--- Epoch 21/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 13.47it/s, loss=1.39]
Validation: 100%|██████████| 13/13 [00:00<00:00, 23.80it/s, val_loss=0.869]


Average Train Loss: 1.1883
Average Val Loss: 1.1284

--- Epoch 22/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.26it/s, loss=1.1]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.14it/s, val_loss=0.852]


Average Train Loss: 1.1773
Average Val Loss: 1.1518

--- Epoch 23/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 13.88it/s, loss=1.64]
Validation: 100%|██████████| 13/13 [00:00<00:00, 26.29it/s, val_loss=0.904]


Average Train Loss: 1.1559
Average Val Loss: 1.1125
=> Saved new best model

--- Epoch 24/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 13.57it/s, loss=1.53]
Validation: 100%|██████████| 13/13 [00:00<00:00, 24.04it/s, val_loss=0.852]


Average Train Loss: 1.1436
Average Val Loss: 1.1276

--- Epoch 25/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.16it/s, loss=0.84]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.37it/s, val_loss=0.748]


Average Train Loss: 1.1343
Average Val Loss: 1.1131

--- Epoch 26/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.05it/s, loss=1.37]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.75it/s, val_loss=0.728]


Average Train Loss: 1.1207
Average Val Loss: 1.1060
=> Saved new best model

--- Epoch 27/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 13.54it/s, loss=1]
Validation: 100%|██████████| 13/13 [00:00<00:00, 24.99it/s, val_loss=0.804]


Average Train Loss: 1.1175
Average Val Loss: 1.0860
=> Saved new best model

--- Epoch 28/30 ---


Training: 100%|██████████| 115/115 [00:10<00:00, 11.20it/s, loss=1.36]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.58it/s, val_loss=0.874]


Average Train Loss: 1.1062
Average Val Loss: 1.1363

--- Epoch 29/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 13.93it/s, loss=1]
Validation: 100%|██████████| 13/13 [00:00<00:00, 27.18it/s, val_loss=1.04]


Average Train Loss: 1.1011
Average Val Loss: 1.0977

--- Epoch 30/30 ---


Training: 100%|██████████| 115/115 [00:08<00:00, 14.29it/s, loss=1.22]
Validation: 100%|██████████| 13/13 [00:00<00:00, 25.32it/s, val_loss=0.872]


Average Train Loss: 1.0881
Average Val Loss: 1.0662
=> Saved new best model

--- Loading best model for final testing ---

--- Saving predictions for test_set set ---


Saving test_set Predictions: 100%|██████████| 52/52 [00:50<00:00,  1.02it/s]

--- V3 Training Complete ---





In [None]:
# =================================================================================
# 0. SETUP AND IMPORTS
# =================================================================================
!pip install -q monai pandas

import os
import cv2
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from monai.metrics import (
    compute_dice,
    compute_iou,
    compute_hausdorff_distance
)

# =================================================================================
# 1. CONFIGURATION
# =================================================================================
# --- PATHS FOR V3 ---
# Ground Truth Masks (The COMBINED dataset masks)
GT_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/combined_data/masks/"

# Predicted Masks (From V3 Output)
PRED_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs_v3/pred_masks_unsqueezed/test_set/"

# --- CORRECTED CLASS MAP ---
NUM_CLASSES = 5
CLASS_MAP = {
    0: "Background",
    1: "Stem",
    2: "Leaf",
    3: "Root",
    4: "Seed",
}

# =================================================================================
# 2. HELPER FUNCTIONS
# =================================================================================
def to_one_hot(mask, num_classes):
    """Converts a (H, W) mask to (1, C, H, W) one-hot tensor."""
    # Clip values just in case
    mask[mask >= num_classes] = 0

    one_hot = np.eye(num_classes)[mask] # (H, W, C)
    one_hot = np.transpose(one_hot, (2, 0, 1)) # (C, H, W)
    return torch.from_numpy(one_hot).unsqueeze(0) # (1, C, H, W)

# =================================================================================
# 3. MAIN ANALYSIS LOOP
# =================================================================================
def run_analysis():
    print("Starting V3 Analysis...")
    print(f"GT Directory: {GT_MASK_DIR}")
    print(f"Pred Directory: {PRED_MASK_DIR}")

    results_list = []
    pred_files = [f for f in os.listdir(PRED_MASK_DIR) if f.endswith('.png')]

    if len(pred_files) == 0:
        print("Error: No prediction files found! Check your PRED_MASK_DIR path.")
        return

    for filename in tqdm(pred_files):
        pred_path = os.path.join(PRED_MASK_DIR, filename)

        # Logic to find GT file (Assumes names match exactly or close to it)
        # V3 script saves predictions as "original_name_mask.png"
        gt_path = os.path.join(GT_MASK_DIR, filename)

        if not os.path.exists(gt_path):
            print(f"Skipping {filename}: GT mask not found.")
            continue

        # Load Masks
        gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
        pred_mask = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)

        if gt_mask is None or pred_mask is None:
            continue

        # Resize GT if dimensions don't match (Safety check for unsqueezing issues)
        if gt_mask.shape != pred_mask.shape:
            # We assume prediction is correct size (unsqueezed), so we resize GT to match
            # This handles cases where GT might be slightly off due to pre-processing
            gt_mask = cv2.resize(gt_mask, (pred_mask.shape[1], pred_mask.shape[0]), interpolation=cv2.INTER_NEAREST)

        # Convert to One-Hot
        gt_onehot = to_one_hot(gt_mask, NUM_CLASSES)
        pred_onehot = to_one_hot(pred_mask, NUM_CLASSES)

        # --- Calculate Metrics ---
        dice = compute_dice(pred_onehot, gt_onehot, include_background=True)
        iou = compute_iou(pred_onehot, gt_onehot, include_background=True)
        hd95 = compute_hausdorff_distance(pred_onehot, gt_onehot, include_background=True, percentile=95)

        # Store results
        file_metrics = {'filename': filename}
        for i in range(NUM_CLASSES):
            c_name = CLASS_MAP[i]
            file_metrics[f"{c_name}_Dice"] = dice[0, i].item()
            file_metrics[f"{c_name}_IOU"] = iou[0, i].item()
            file_metrics[f"{c_name}_HD95"] = hd95[0, i].item()

        results_list.append(file_metrics)

    # =================================================================================
    # 4. REPORTING
    # =================================================================================
    if not results_list:
        print("No results generated.")
        return

    df = pd.DataFrame(results_list)

    # Calculate Averages
    overall_stats = df.mean(numeric_only=True)

    print("\n\n--- V3 Overall Average Statistics (Test Set) ---")
    summary_data = []
    for i in range(NUM_CLASSES):
        c_name = CLASS_MAP[i]
        summary_data.append({
            "Class": c_name,
            "Dice (↑)": overall_stats.get(f"{c_name}_Dice"),
            "IOU (↑)": overall_stats.get(f"{c_name}_IOU"),
            "HD95 (↓)": overall_stats.get(f"{c_name}_HD95"),
        })

    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_markdown(index=False, floatfmt=".4f"))

    # Optional: Save to CSV for comparison later
    summary_df.to_csv("/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/v3_results_summary.csv", index=False)
    print("\nSummary saved to v3_results_summary.csv")

if __name__ == "__main__":
    run_analysis()

Starting V3 Analysis...
GT Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/combined_data/masks/
Pred Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs_v3/pred_masks_unsqueezed/test_set/


100%|██████████| 52/52 [00:04<00:00, 12.43it/s]



--- V3 Overall Average Statistics (Test Set) ---
| Class      |   Dice (↑) |   IOU (↑) |   HD95 (↓) |
|:-----------|-----------:|----------:|-----------:|
| Background |     0.9942 |    0.9885 |     7.1458 |
| Stem       |     0.7149 |    0.5774 |     7.1714 |
| Leaf       |     0.8360 |    0.7231 |     7.7421 |
| Root       |     0.8103 |    0.6908 |    10.4944 |
| Seed       |     0.6350 |    0.5143 |    20.8207 |

Summary saved to v3_results_summary.csv



