In [1]:
# =================================================================================
# 0. SETUP AND IMPORTS
# =================================================================================
# Force upgrade to fix potential torch attribute errors
!pip install -q segmentation-models-pytorch albumentations monai

import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from tqdm import tqdm
import random

from monai.losses import DiceLoss



[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m90.3 MB/s[0m eta [36m0:00:00[0m
[?25h



In [10]:
# =================================================================================
# 1. CONFIGURATION (V5: SPLIT LEAF)
# =================================================================================
class Config:
    # -- Base Paths --
    BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/"

    # Path to ORIGINAL images (JPGs)
    # We use the combined_data images folder
    IMAGE_DIR = os.path.join(BASE_PATH, "images/")

    # Path to NEW CSV MASKS
    # Update this to where you uploaded the "split_masks" folder
    MASK_CSV_DIR = os.path.join(BASE_PATH, "augmented masks v3/split_masks")
    # QC Report & Splits
    QC_REPORT_CSV = os.path.join(BASE_PATH, "missing_classes_from_mask.csv")
    SPLIT_CSV = os.path.join(BASE_PATH, "dataset_split.csv")

    # -- Output Paths --
    OUTPUT_DIR = os.path.join(BASE_PATH, "outputs_v5_split_leaf_final")
    OUTPUT_MASK_DIR = os.path.join(OUTPUT_DIR, "pred_masks")
    COLOR_MASK_DIR = os.path.join(OUTPUT_DIR, "color_masks")

    # -- Model Hyperparameters --
    ARCHITECTURE = 'unetplusplus'
    ENCODER = 'resnet34'
    ENCODER_WEIGHTS = 'imagenet'
    LEARNING_RATE = 1e-4
    OPTIMIZER = 'AdamW'

    # -- Training Settings --
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 4
    NUM_EPOCHS = 40

    IMAGE_HEIGHT = 512
    IMAGE_WIDTH = 512

    # 0=Bg, 1=Root, 2=Unused, 3=Stem, 4=Seed, 5=Left, 6=Right
    NUM_CLASSES = 7

    # -- LOSS CONFIGURATION --
    # Weights mapped to the new indices
    CLASS_WEIGHTS = torch.tensor([
        1.0,  # 0: Background
        10.0,  # 1: Root
        0.0,  # 2: UNUSED (Weight 0 so model doesn't care)
        10.0,  # 3: Stem
        15.0, # 4: Seed
        7.0,  # 5: Left Leaf
        7.0   # 6: Right Leaf
    ], device=DEVICE)

    # -- Visualization --
    COLOR_MAP = {
        0: (0, 0, 0),        # Bg
        1: (255, 255, 0),    # Root (Yellow)
        2: (0, 0, 0),        # Unused
        3: (139, 69, 19),    # Stem (Brown)
        4: (255, 0, 0),      # Seed (Red)
        5: (0, 255, 0),      # Left Leaf (Green)
        6: (0, 128, 0),      # Right Leaf (Dark Green)
    }

os.makedirs(Config.OUTPUT_MASK_DIR, exist_ok=True)
os.makedirs(Config.COLOR_MASK_DIR, exist_ok=True)



In [14]:
def get_splits_v5(image_dir, mask_dir, split_csv_path, qc_csv_path):
    print("--- Configuring V5 Data Splits ---")
    if os.path.exists(qc_csv_path):
        df_qc = pd.read_csv(qc_csv_path)
        # Find rows where Mask_correct is explicitly FALSE
        bad_rows = df_qc[df_qc['Mask_correct'].astype(str).str.upper() == 'FALSE']
        excluded_mask_filenames = set(bad_rows['filename'].tolist())
        print(f"QC Report loaded. Found {len(excluded_mask_filenames)} bad masks to exclude.")
    else:
        print("Warning: QC CSV not found. No masks will be blacklisted.")

    # Load Split Map
    split_map = {}
    if os.path.exists(split_csv_path):
        try:
            df_split = pd.read_csv(split_csv_path)
        except:
            df_split = pd.read_excel(split_csv_path.replace('.csv', '.xlsx'))
        for idx, row in df_split.iterrows():
            split_map[row['img_name']] = row['set'].lower().strip()

    all_images = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))])

    train_files, val_files, test_files = [], [], []

    for img_name in all_images:
      # Construct mask name for QC check
        mask_name = os.path.splitext(img_name)[0] + "_mask.png"

        # --- CHECK 1: Is it mislabeled? ---
        if mask_name in excluded_mask_filenames:
            continue # SKIP THIS IMAGE COMPLETELY
        # Check if Mask file exists (CSV or TSV)
        base_name = os.path.splitext(img_name)[0]
        possible_names = [
            base_name + ".csv", base_name + "_mask.csv",
            base_name + ".tsv", base_name + "_mask.tsv"
        ]

        mask_path = None
        for name in possible_names:
            p = os.path.join(mask_dir, name)
            if os.path.exists(p):
                mask_path = p
                break

        if mask_path is None:
            continue

        # Split assignment
        if img_name in split_map:
            assigned_set = split_map[img_name]
            if assigned_set == 'train': train_files.append(img_name)
            elif assigned_set == 'val': val_files.append(img_name)
            elif assigned_set == 'test': test_files.append(img_name)
        else:
            train_files.append(img_name) # Synthetic/Extra to Train

    print(f"Train: {len(train_files)} | Val: {len(val_files)} | Test: {len(test_files)}")
    return train_files, val_files, test_files


In [4]:
class RobustMaskDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_filenames, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = image_filenames

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)

        # Load Image
        image = np.array(Image.open(img_path).convert("RGB"))

        # Load Mask (Try CSV/TSV)
        base_name = os.path.splitext(img_name)[0]
        possible_names = [
            base_name + ".csv", base_name + "_mask.csv",
            base_name + ".tsv", base_name + "_mask.tsv"
        ]
        mask_path = None
        for name in possible_names:
            p = os.path.join(self.mask_dir, name)
            if os.path.exists(p):
                mask_path = p
                break

        # --- ROBUST LOADING LOGIC ---
        if mask_path is None:
            mask = np.zeros(image.shape[:2], dtype=np.float32)
        else:
            try:
                # Determine delimiter based on extension or sniffing
                delimiter = ',' if mask_path.endswith('.csv') else '\t'

                # If reading fails with delimiter, try whitespace ' '
                try:
                    df = pd.read_csv(mask_path, header=None, sep=delimiter)
                    # Check if it loaded as one column (parsing error)
                    if df.shape[1] == 1:
                         # Fallback to space separated
                         df = pd.read_csv(mask_path, header=None, sep='\s+')
                except:
                    # Final fallback: space/tab/arbitrary whitespace
                    df = pd.read_csv(mask_path, header=None, sep='\s+')

                mask = df.values.astype(np.float32)
            except Exception as e:
                print(f"Error reading mask {mask_path}: {e}")
                mask = np.zeros(image.shape[:2], dtype=np.float32)

        # --- SMART PADDING LOGIC ---
        # The mask is the "Ground Truth" dimension.
        # If image is smaller, pad it to match mask.
        h_img, w_img = image.shape[:2]
        h_mask, w_mask = mask.shape[:2]

        if (h_img != h_mask) or (w_img != w_mask):
            # Create a padded canvas for the image
            padded_image = np.zeros((h_mask, w_mask, 3), dtype=np.uint8)

            # Paste image at top-left (0,0)
            # Clip dimensions to avoid errors if image is somehow larger
            h_paste = min(h_img, h_mask)
            w_paste = min(w_img, w_mask)

            padded_image[:h_paste, :w_paste, :] = image[:h_paste, :w_paste, :]
            image = padded_image

            # NOTE: We assume top-left alignment based on typical canvas generation.

        original_height, original_width = h_mask, w_mask

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask.long(), (original_height, original_width)


  df = pd.read_csv(mask_path, header=None, sep='\s+')
  df = pd.read_csv(mask_path, header=None, sep='\s+')


