In [None]:
!pip install -q segmentation-models-pytorch albumentations monai

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m2.4/2.7 MB[0m [31m72.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from tqdm import tqdm
import random

from monai.losses import DiceLoss, HausdorffDTLoss



In [None]:
class Config:
    # -- Base Paths --
    BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/"
    DATA_PATH = os.path.join(BASE_PATH, "combined_data")
    IMAGE_DIR = os.path.join(DATA_PATH, "images")
    MASK_DIR = os.path.join(DATA_PATH, "masks")
    QC_REPORT_CSV = os.path.join(BASE_PATH, "problematic_images_inspection.csv")

    # -- Output Paths --
    OUTPUT_DIR = os.path.join(BASE_PATH, "outputs_v4_highres") # V4 Output
    OUTPUT_MASK_DIR = os.path.join(OUTPUT_DIR, "pred_masks_unsqueezed")
    COLOR_MASK_DIR = os.path.join(OUTPUT_DIR, "color_masks_unsqueezed")

    # -- Data Settings --
    VAL_SIZE = 51
    TEST_SIZE = 52

    # -- Model Hyperparameters (V4 UPGRADES) --
    # UPGRADE 1: U-Net++ is better for fine details than standard U-Net
    ARCHITECTURE = 'unetplusplus'
    ENCODER = 'resnet34'
    ENCODER_WEIGHTS = 'imagenet'
    LEARNING_RATE = 1e-4
    OPTIMIZER = 'AdamW'

    # -- Training Settings --
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # UPGRADE 2: Lower batch size because images are bigger
    BATCH_SIZE = 2
    NUM_EPOCHS = 35 # High res might need a bit longer to converge

    # UPGRADE 3: Double the resolution (Critical for thin roots)
    IMAGE_HEIGHT = 512
    IMAGE_WIDTH = 512

    NUM_CLASSES = 5

    # -- LOSS CONFIGURATION (Same as V3 - Don't fix what isn't broken) --
    CLASS_WEIGHTS = torch.tensor([
        1.0,  # Background
        7.0,  # Stem
        5.0,  # Leaf
        5.0,  # Root
        7.0   # Seed
    ], device=DEVICE)

    HYBRID_WEIGHT_DICE = 1.0
    HYBRID_WEIGHT_HD = 0.0

    # -- Visualization --
    COLOR_MAP = {
        0: (0, 0, 0),         # background - black
        1: (139, 69, 19),     # stem - brown
        2: (0, 255, 0),       # leaf - green
        3: (255, 255, 0),     # root - yellow
        4: (255, 0, 0),       # seed - red
    }


In [None]:
os.makedirs(Config.OUTPUT_MASK_DIR, exist_ok=True)
os.makedirs(Config.COLOR_MASK_DIR, exist_ok=True)

In [None]:
def get_filtered_data_splits(image_dir, qc_csv_path, val_size, test_size):
    all_images = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))])

    excluded_filenames = set()
    if os.path.exists(qc_csv_path):
        df = pd.read_csv(qc_csv_path)
        bad_rows = df[df['Mask_correct'].astype(str).str.upper() == 'FALSE']
        excluded_filenames = set(bad_rows['filename'].tolist())

    valid_images = []
    for img_name in all_images:
        mask_name = os.path.splitext(img_name)[0] + "_mask.png"
        if mask_name in excluded_filenames:
            continue
        valid_images.append(img_name)

    random.seed(42)
    random.shuffle(valid_images)

    test_files = valid_images[:test_size]
    val_files = valid_images[test_size : test_size + val_size]
    train_files = valid_images[test_size + val_size :]

    return train_files, val_files, test_files


In [None]:
class PlantDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_filenames, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = image_filenames

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)
        mask_name = os.path.splitext(img_name)[0] + "_mask.png"
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        original_height, original_width = image.shape[:2]

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        mask[mask >= Config.NUM_CLASSES] = 0
        return image, mask.long(), (original_height, original_width)
