In [None]:
!pip install -q -U segmentation-models-pytorch albumentations monai

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.7/2.7 MB[0m [31m96.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m60.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from tqdm import tqdm
import random

from monai.losses import DiceLoss



In [None]:
class Config:
    # -- Base Paths --
    BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/"

    # Path to ORIGINAL images (JPGs) and GT masks
    IMAGE_DIR = os.path.join(BASE_PATH, "images")

    MASK_DIR = os.path.join(BASE_PATH,"masks")

    # Path to DISTANCE WEIGHT MAPS
    DIST_WEIGHT_DIR = os.path.join(BASE_PATH,"augmented masks v3/distance_weights")

    # Data QC Report & Splits
    QC_REPORT_CSV = os.path.join(BASE_PATH, "missing_classes_from_mask.csv")
    SPLIT_CSV = os.path.join(BASE_PATH, "dataset_split.csv")

    # -- Output --
    OUTPUT_DIR = os.path.join(BASE_PATH, "outputs_v6_final")
    OUTPUT_MASK_DIR = os.path.join(OUTPUT_DIR, "pred_masks")
    COLOR_MASK_DIR = os.path.join(OUTPUT_DIR, "color_masks")

    # -- Model Hyperparameters --
    ARCHITECTURE = 'unetplusplus'
    ENCODER = 'resnet34'
    ENCODER_WEIGHTS = 'imagenet'
    LEARNING_RATE = 1e-4
    OPTIMIZER = 'AdamW'

    # -- Training Settings --
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 4
    NUM_EPOCHS = 40

    IMAGE_HEIGHT = 512
    IMAGE_WIDTH = 512

    NUM_CLASSES = 5

    # -- LOSS CONFIGURATION --
    # Weights for 5 classes: [Bg, Stem, Leaf, Root, Seed]
    CLASS_WEIGHTS = torch.tensor([
        1.0,  # Background
        9.0,  # Stem
        7.0,  # Leaf
        9.0,  # Root
        10.0  # Seed
    ], device=DEVICE)

    # Boundary Importance Scale
    BOUNDARY_SCALE = 2.0

    # -- Visualization --
    COLOR_MAP = {
        0: (0, 0, 0),        # background
        1: (139, 69, 19),    # stem - brown
        2: (0, 255, 0),      # leaf - green
        3: (255, 255, 0),    # root - yellow
        4: (255, 0, 0),      # seed - red
    }

os.makedirs(Config.OUTPUT_MASK_DIR, exist_ok=True)
os.makedirs(Config.COLOR_MASK_DIR, exist_ok=True)


In [None]:
def get_splits_v6_original(image_dir, mask_dir, dist_dir, split_csv_path, qc_csv_path):
    print("--- Configuring V6 Splits (Original Masks + Distance) ---")

    # Load mislabeled
    excluded_masks = set()
    if os.path.exists(qc_csv_path):
        df_qc = pd.read_csv(qc_csv_path)
        bad_rows = df_qc[df_qc['Mask_correct'].astype(str).str.upper() == 'FALSE']
        excluded_masks = set(bad_rows['filename'].tolist())

    # Load Split Map
    if not os.path.exists(split_csv_path): raise FileNotFoundError("Split CSV missing")
    try: df_split = pd.read_csv(split_csv_path)
    except: df_split = pd.read_excel(split_csv_path.replace('.csv', '.xlsx'))

    train_files, val_files, test_files = [], [], []

    for idx, row in df_split.iterrows():
        img_name = row['img_name']
        set_type = row['set'].lower().strip()

        # Verify Image
        if not os.path.exists(os.path.join(image_dir, img_name)): continue

        base_name = os.path.splitext(img_name)[0]

        # Verify PNG Mask (V4 Style)
        mask_name = base_name + "_mask.png"
        if not os.path.exists(os.path.join(mask_dir, mask_name)): continue

        # Verify Distance Map (TSV/CSV) - STRICT for V6
        dist_path = None
        for name in [base_name + ".csv", base_name + "_weights.csv", base_name + ".tsv"]:
            if os.path.exists(os.path.join(dist_dir, name)): dist_path = name; break
        if not dist_path: continue

        # QC Check
        if mask_name in excluded_masks: continue

        if set_type == 'train': train_files.append(img_name)
        elif set_type == 'val': val_files.append(img_name)
        elif set_type == 'test': test_files.append(img_name)

    print(f"Train: {len(train_files)} | Val: {len(val_files)} | Test: {len(test_files)}")
    return train_files, val_files, test_files


In [None]:
class BoundaryDiceDataset(Dataset):
    def __init__(self, image_dir, mask_dir, dist_dir, image_filenames, transform=None):
        self.image_dir = image_dir; self.mask_dir = mask_dir; self.dist_dir = dist_dir
        self.transform = transform; self.images = image_filenames

    def __len__(self): return len(self.images)

    def _load_csv(self, path):
        if path is None: return None
        try:
            try: df = pd.read_csv(path, header=None, sep=',' if path.endswith('.csv') else '\t')
            except: df = pd.read_csv(path, header=None, delim_whitespace=True)
            return df.values.astype(np.float32)
        except: return None

    def __getitem__(self, index):
        img_name = self.images[index]
        base_name = os.path.splitext(img_name)[0]

        # 1. Load Image
        img_path = os.path.join(self.image_dir, img_name)
        image = np.array(Image.open(img_path).convert("RGB"))

        # 2. Load Mask (PNG - Original V4 Style)
        mask_name = base_name + "_mask.png"
        mask_path = os.path.join(self.mask_dir, mask_name)
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # 3. Load Distance Map (TSV)
        dist_path = None
        for name in [base_name + ".csv", base_name + "_weights.csv", base_name + ".tsv"]:
            p = os.path.join(self.dist_dir, name)
            if os.path.exists(p): dist_path = p; break

        dist_map = self._load_csv(dist_path)

        if dist_map is None: dist_map = np.zeros_like(mask)

        # Smart Padding / Sizing

        h_img, w_img = image.shape[:2]
        h_dist, w_dist = dist_map.shape[:2]

        if (h_img != h_dist) or (w_img != w_dist):

            cropped_dist = dist_map[:h_img, :w_img]
            # Handle case where image > dist map
            if cropped_dist.shape != (h_img, w_img):
                temp = np.zeros((h_img, w_img), dtype=np.float32)
                h_c, w_c = cropped_dist.shape
                temp[:h_c, :w_c] = cropped_dist
                cropped_dist = temp

            dist_map = cropped_dist

        original_size = (h_img, w_img)

        # 5. Augmentations
        if self.transform:
            aug = self.transform(image=image, mask=mask, dist_map=dist_map)
            image = aug['image']; mask = aug['mask']; dist_map = aug['dist_map']

        # Clip mask to 5 classes
        mask[mask >= Config.NUM_CLASSES] = 0

        return image, mask.long(), dist_map.float(), original_size

In [None]:
train_transform = A.Compose([
    A.Resize(Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH),
    A.Rotate(limit=35, p=0.5), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5),
    A.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0], max_pixel_value=255.0),
    ToTensorV2(),
], additional_targets={'dist_map': 'mask'})