In [5]:
class WeightedDiceLoss(nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = class_weights
        self.dice_loss = DiceLoss(softmax=True, to_onehot_y=True, include_background=True, reduction='none')

    def forward(self, preds, targets):
        targets = targets.unsqueeze(1)
        loss_per_class = self.dice_loss(preds, targets)
        weighted_loss = loss_per_class * self.class_weights
        return weighted_loss.mean()

def mask_to_rgb(mask_tensor, color_map):
    mask = mask_tensor.cpu().numpy().squeeze()
    rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        rgb_mask[mask == class_idx] = color
    return Image.fromarray(rgb_mask)

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader, desc="Training")
    total_loss = 0
    for batch_idx, (data, targets, _) in enumerate(loop):
        data = data.to(device=Config.DEVICE)
        targets = targets.to(device=Config.DEVICE)
        with torch.amp.autocast('cuda'):
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    return total_loss / len(loader)

def eval_fn(loader, model, loss_fn):
    model.eval()
    total_loss = 0
    loop = tqdm(loader, desc="Validation")
    with torch.no_grad():
        for data, targets, _ in loop:
            data = data.to(device=Config.DEVICE)
            targets = targets.to(device=Config.DEVICE)
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            total_loss += loss.item()
            loop.set_postfix(val_loss=loss.item())
    model.train()
    return total_loss / len(loader)