train_transform = A.Compose([
    A.Resize(height=Config.IMAGE_HEIGHT, width=Config.IMAGE_WIDTH),
    A.Rotate(limit=35, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=Config.IMAGE_HEIGHT, width=Config.IMAGE_WIDTH),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

In [None]:
class WeightedHybridLoss(nn.Module):
    def __init__(self, class_weights, weight_dice=1.0, weight_hd=0.0):
        super(WeightedHybridLoss, self).__init__()
        self.w_dice = weight_dice
        self.w_hd = weight_hd
        self.class_weights = class_weights
        self.dice_loss = DiceLoss(softmax=True, to_onehot_y=True, include_background=True, reduction='none')
        self.hd_loss = HausdorffDTLoss(softmax=True, to_onehot_y=True, include_background=True, reduction='none')

    def forward(self, preds_logits, targets_idx):
        targets_idx = targets_idx.unsqueeze(1)
        loss_dice_per_class = self.dice_loss(preds_logits, targets_idx)
        weighted_loss_dice_all = loss_dice_per_class * self.class_weights
        weighted_loss_dice = weighted_loss_dice_all.mean()
        total_loss = self.w_dice * weighted_loss_dice
        if self.w_hd > 0:
            loss_hd_per_class = self.hd_loss(preds_logits, targets_idx)
            weighted_loss_hd_all = loss_hd_per_class * self.class_weights
            weighted_loss_hd = weighted_loss_hd_all.mean()
            total_loss += (self.w_hd * weighted_loss_hd)
        return total_loss


In [None]:
def mask_to_rgb(mask_tensor, color_map):
    mask = mask_tensor.cpu().numpy().squeeze()
    rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        rgb_mask[mask == class_idx] = color
    return Image.fromarray(rgb_mask)

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader, desc="Training")
    total_loss = 0
    for batch_idx, (data, targets, _) in enumerate(loop):
        data = data.to(device=Config.DEVICE)
        targets = targets.to(device=Config.DEVICE)
        with torch.amp.autocast('cuda'):
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    return total_loss / len(loader)

def eval_fn(loader, model, loss_fn):
    model.eval()
    total_loss = 0
    loop = tqdm(loader, desc="Validation")
    with torch.no_grad():
        for data, targets, _ in loop:
            data = data.to(device=Config.DEVICE)
            targets = targets.to(device=Config.DEVICE)
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            total_loss += loss.item()
            loop.set_postfix(val_loss=loss.item())
    model.train()
    return total_loss / len(loader)

def save_predictions_fn(loader, model, folder_basename=""):
    print(f"\n--- Saving predictions for {folder_basename} set ---")
    model.eval()
    output_mask_dir = os.path.join(Config.OUTPUT_MASK_DIR, folder_basename)
    color_mask_dir = os.path.join(Config.COLOR_MASK_DIR, folder_basename)
    os.makedirs(output_mask_dir, exist_ok=True)
    os.makedirs(color_mask_dir, exist_ok=True)

    for idx in tqdm(range(len(loader.dataset)), desc=f"Saving {folder_basename} Predictions"):
        img_tensor, _, (original_h, original_w) = loader.dataset[idx]
        with torch.no_grad():
            img_tensor = img_tensor.to(Config.DEVICE).unsqueeze(0)
            preds = model(img_tensor)
            final_mask_tensor = torch.argmax(preds, dim=1).squeeze(0)

        pred_mask_np = final_mask_tensor.cpu().numpy().astype(np.uint8)
        resized_mask = cv2.resize(pred_mask_np, (original_w, original_h), interpolation=cv2.INTER_NEAREST)

        pred_mask_img = Image.fromarray(resized_mask)
        original_filename = loader.dataset.images[idx]
        mask_filename = os.path.splitext(original_filename)[0] + "_mask.png"
        pred_mask_img.save(os.path.join(output_mask_dir, mask_filename))

        color_mask_img = mask_to_rgb(torch.from_numpy(resized_mask), Config.COLOR_MAP)
        color_mask_img.save(os.path.join(color_mask_dir, mask_filename))
    model.train()

In [None]:
def main():
    print(f"Using device: {Config.DEVICE}")
    print(f"Dataset Path: {Config.DATA_PATH}")

    # 1. Splits
    train_files, val_files, test_files = get_filtered_data_splits(
        Config.IMAGE_DIR, Config.QC_REPORT_CSV, Config.VAL_SIZE, Config.TEST_SIZE
    )

    # 2. Loaders
    train_dataset = PlantDataset(Config.IMAGE_DIR, Config.MASK_DIR, train_files, train_transform)
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
    val_dataset = PlantDataset(Config.IMAGE_DIR, Config.MASK_DIR, val_files, val_transform)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)
    test_dataset = PlantDataset(Config.IMAGE_DIR, Config.MASK_DIR, test_files, val_transform)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # 3. Model (UNET++)
    model = smp.create_model(
        arch=Config.ARCHITECTURE, # 'unetplusplus'
        encoder_name=Config.ENCODER,
        encoder_weights=Config.ENCODER_WEIGHTS,
        in_channels=3,
        classes=Config.NUM_CLASSES,
    ).to(Config.DEVICE)

    loss_fn = WeightedHybridLoss(
        class_weights=Config.CLASS_WEIGHTS,
        weight_dice=Config.HYBRID_WEIGHT_DICE,
        weight_hd=Config.HYBRID_WEIGHT_HD
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scaler = torch.amp.GradScaler('cuda')
    best_val_loss = float('inf')

    # 4. Train
    for epoch in range(Config.NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{Config.NUM_EPOCHS} ---")
        train_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)
        val_loss = eval_fn(val_loader, model, loss_fn)
        print(f"Average Train Loss: {train_loss:.4f}")
        print(f"Average Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(Config.BASE_PATH, "best_model_v4_highres.pth"))
            print("=> Saved new best model")

    # 5. Save
    print("\n--- Loading best model for final testing ---")
    model.load_state_dict(torch.load(os.path.join(Config.BASE_PATH, "best_model_v4_highres.pth")))
    save_predictions_fn(test_loader, model, folder_basename="test_set")
    print("--- V4 Training Complete ---")


In [None]:
if __name__ == "__main__":
    main()

Using device: cuda
Dataset Path: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/combined_data


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]


--- Epoch 1/35 ---


Training: 100%|██████████| 246/246 [02:45<00:00,  1.48it/s, loss=3.52]
Validation: 100%|██████████| 26/26 [01:11<00:00,  2.76s/it, val_loss=3.17]


Average Train Loss: 3.8404
Average Val Loss: 3.1450
=> Saved new best model

--- Epoch 2/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.65it/s, loss=2.25]
Validation: 100%|██████████| 26/26 [00:00<00:00, 27.99it/s, val_loss=2.02]