val_transform = A.Compose([
    A.Resize(Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH),
    A.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0], max_pixel_value=255.0),
    ToTensorV2(),
], additional_targets={'dist_map': 'mask'})


In [None]:
class BoundaryWeightedDiceLoss(nn.Module):
    def __init__(self, class_weights, boundary_scale=2.0, smooth=1e-5):
        super().__init__()
        self.class_weights = class_weights
        self.boundary_scale = boundary_scale
        self.smooth = smooth

    def forward(self, preds, targets, dist_maps):
        """
        Calculates Dice Loss but weights pixels near boundaries more heavily.
        """
        # Apply Softmax to get probabilities (B, C, H, W)
        probs = torch.softmax(preds, dim=1)

        # One-Hot Encode Targets (B, C, H, W)
        targets_one_hot = torch.nn.functional.one_hot(targets, num_classes=Config.NUM_CLASSES).permute(0, 3, 1, 2).float()

        # Create Boundary Weight Map (B, 1, H, W)
        # 1.0 (Center) -> 1.0 + Scale (Boundary)
        # Unsqueeze dist_map to match channel dim for broadcasting
        spatial_weights = 1.0 + (self.boundary_scale * dist_maps.unsqueeze(1))

        # Calculate Weighted Intersection & Union
        numerator = 2.0 * (probs * targets_one_hot * spatial_weights).sum(dim=(2, 3))
        denominator = ((probs + targets_one_hot) * spatial_weights).sum(dim=(2, 3))

        # Calculate Dice per class per batch
        dice_score = (numerator + self.smooth) / (denominator + self.smooth)
        # Calculate loss per class (1 - Dice)
        dice_loss = 1.0 - dice_score

        # Apply weights to the LOSS, not the Dice score
        weighted_loss = dice_loss * self.class_weights

        # Return the mean loss
        return weighted_loss.mean()

def mask_to_rgb(mask_tensor, color_map):
    if torch.is_tensor(mask_tensor): mask = mask_tensor.cpu().numpy().squeeze()
    else: mask = mask_tensor.squeeze()
    rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for k, v in color_map.items(): rgb[mask==k] = v
    return rgb