def save_predictions_fn(loader, model, folder_basename=""):
    print(f"\n--- Saving predictions for {folder_basename} set ---")
    model.eval()
    output_mask_dir = os.path.join(Config.OUTPUT_MASK_DIR, folder_basename)
    color_mask_dir = os.path.join(Config.COLOR_MASK_DIR, folder_basename)
    os.makedirs(output_mask_dir, exist_ok=True)
    os.makedirs(color_mask_dir, exist_ok=True)

    for idx in tqdm(range(len(loader.dataset))):
        img_tensor, _, (original_h, original_w) = loader.dataset[idx]
        with torch.no_grad():
            img_tensor = img_tensor.to(Config.DEVICE).unsqueeze(0)
            preds = model(img_tensor)
            final_mask_tensor = torch.argmax(preds, dim=1).squeeze(0)

        pred_mask_np = final_mask_tensor.cpu().numpy().astype(np.uint8)
        resized_mask = cv2.resize(pred_mask_np, (original_w, original_h), interpolation=cv2.INTER_NEAREST)

        original_filename = loader.dataset.images[idx]
        name_only = os.path.splitext(original_filename)[0]

        Image.fromarray(resized_mask).save(os.path.join(output_mask_dir, name_only + "_mask.png"))
        mask_to_rgb(torch.from_numpy(resized_mask), Config.COLOR_MAP).save(os.path.join(color_mask_dir, name_only + "_mask.png"))
    model.train()

In [6]:
# Augmentations
train_transform = A.Compose([
    A.Resize(height=Config.IMAGE_HEIGHT, width=Config.IMAGE_WIDTH),
    A.Rotate(limit=35, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=Config.IMAGE_HEIGHT, width=Config.IMAGE_WIDTH),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

In [15]:
def main():
    print(f"Using device: {Config.DEVICE}")

    # 1. Splits
    train_files, val_files, test_files = get_splits_v5(Config.IMAGE_DIR, Config.MASK_CSV_DIR, Config.SPLIT_CSV, Config.QC_REPORT_CSV)

    if len(train_files) == 0:
        print("Error: No matching mask files found! Check MASK_CSV_DIR path.")
        return

    # 2. Loaders
    train_dataset = RobustMaskDataset(Config.IMAGE_DIR, Config.MASK_CSV_DIR, train_files, train_transform)
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
    val_dataset = RobustMaskDataset(Config.IMAGE_DIR, Config.MASK_CSV_DIR, val_files, val_transform)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)
    test_dataset = RobustMaskDataset(Config.IMAGE_DIR, Config.MASK_CSV_DIR, test_files, val_transform)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # 3. Model
    model = smp.create_model(
        arch=Config.ARCHITECTURE,
        encoder_name=Config.ENCODER,
        encoder_weights=Config.ENCODER_WEIGHTS,
        in_channels=3,
        classes=Config.NUM_CLASSES
    ).to(Config.DEVICE)

    loss_fn = WeightedDiceLoss(class_weights=Config.CLASS_WEIGHTS)
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scaler = torch.amp.GradScaler('cuda')
    best_val_loss = float('inf')

    # 4. Train
    for epoch in range(Config.NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{Config.NUM_EPOCHS} ---")
        train_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)
        val_loss = eval_fn(val_loader, model, loss_fn)
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(Config.BASE_PATH, "best_model_v5_robust.pth"))
            print("=> Saved new best model")

    # 5. Save
    print("\n--- Testing Best Model ---")
    model.load_state_dict(torch.load(os.path.join(Config.BASE_PATH, "best_model_v5_robust.pth")))
    save_predictions_fn(test_loader, model, folder_basename="test_set")
    print("--- V5 Robust Training Complete ---")
if __name__ == "__main__":
    main()

Using device: cuda
--- Configuring V5 Data Splits ---
QC Report loaded. Found 32 bad masks to exclude.
Train: 389 | Val: 47 | Test: 47


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]