Average Train Loss: 2.8135
Average Val Loss: 2.2798
=> Saved new best model

--- Epoch 3/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.60it/s, loss=2.36]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.44it/s, val_loss=1.59]


Average Train Loss: 2.1375
Average Val Loss: 1.8113
=> Saved new best model

--- Epoch 4/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.47it/s, loss=1.2]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.22it/s, val_loss=1.48]


Average Train Loss: 1.9288
Average Val Loss: 1.7495
=> Saved new best model

--- Epoch 5/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.65it/s, loss=1.46]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.85it/s, val_loss=0.736]


Average Train Loss: 1.5069
Average Val Loss: 1.2571
=> Saved new best model

--- Epoch 6/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.64it/s, loss=0.952]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.92it/s, val_loss=0.724]


Average Train Loss: 1.2769
Average Val Loss: 1.2161
=> Saved new best model

--- Epoch 7/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.56it/s, loss=1.66]
Validation: 100%|██████████| 26/26 [00:00<00:00, 28.67it/s, val_loss=0.785]


Average Train Loss: 1.2132
Average Val Loss: 1.2807

--- Epoch 8/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.73it/s, loss=1.16]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.49it/s, val_loss=0.625]


Average Train Loss: 1.1792
Average Val Loss: 1.2155
=> Saved new best model