def save_predictions_fn(loader, model, folder_basename=""):
    print(f"\n--- Saving predictions for {folder_basename} set ---")
    model.eval()
    os.makedirs(os.path.join(Config.OUTPUT_MASK_DIR, folder_basename), exist_ok=True)
    os.makedirs(os.path.join(Config.COLOR_MASK_DIR, folder_basename), exist_ok=True)

    for idx in tqdm(range(len(loader.dataset))):
        img_tensor, _, _, (original_h, original_w) = loader.dataset[idx]
        with torch.no_grad():
            img_tensor = img_tensor.to(Config.DEVICE).unsqueeze(0)
            preds = model(img_tensor)
            final_mask = torch.argmax(preds, dim=1).squeeze(0)

        pred_np = final_mask.cpu().numpy().astype(np.uint8)
        resized = cv2.resize(pred_np, (original_w, original_h), interpolation=cv2.INTER_NEAREST)

        fname = loader.dataset.images[idx]
        name = os.path.splitext(fname)[0]

        Image.fromarray(resized).save(os.path.join(Config.OUTPUT_MASK_DIR, folder_basename, name + "_mask.png"))
        rgb = mask_to_rgb(resized, Config.COLOR_MAP)
        Image.fromarray(rgb).save(os.path.join(Config.COLOR_MASK_DIR, folder_basename, name + "_mask.png"))
    model.train()

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader, desc="Training")
    total_loss = 0
    for batch_idx, (data, targets, dist, _) in enumerate(loop):
        data, targets, dist = data.to(Config.DEVICE), targets.to(Config.DEVICE), dist.to(Config.DEVICE)

        with torch.amp.autocast('cuda'):
            out = model(data)
            loss = loss_fn(out, targets, dist)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    return total_loss / len(loader)

def eval_fn(loader, model, loss_fn):
    model.eval()
    total_loss = 0
    loop = tqdm(loader, desc="Validation")
    with torch.no_grad():
        for data, targets, dist, _ in loop:
            data, targets, dist = data.to(Config.DEVICE), targets.to(Config.DEVICE), dist.to(Config.DEVICE)
            loss = loss_fn(model(data), targets, dist)
            total_loss += loss.item()
            loop.set_postfix(val_loss=loss.item())
    model.train()
    return total_loss / len(loader)