--- Epoch 1/40 ---


Training: 100%|██████████| 98/98 [05:56<00:00,  3.64s/it, loss=6.33]
Validation: 100%|██████████| 12/12 [00:42<00:00,  3.57s/it, val_loss=6.24]


Train Loss: 6.5468 | Val Loss: 6.1245
=> Saved new best model

--- Epoch 2/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.20it/s, loss=6.01]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.27it/s, val_loss=5.8]


Train Loss: 6.0280 | Val Loss: 5.5089
=> Saved new best model

--- Epoch 3/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.24it/s, loss=4.96]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.33it/s, val_loss=5.14]


Train Loss: 5.2069 | Val Loss: 4.5043
=> Saved new best model

--- Epoch 4/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.21it/s, loss=4.46]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.29it/s, val_loss=4.5]


Train Loss: 4.1032 | Val Loss: 3.6643
=> Saved new best model

--- Epoch 5/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.21it/s, loss=4.44]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.28it/s, val_loss=3.87]


Train Loss: 3.4077 | Val Loss: 3.2105
=> Saved new best model

--- Epoch 6/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.20it/s, loss=3.16]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.27it/s, val_loss=3.61]


Train Loss: 3.0859 | Val Loss: 3.2387

--- Epoch 7/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.47it/s, loss=2.73]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.15it/s, val_loss=3.07]


Train Loss: 2.9322 | Val Loss: 2.7952
=> Saved new best model

--- Epoch 8/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.26it/s, loss=2.35]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.28it/s, val_loss=2.89]


Train Loss: 2.7820 | Val Loss: 2.6280
=> Saved new best model

--- Epoch 9/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.23it/s, loss=2.71]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.30it/s, val_loss=3.04]


Train Loss: 2.7004 | Val Loss: 2.6428

--- Epoch 10/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.49it/s, loss=2.98]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.39it/s, val_loss=3.28]


Train Loss: 2.6307 | Val Loss: 2.7281

--- Epoch 11/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=4.53]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.26it/s, val_loss=2.62]


Train Loss: 2.5961 | Val Loss: 2.5999
=> Saved new best model

--- Epoch 12/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.24it/s, loss=2.58]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.35it/s, val_loss=2.73]


Train Loss: 2.5364 | Val Loss: 2.4950
=> Saved new best model

--- Epoch 13/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.26it/s, loss=2.03]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.40it/s, val_loss=2.66]


Train Loss: 2.4967 | Val Loss: 2.5709

--- Epoch 14/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.49it/s, loss=2.43]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.31it/s, val_loss=2.91]


Train Loss: 2.4554 | Val Loss: 2.5745

--- Epoch 15/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.47it/s, loss=2.22]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.43it/s, val_loss=3.29]


Train Loss: 2.5569 | Val Loss: 2.6843

--- Epoch 16/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.49it/s, loss=4.67]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.31it/s, val_loss=3.01]


Train Loss: 2.4695 | Val Loss: 2.5671

--- Epoch 17/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.46it/s, loss=2.25]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.41it/s, val_loss=2.9]


Train Loss: 2.4791 | Val Loss: 2.5465

--- Epoch 18/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=4.28]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.30it/s, val_loss=2.57]


Train Loss: 2.4510 | Val Loss: 2.4173
=> Saved new best model

--- Epoch 19/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.21it/s, loss=2.77]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.29it/s, val_loss=2.81]


Train Loss: 2.4440 | Val Loss: 2.5190

--- Epoch 20/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=2.71]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.15it/s, val_loss=2.43]


Train Loss: 2.3899 | Val Loss: 2.4160
=> Saved new best model

--- Epoch 21/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.21it/s, loss=2.96]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.24it/s, val_loss=2.51]


Train Loss: 2.3914 | Val Loss: 2.5609

--- Epoch 22/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.46it/s, loss=2.12]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.34it/s, val_loss=2.46]


Train Loss: 2.3954 | Val Loss: 2.3945
=> Saved new best model

--- Epoch 23/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.25it/s, loss=2.9]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.26it/s, val_loss=2.46]


Train Loss: 2.3351 | Val Loss: 2.4431

--- Epoch 24/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=2.26]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.36it/s, val_loss=2.36]


Train Loss: 2.3194 | Val Loss: 2.4007

--- Epoch 25/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.47it/s, loss=3.22]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.28it/s, val_loss=3.57]


Train Loss: 2.3727 | Val Loss: 2.6869