--- Epoch 9/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.58it/s, loss=1.19]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.54it/s, val_loss=0.499]


Average Train Loss: 1.1706
Average Val Loss: 1.1832
=> Saved new best model

--- Epoch 10/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.55it/s, loss=0.842]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.85it/s, val_loss=0.452]


Average Train Loss: 1.1482
Average Val Loss: 1.1726
=> Saved new best model

--- Epoch 11/35 ---


Training: 100%|██████████| 246/246 [00:17<00:00, 13.71it/s, loss=2.25]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.23it/s, val_loss=0.566]


Average Train Loss: 1.1477
Average Val Loss: 1.1713
=> Saved new best model

--- Epoch 12/35 ---


Training: 100%|██████████| 246/246 [00:17<00:00, 13.78it/s, loss=1.93]
Validation: 100%|██████████| 26/26 [00:00<00:00, 30.36it/s, val_loss=0.581]


Average Train Loss: 1.1044
Average Val Loss: 1.1531
=> Saved new best model

--- Epoch 13/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.48it/s, loss=0.64]
Validation: 100%|██████████| 26/26 [00:00<00:00, 28.90it/s, val_loss=0.514]


Average Train Loss: 1.1041
Average Val Loss: 1.0921
=> Saved new best model

--- Epoch 14/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.55it/s, loss=1.94]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.93it/s, val_loss=0.553]


Average Train Loss: 1.0936
Average Val Loss: 1.1056

--- Epoch 15/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.80it/s, loss=0.637]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.09it/s, val_loss=0.661]


Average Train Loss: 1.0738
Average Val Loss: 1.0978

--- Epoch 16/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.76it/s, loss=0.737]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.63it/s, val_loss=0.697]


Average Train Loss: 1.0752
Average Val Loss: 1.1358

--- Epoch 17/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.78it/s, loss=1.27]
Validation: 100%|██████████| 26/26 [00:00<00:00, 28.52it/s, val_loss=0.631]


Average Train Loss: 1.0755
Average Val Loss: 1.1061

--- Epoch 18/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.71it/s, loss=1.34]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.21it/s, val_loss=0.603]


Average Train Loss: 1.0652
Average Val Loss: 1.0781
=> Saved new best model

--- Epoch 19/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.63it/s, loss=1.37]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.46it/s, val_loss=0.574]


Average Train Loss: 1.0487
Average Val Loss: 1.0777
=> Saved new best model

--- Epoch 20/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.55it/s, loss=0.699]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.21it/s, val_loss=0.574]


Average Train Loss: 1.0527
Average Val Loss: 1.0826

--- Epoch 21/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.68it/s, loss=1.04]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.91it/s, val_loss=0.715]


Average Train Loss: 1.0349
Average Val Loss: 1.0940

--- Epoch 22/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.77it/s, loss=1.21]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.01it/s, val_loss=0.526]


Average Train Loss: 1.0478
Average Val Loss: 1.0588
=> Saved new best model

--- Epoch 23/35 ---


Training: 100%|██████████| 246/246 [00:17<00:00, 13.69it/s, loss=0.758]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.80it/s, val_loss=0.473]


Average Train Loss: 1.0430
Average Val Loss: 1.1134

--- Epoch 24/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.80it/s, loss=1.4]
Validation: 100%|██████████| 26/26 [00:00<00:00, 30.09it/s, val_loss=0.525]


Average Train Loss: 1.0251
Average Val Loss: 1.0928

--- Epoch 25/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.68it/s, loss=1.28]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.64it/s, val_loss=0.42]


Average Train Loss: 1.0313
Average Val Loss: 1.0777

--- Epoch 26/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.75it/s, loss=1.18]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.53it/s, val_loss=0.41]