In [None]:
def main():
    print(f"Using device: {Config.DEVICE}")
    train_files, val_files, test_files = get_splits_v6_original(
        Config.IMAGE_DIR, Config.MASK_DIR, Config.DIST_WEIGHT_DIR,
        Config.SPLIT_CSV, Config.QC_REPORT_CSV
    )

    if not train_files:
        print("Error: No matching data found!")
        return

    train_ds = BoundaryDiceDataset(Config.IMAGE_DIR, Config.MASK_DIR, Config.DIST_WEIGHT_DIR, train_files, train_transform)
    val_ds = BoundaryDiceDataset(Config.IMAGE_DIR, Config.MASK_DIR, Config.DIST_WEIGHT_DIR, val_files, val_transform)
    test_ds = BoundaryDiceDataset(Config.IMAGE_DIR, Config.MASK_DIR, Config.DIST_WEIGHT_DIR, test_files, val_transform)

    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

    model = smp.create_model(
        arch=Config.ARCHITECTURE, encoder_name=Config.ENCODER,
        encoder_weights=Config.ENCODER_WEIGHTS, in_channels=3, classes=Config.NUM_CLASSES
    ).to(Config.DEVICE)

    loss_fn = BoundaryWeightedDiceLoss(class_weights=Config.CLASS_WEIGHTS, boundary_scale=Config.BOUNDARY_SCALE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scaler = torch.amp.GradScaler('cuda')
    best_val = float('inf')

    for epoch in range(Config.NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{Config.NUM_EPOCHS} ---")
        avg_train = train_fn(train_loader, model, optimizer, loss_fn, scaler)
        avg_val = eval_fn(val_loader, model, loss_fn)
        print(f"Train: {avg_train:.4f} | Val: {avg_val:.4f}")
        if avg_val < best_val:
            best_val = avg_val
            torch.save(model.state_dict(), os.path.join(Config.BASE_PATH, "best_model_v6_boundary.pth"))
            print("=> Saved new best model")

    print("\n--- Testing ---")
    model.load_state_dict(torch.load(os.path.join(Config.BASE_PATH, "best_model_v6_boundary.pth")))
    save_predictions_fn(test_loader, model, folder_basename="test_set")

if __name__ == "__main__":
    main()

##Performance Evaluation

In [None]:
!pip install -q monai pandas

import os
import cv2
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from monai.metrics import (
    compute_dice,
    compute_iou,
    compute_hausdorff_distance
)

# Ground Truth Masks
GT_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/masks/"

# Predicted Masks (From V4 Output)
PRED_MASK_DIR = "/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs_v6_final/pred_masks/test_set/"

# --- CLASS MAP ---
NUM_CLASSES = 5
CLASS_MAP = {
    0: "Background",
    1: "Stem",
    2: "Leaf",
    3: "Root",
    4: "Seed",
}

def to_one_hot(mask, num_classes):
    mask[mask >= num_classes] = 0

    one_hot = np.eye(num_classes)[mask]
    one_hot = np.transpose(one_hot, (2, 0, 1))
    return torch.from_numpy(one_hot).unsqueeze(0)

def run_analysis():
    print("Starting V4 opt Analysis...")
    print(f"GT Directory: {GT_MASK_DIR}")
    print(f"Pred Directory: {PRED_MASK_DIR}")

    results_list = []
    pred_files = [f for f in os.listdir(PRED_MASK_DIR) if f.endswith('.png')]

    if len(pred_files) == 0:
        print("Error: No prediction files found! Check your PRED_MASK_DIR path.")
        return

    for filename in tqdm(pred_files):
        pred_path = os.path.join(PRED_MASK_DIR, filename)
        gt_path = os.path.join(GT_MASK_DIR, filename.replace(".png", "_mask.png"))

        if not os.path.exists(gt_path):
            print(f"Skipping {filename}: GT mask not found.")
            continue

        # Load Masks
        gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
        pred_mask = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)

        if gt_mask is None or pred_mask is None:
            continue

        # Resize GT if dimensions don't match (Safety check for unsqueezing issues)
        if gt_mask.shape != pred_mask.shape:
            gt_mask = cv2.resize(gt_mask, (pred_mask.shape[1], pred_mask.shape[0]), interpolation=cv2.INTER_NEAREST)

        # Convert to One-Hot
        gt_onehot = to_one_hot(gt_mask, NUM_CLASSES)
        pred_onehot = to_one_hot(pred_mask, NUM_CLASSES)

        # --- Calculate Metrics ---
        dice = compute_dice(pred_onehot, gt_onehot, include_background=True)
        iou = compute_iou(pred_onehot, gt_onehot, include_background=True)
        hd95 = compute_hausdorff_distance(pred_onehot, gt_onehot, include_background=True, percentile=95)

        # Store results
        file_metrics = {'filename': filename}
        for i in range(NUM_CLASSES):
            c_name = CLASS_MAP[i]
            file_metrics[f"{c_name}_Dice"] = dice[0, i].item()
            file_metrics[f"{c_name}_IOU"] = iou[0, i].item()
            file_metrics[f"{c_name}_HD95"] = hd95[0, i].item()

        results_list.append(file_metrics)

    if not results_list:
        print("No results generated.")
        return

    df = pd.DataFrame(results_list)

    # Calculate Averages
    overall_stats = df.mean(numeric_only=True)

    print("\n\n--- Overall Average Statistics (Test Set) Distance weights ---")
    summary_data = []
    for i in range(NUM_CLASSES):
        c_name = CLASS_MAP[i]
        summary_data.append({
            "Class": c_name,
            "Dice (↑)": overall_stats.get(f"{c_name}_Dice"),
            "IOU (↑)": overall_stats.get(f"{c_name}_IOU"),
            "HD95 (↓)": overall_stats.get(f"{c_name}_HD95"),
        })

    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_markdown(index=False, floatfmt=".4f"))

    # Optional: Save to CSV for comparison later
    summary_df.to_csv("/content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/v6_results_summary.csv", index=False)
    print("\nSummary saved to v6_results_summary.csv")

if __name__ == "__main__":
    run_analysis()

Starting V4 opt Analysis...
GT Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/masks/
Pred Directory: /content/drive/MyDrive/Colab Notebooks/phenocyte_seg/phenocyte_seg/outputs_v6_final/pred_masks/test_set/


100%|██████████| 47/47 [00:04<00:00, 11.51it/s]




--- V6 opt Overall Average Statistics (Test Set) Distance weights ---
| Class      |   Dice (↑) |   IOU (↑) |   HD95 (↓) |
|:-----------|-----------:|----------:|-----------:|
| Background |     0.9947 |    0.9895 |     9.4172 |
| Stem       |     0.7515 |    0.6148 |     7.0896 |
| Leaf       |     0.8489 |    0.7431 |    10.1932 |
| Root       |     0.8050 |    0.6865 |    28.8622 |
| Seed       |     0.6207 |    0.5206 |     9.2642 |

Summary saved to v6_results_summary.csv