--- Epoch 26/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=2.58]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.31it/s, val_loss=2.42]


Train Loss: 2.3914 | Val Loss: 2.4167

--- Epoch 27/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.50it/s, loss=2.08]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.28it/s, val_loss=2.57]


Train Loss: 2.3474 | Val Loss: 2.3602
=> Saved new best model

--- Epoch 28/40 ---


Training: 100%|██████████| 98/98 [00:29<00:00,  3.27it/s, loss=3.72]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.26it/s, val_loss=2.4]


Train Loss: 2.3445 | Val Loss: 2.3913

--- Epoch 29/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=1.7]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.22it/s, val_loss=2.52]


Train Loss: 2.3036 | Val Loss: 2.3028
=> Saved new best model

--- Epoch 30/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.25it/s, loss=1.71]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.32it/s, val_loss=2.41]


Train Loss: 2.2879 | Val Loss: 2.3330

--- Epoch 31/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=1.73]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.36it/s, val_loss=2.35]


Train Loss: 2.2743 | Val Loss: 2.4148

--- Epoch 32/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.49it/s, loss=3.71]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.25it/s, val_loss=2.43]


Train Loss: 2.2684 | Val Loss: 2.3150

--- Epoch 33/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.47it/s, loss=2.09]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.26it/s, val_loss=2.45]


Train Loss: 2.2504 | Val Loss: 2.3641

--- Epoch 34/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.47it/s, loss=1.91]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.25it/s, val_loss=2.66]


Train Loss: 2.3060 | Val Loss: 2.3086

--- Epoch 35/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=2.63]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.29it/s, val_loss=2.42]


Train Loss: 2.2540 | Val Loss: 2.3418

--- Epoch 36/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=1.9]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.31it/s, val_loss=2.38]


Train Loss: 2.2683 | Val Loss: 2.3043

--- Epoch 37/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=1.91]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.35it/s, val_loss=2.48]


Train Loss: 2.2143 | Val Loss: 2.3024
=> Saved new best model

--- Epoch 38/40 ---


Training: 100%|██████████| 98/98 [00:30<00:00,  3.23it/s, loss=1.82]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.28it/s, val_loss=2.53]


Train Loss: 2.2464 | Val Loss: 2.3090

--- Epoch 39/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=2.14]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.24it/s, val_loss=2.38]


Train Loss: 2.2214 | Val Loss: 2.3393

--- Epoch 40/40 ---


Training: 100%|██████████| 98/98 [00:28<00:00,  3.48it/s, loss=1.83]
Validation: 100%|██████████| 12/12 [00:02<00:00,  5.17it/s, val_loss=3.16]


Train Loss: 2.2396 | Val Loss: 2.5361

--- Testing Best Model ---

--- Saving predictions for test_set set ---


100%|██████████| 47/47 [01:57<00:00,  2.49s/it]

--- V5 Robust Training Complete ---





In [None]:
# =================================================================================
# 0. SETUP AND IMPORTS
# =================================================================================
!pip install -q monai pandas

import os
import cv2
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from monai.metrics import (
    compute_dice,
    compute_iou,
    compute_hausdorff_distance
)

# =================================================================================
# 1. CONFIGURATION
# =================================================================================
# --- PATHS FOR V5 ---
# Ground Truth Masks (The NEW CSV/TSV masks folder)
GT_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/augmented masks v3/split_masks"

# Predicted Masks (From V5 Output)
PRED_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs_v5_split_leaf_final/pred_masks/test_set/"

# --- CLASS MAP FOR V5 (7 Classes) ---
NUM_CLASSES = 7
CLASS_MAP = {
    0: "Background",
    1: "Root",
    2: "Unused",
    3: "Stem",
    4: "Seed",
    5: "Left Leaf",
    6: "Right Leaf"
}

# =================================================================================
# 2. HELPER FUNCTIONS
# =================================================================================
def to_one_hot(mask, num_classes):
    """Converts a (H, W) mask to (1, C, H, W) one-hot tensor."""
    mask[mask >= num_classes] = 0 # Safety clip
    one_hot = np.eye(num_classes)[mask]
    one_hot = np.transpose(one_hot, (2, 0, 1))
    return torch.from_numpy(one_hot).unsqueeze(0)