Average Train Loss: 1.0187
Average Val Loss: 1.0527
=> Saved new best model

--- Epoch 27/35 ---


Training: 100%|██████████| 246/246 [00:18<00:00, 13.61it/s, loss=1.06]
Validation: 100%|██████████| 26/26 [00:00<00:00, 28.46it/s, val_loss=0.539]


Average Train Loss: 0.9961
Average Val Loss: 1.0920

--- Epoch 28/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.80it/s, loss=0.813]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.47it/s, val_loss=0.516]


Average Train Loss: 1.0004
Average Val Loss: 1.1259

--- Epoch 29/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.78it/s, loss=0.503]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.70it/s, val_loss=0.691]


Average Train Loss: 1.0081
Average Val Loss: 1.1583

--- Epoch 30/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.75it/s, loss=0.668]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.51it/s, val_loss=0.431]


Average Train Loss: 0.9900
Average Val Loss: 1.0572

--- Epoch 31/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.80it/s, loss=1.27]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.44it/s, val_loss=0.597]


Average Train Loss: 0.9925
Average Val Loss: 1.1396

--- Epoch 32/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.72it/s, loss=0.936]
Validation: 100%|██████████| 26/26 [00:00<00:00, 28.87it/s, val_loss=0.465]


Average Train Loss: 0.9892
Average Val Loss: 1.0651

--- Epoch 33/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.71it/s, loss=0.647]
Validation: 100%|██████████| 26/26 [00:00<00:00, 29.74it/s, val_loss=0.499]


Average Train Loss: 0.9667
Average Val Loss: 1.1017

--- Epoch 34/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.83it/s, loss=1.03]
Validation: 100%|██████████| 26/26 [00:00<00:00, 28.63it/s, val_loss=0.506]


Average Train Loss: 0.9764
Average Val Loss: 1.0808

--- Epoch 35/35 ---


Training: 100%|██████████| 246/246 [00:15<00:00, 15.74it/s, loss=0.695]
Validation: 100%|██████████| 26/26 [00:00<00:00, 28.71it/s, val_loss=0.602]


Average Train Loss: 0.9769
Average Val Loss: 1.1050

--- Loading best model for final testing ---

--- Saving predictions for test_set set ---


Saving test_set Predictions: 100%|██████████| 52/52 [00:13<00:00,  3.96it/s]

--- V4 Training Complete ---





In [None]:
# =================================================================================
# 0. SETUP AND IMPORTS
# =================================================================================
!pip install -q monai pandas

import os
import cv2
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from monai.metrics import (
    compute_dice,
    compute_iou,
    compute_hausdorff_distance
)

# =================================================================================
# 1. CONFIGURATION
# =================================================================================
# --- PATHS FOR V3 ---
# Ground Truth Masks (The COMBINED dataset masks)
GT_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/combined_data/masks/"

# Predicted Masks (From V4 Output)
PRED_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs_v4_highres/pred_masks_unsqueezed/test_set/"

# --- CORRECTED CLASS MAP ---
NUM_CLASSES = 5
CLASS_MAP = {
    0: "Background",
    1: "Stem",
    2: "Leaf",
    3: "Root",
    4: "Seed",
}

# =================================================================================
# 2. HELPER FUNCTIONS
# =================================================================================
def to_one_hot(mask, num_classes):
    """Converts a (H, W) mask to (1, C, H, W) one-hot tensor."""
    # Clip values just in case
    mask[mask >= num_classes] = 0

    one_hot = np.eye(num_classes)[mask] # (H, W, C)
    one_hot = np.transpose(one_hot, (2, 0, 1)) # (C, H, W)
    return torch.from_numpy(one_hot).unsqueeze(0) # (1, C, H, W)

# =================================================================================
# 3. MAIN ANALYSIS LOOP
# =================================================================================
def run_analysis():
    print("Starting V4 Analysis...")
    print(f"GT Directory: {GT_MASK_DIR}")
    print(f"Pred Directory: {PRED_MASK_DIR}")

    results_list = []
    pred_files = [f for f in os.listdir(PRED_MASK_DIR) if f.endswith('.png')]

    if len(pred_files) == 0:
        print("Error: No prediction files found! Check your PRED_MASK_DIR path.")
        return

    for filename in tqdm(pred_files):
        pred_path = os.path.join(PRED_MASK_DIR, filename)

        # Logic to find GT file (Assumes names match exactly or close to it)
        # V3 script saves predictions as "original_name_mask.png"
        gt_path = os.path.join(GT_MASK_DIR, filename)

        if not os.path.exists(gt_path):
            print(f"Skipping {filename}: GT mask not found.")
            continue

        # Load Masks
        gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
        pred_mask = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)

        if gt_mask is None or pred_mask is None:
            continue

        # Resize GT if dimensions don't match (Safety check for unsqueezing issues)
        if gt_mask.shape != pred_mask.shape:
            # We assume prediction is correct size (unsqueezed), so we resize GT to match
            # This handles cases where GT might be slightly off due to pre-processing
            gt_mask = cv2.resize(gt_mask, (pred_mask.shape[1], pred_mask.shape[0]), interpolation=cv2.INTER_NEAREST)

        # Convert to One-Hot
        gt_onehot = to_one_hot(gt_mask, NUM_CLASSES)
        pred_onehot = to_one_hot(pred_mask, NUM_CLASSES)

        # --- Calculate Metrics ---
        dice = compute_dice(pred_onehot, gt_onehot, include_background=True)
        iou = compute_iou(pred_onehot, gt_onehot, include_background=True)
        hd95 = compute_hausdorff_distance(pred_onehot, gt_onehot, include_background=True, percentile=95)

        # Store results
        file_metrics = {'filename': filename}
        for i in range(NUM_CLASSES):
            c_name = CLASS_MAP[i]
            file_metrics[f"{c_name}_Dice"] = dice[0, i].item()
            file_metrics[f"{c_name}_IOU"] = iou[0, i].item()
            file_metrics[f"{c_name}_HD95"] = hd95[0, i].item()

        results_list.append(file_metrics)

    # =================================================================================
    # 4. REPORTING
    # =================================================================================
    if not results_list:
        print("No results generated.")
        return

    df = pd.DataFrame(results_list)

    # Calculate Averages
    overall_stats = df.mean(numeric_only=True)

    print("\n\n--- V4 Overall Average Statistics (Test Set) ---")
    summary_data = []
    for i in range(NUM_CLASSES):
        c_name = CLASS_MAP[i]
        summary_data.append({
            "Class": c_name,
            "Dice (↑)": overall_stats.get(f"{c_name}_Dice"),
            "IOU (↑)": overall_stats.get(f"{c_name}_IOU"),
            "HD95 (↓)": overall_stats.get(f"{c_name}_HD95"),
        })

    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_markdown(index=False, floatfmt=".4f"))

    # Optional: Save to CSV for comparison later
    summary_df.to_csv("/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/v4_results_summary.csv", index=False)
    print("\nSummary saved to v4_results_summary.csv")

if __name__ == "__main__":
    run_analysis()

Starting V4 Analysis...
GT Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/combined_data/masks/
Pred Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs_v4_highres/pred_masks_unsqueezed/test_set/


100%|██████████| 52/52 [00:03<00:00, 13.09it/s]



--- V4 Overall Average Statistics (Test Set) ---
| Class      |   Dice (↑) |   IOU (↑) |   HD95 (↓) |
|:-----------|-----------:|----------:|-----------:|
| Background |     0.9939 |    0.9880 |     7.2059 |
| Stem       |     0.7138 |    0.5842 |     7.1228 |
| Leaf       |     0.8482 |    0.7461 |     9.0360 |
| Root       |     0.8111 |    0.6964 |    16.6752 |
| Seed       |     0.6610 |    0.5534 |     9.1982 |

Summary saved to v4_results_summary.csv