def load_gt_mask(filename):
    """Robustly loads GT mask from CSV or TSV."""
    base_name = os.path.splitext(filename)[0].replace("_mask", "") # Remove _mask from prediction name
    possible_names = [base_name + ".csv", base_name + ".tsv", base_name + "_mask.csv", base_name + "_mask.tsv"]

    mask_path = None
    for name in possible_names:
        p = os.path.join(GT_MASK_DIR, name)
        if os.path.exists(p):
            mask_path = p
            break

    if mask_path is None:
        return None

    try:
        delimiter = ',' if mask_path.endswith('.csv') else '\t'
        try:
            df = pd.read_csv(mask_path, header=None, sep=delimiter)
            if df.shape[1] == 1:
                 df = pd.read_csv(mask_path, header=None, delim_whitespace=True)
        except:
            df = pd.read_csv(mask_path, header=None, delim_whitespace=True)
        return df.values.astype(np.int64)
    except:
        return None

# =================================================================================
# 3. MAIN ANALYSIS LOOP
# =================================================================================
def run_analysis():
    print("Starting V5 Analysis...")
    print(f"Pred Directory: {PRED_MASK_DIR}")

    results_list = []
    pred_files = [f for f in os.listdir(PRED_MASK_DIR) if f.endswith('.png')]

    if len(pred_files) == 0:
        print("❌ Error: No prediction files found! Check your PRED_MASK_DIR path.")
        return

    for filename in tqdm(pred_files):
        pred_path = os.path.join(PRED_MASK_DIR, filename)

        # Load Prediction
        pred_mask = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)

        # Load Ground Truth (Using robust loader)
        gt_mask = load_gt_mask(filename)

        if gt_mask is None or pred_mask is None:
            continue

        # Resize GT to match Pred (Using our Robust Training Logic: Pred is Image Size)
        # In training, we padded Image to match GT.
        # But prediction output was "unsqueezed" to original image size?
        # Wait, the V5 script output raw mask size (padded).
        # Let's ensure dimensions match by resizing GT to Pred if needed (or vice versa).
        # Safest bet: Resize GT to Pred size (Nearest Neighbor) to emulate "Original Image Space" evaluation
        if gt_mask.shape != pred_mask.shape:
             gt_mask = cv2.resize(gt_mask.astype(np.uint8), (pred_mask.shape[1], pred_mask.shape[0]), interpolation=cv2.INTER_NEAREST)

        # Convert to One-Hot
        gt_onehot = to_one_hot(gt_mask, NUM_CLASSES)
        pred_onehot = to_one_hot(pred_mask, NUM_CLASSES)

        # Metrics
        dice = compute_dice(pred_onehot, gt_onehot, include_background=True)
        iou = compute_iou(pred_onehot, gt_onehot, include_background=True)
        hd95 = compute_hausdorff_distance(pred_onehot, gt_onehot, include_background=True, percentile=95)

        # Store
        file_metrics = {'filename': filename}
        for i in range(NUM_CLASSES):
            if i == 2: continue # Skip Unused Class
            c_name = CLASS_MAP[i]
            file_metrics[f"{c_name}_Dice"] = dice[0, i].item()
            file_metrics[f"{c_name}_IOU"] = iou[0, i].item()
            file_metrics[f"{c_name}_HD95"] = hd95[0, i].item()

        results_list.append(file_metrics)

    # Report
    if not results_list:
        print("No results generated.")
        return

    df = pd.DataFrame(results_list)
    overall_stats = df.mean(numeric_only=True)

    print("\n\n--- V5 Overall Average Statistics (Test Set) ---")
    summary_data = []
    for i in range(NUM_CLASSES):
        if i == 2: continue
        c_name = CLASS_MAP[i]
        summary_data.append({
            "Class": c_name,
            "Dice (↑)": overall_stats.get(f"{c_name}_Dice"),
            "IOU (↑)": overall_stats.get(f"{c_name}_IOU"),
            "HD95 (↓)": overall_stats.get(f"{c_name}_HD95"),
        })

    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_markdown(index=False, floatfmt=".4f"))

    summary_df.to_csv("/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/v5_results_summary.csv", index=False)
    print("\nSummary saved.")

if __name__ == "__main__":
    run_analysis()

Starting V5 Analysis...
Pred Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs_v5_split_leaf_final/pred_masks/test_set/


  df = pd.read_csv(mask_path, header=None, delim_whitespace=True)
  df = pd.read_csv(mask_path, header=None, delim_whitespace=True)
  df = pd.read_csv(mask_path, header=None, delim_whitespace=True)
  6%|▌         | 3/52 [00:01<00:23,  2.08it/s]