In [2]:
!pip install -U torch torchvision
!pip install -U segmentation-models-pytorch timm
!pip install -U albumentations opencv-python
!pip install -U numpy tqdm matplotlib

Collecting torch
  Downloading torch-2.9.1-cp313-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting torchvision
  Downloading torchvision-0.24.1-cp313-cp313-macosx_12_0_arm64.whl.metadata (5.9 kB)
Downloading torch-2.9.1-cp313-none-macosx_11_0_arm64.whl (74.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.5/74.5 MB[0m [31m3.5 MB/s[0m  [33m0:00:21[0mm0:00:01[0m00:01[0m
[?25hDownloading torchvision-0.24.1-cp313-cp313-macosx_12_0_arm64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m3.6 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: torch, torchvision
[2K  Attempting uninstall: torch
[2K    Found existing installation: torch 2.7.0
[2K    Uninstalling torch-2.7.0:━━━━━━━━━━━━━━━━━━━[0m [32m0/2[0m [torch]
[2K      Successfully uninstalled torch-2.7.0━━[0m [32m0/2[0m [torch]
[2K  Attempting uninstall: torchvision━━━━━━━━━━━━━[0m [32m0/2[0m [torch]
[2K    Found

In [1]:
import numpy as np
np.array([1, 2, 3])

array([1, 2, 3])

In [None]:
import os
import cv2
import numpy as np
from glob import glob
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp


# ============================================================
# 0) CONFIG
# ============================================================
BASE_DIR = "/Users/akhilgattu/Desktop/VLM_project/"  # change if needed
SEG_DIR = os.path.join(BASE_DIR, "A. Segmentation")

# Images (jpg)
TRAIN_IMG_DIR = os.path.join(SEG_DIR, "1. Original Images", "a. Training Set")

# GT root has lesion subfolders
TRAIN_GT_ROOT = os.path.join(SEG_DIR, "2. All Segmentation Groundtruths", "a. Training Set")

# Output dataset
OUT_DIR = "data_idrid_multiclass"
OUT_TRAIN_IMG = os.path.join(OUT_DIR, "train", "images")
OUT_TRAIN_MSK = os.path.join(OUT_DIR, "train", "masks")
OUT_VAL_IMG   = os.path.join(OUT_DIR, "val", "images")
OUT_VAL_MSK   = os.path.join(OUT_DIR, "val", "masks")

DEBUG_DIR = "debug_dataset_preview"
CKPT_PATH = "checkpoints/unetpp_effb3_idrid_5class.pth"

IMG_SIZE = 512
BATCH_SIZE = 4
EPOCHS = 40
LR = 3e-4
VAL_RATIO = 0.2
SEED = 42

NUM_CLASSES = 6  # 0..5 (0 is BG)


# ============================================================
# 1) IDRiD CLASS MAP + GT SUBFOLDERS
# ============================================================
# Your folder names (exactly as in your screenshot)
GT_SUBFOLDERS = {
    "MA": "1. Microaneurysms",
    "HE": "2. Haemorrhages",
    "EX": "3. Hard Exudates",
    "SE": "4. Soft Exudates",
    "OD": "5. Optic Disc",
}

CLASS_TO_ID = {
    "MA": 1,
    "HE": 2,
    "EX": 3,
    "SE": 4,
    "OD": 5,
}

ID_TO_CLASS = {
    0: "BG",
    1: "MA",
    2: "HE",
    3: "EX",
    4: "SE",
    5: "OD",
}

# Colors for debug overlays (BGR)
COLORS = {
    0: (0, 0, 0),
    1: (0, 0, 255),       # MA red
    2: (0, 165, 255),     # HE orange
    3: (0, 255, 255),     # EX yellow
    4: (255, 0, 255),     # SE magenta
    5: (255, 255, 0),     # OD cyan
}


# ============================================================
# 2) UTILS
# ============================================================
def ensure_dirs():
    os.makedirs(OUT_TRAIN_IMG, exist_ok=True)
    os.makedirs(OUT_TRAIN_MSK, exist_ok=True)
    os.makedirs(OUT_VAL_IMG, exist_ok=True)
    os.makedirs(OUT_VAL_MSK, exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs(DEBUG_DIR, exist_ok=True)


def list_jpg_images(folder):
    files = glob(os.path.join(folder, "*.jpg"))
    files.sort()
    return files


def find_mask_path(gt_root, base_name, lesion_code):
    """
    Given base_name like: IDRiD_01
    Return path like:
      TRAIN_GT_ROOT / "1. Microaneurysms" / "IDRiD_01_MA.tif"
    """
    sub = GT_SUBFOLDERS[lesion_code]
    folder = os.path.join(gt_root, sub)

    # official naming is usually base_LESION.tif
    p = os.path.join(folder, f"{base_name}_{lesion_code}.tif")
    if os.path.exists(p):
        return p

    return None


def build_multiclass_mask(image_hw, gt_root, base_name):
    """
    Merge all lesion binary masks into one multiclass mask (HxW).
    Priority high overwrites low:
      MA > HE > EX > SE > OD
    """
    H, W = image_hw
    mask = np.zeros((H, W), dtype=np.uint8)

    # low->high so higher overwrites later
    priority = ["OD", "SE", "EX", "HE", "MA"]

    for lesion in priority:
        mp = find_mask_path(gt_root, base_name, lesion)
        if mp is None:
            continue

        m = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
        if m is None:
            continue

        # IDRiD masks are usually 0 / 255
        m_bin = (m > 0).astype(np.uint8)
        cls_id = CLASS_TO_ID[lesion]
        mask[m_bin == 1] = cls_id

    return mask


def colorize_mask(mask):
    h, w = mask.shape
    out = np.zeros((h, w, 3), dtype=np.uint8)
    for cid, bgr in COLORS.items():
        out[mask == cid] = bgr
    return out


def overlay(image_bgr, color_mask_bgr, alpha=0.45):
    return cv2.addWeighted(image_bgr, 1 - alpha, color_mask_bgr, alpha, 0)


def sanity_check_masks(masks_dir, n=20):
    msks = glob(os.path.join(masks_dir, "*.png"))
    msks.sort()
    msks = msks[:n]

    print(f"\nSanity check -> {masks_dir} (showing {len(msks)})")
    for mp in msks:
        m = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
        if m is None:
            print("Cannot read:", mp)
            continue
        uniq = np.unique(m)
        nonzero = int((m > 0).sum())
        print(os.path.basename(mp), "| unique:", uniq, "| nonzero_pixels:", nonzero)


def debug_preview(images_dir, masks_dir, n=8):
    imgs = list_jpg_images(images_dir)[:n]
    print(f"\nSaving {len(imgs)} debug previews in: {DEBUG_DIR}")

    for ip in imgs:
        base = os.path.splitext(os.path.basename(ip))[0]
        mp = os.path.join(masks_dir, f"{base}.png")

        img = cv2.imread(ip, cv2.IMREAD_COLOR)
        msk = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)

        if img is None or msk is None:
            continue

        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        msk = cv2.resize(msk, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

        col = colorize_mask(msk)
        over = overlay(img, col, alpha=0.45)

        cv2.imwrite(os.path.join(DEBUG_DIR, f"{base}_gt_overlay.png"), over)


def prepare_train_val(force_rebuild=False):
    """
    Create OUT_DIR train/val with masks merged into multiclass png.
    """
    if not force_rebuild:
        if len(os.listdir(OUT_TRAIN_MSK)) > 0 and len(os.listdir(OUT_VAL_MSK)) > 0:
            print("Prepared dataset already exists. Skipping rebuild.")
            return

    # optionally wipe old
    if force_rebuild:
        for p in [OUT_TRAIN_IMG, OUT_TRAIN_MSK, OUT_VAL_IMG, OUT_VAL_MSK]:
            for f in glob(os.path.join(p, "*")):
                os.remove(f)

    img_paths = list_jpg_images(TRAIN_IMG_DIR)
    if len(img_paths) == 0:
        raise RuntimeError(f"No .jpg images found in: {TRAIN_IMG_DIR}")

    # split
    rng = np.random.default_rng(SEED)
    idx = np.arange(len(img_paths))
    rng.shuffle(idx)

    n_val = int(len(img_paths) * VAL_RATIO)
    val_idx = idx[:n_val]
    tr_idx = idx[n_val:]

    train_paths = [img_paths[i] for i in tr_idx]
    val_paths   = [img_paths[i] for i in val_idx]

    print(f"Total: {len(img_paths)} | Train: {len(train_paths)} | Val: {len(val_paths)}")

    def write_one(ip, out_img_dir, out_msk_dir):
        img = cv2.imread(ip, cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {ip}")

        H, W = img.shape[:2]
        base = os.path.splitext(os.path.basename(ip))[0]

        msk = build_multiclass_mask((H, W), TRAIN_GT_ROOT, base)

        cv2.imwrite(os.path.join(out_img_dir, os.path.basename(ip)), img)
        cv2.imwrite(os.path.join(out_msk_dir, f"{base}.png"), msk)

    for ip in tqdm(train_paths, desc="Preparing TRAIN"):
        write_one(ip, OUT_TRAIN_IMG, OUT_TRAIN_MSK)

    for ip in tqdm(val_paths, desc="Preparing VAL"):
        write_one(ip, OUT_VAL_IMG, OUT_VAL_MSK)


# ============================================================
# 3) DATASET + AUGS
# ============================================================
class MultiClassSegDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images = list_jpg_images(images_dir)
        self.masks = []
        for ip in self.images:
            base = os.path.splitext(os.path.basename(ip))[0]
            self.masks.append(os.path.join(masks_dir, f"{base}.png"))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        ip = self.images[idx]
        mp = self.masks[idx]

        img = cv2.imread(ip, cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {ip}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Missing mask: {mp}")

        if self.transform:
            out = self.transform(image=img, mask=mask)
            img = out["image"]
            mask = out["mask"]

        mask = torch.as_tensor(mask, dtype=torch.long)
        return img, mask


def get_train_tfms(sz=512):
    return A.Compose([
        A.Resize(sz, sz),
        A.RandomBrightnessContrast(p=0.5),
        A.CLAHE(p=0.3),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.ShiftScaleRotate(
            shift_limit=0.05, scale_limit=0.10, rotate_limit=20,
            border_mode=cv2.BORDER_CONSTANT,
            value=0, mask_value=0, p=0.5
        ),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])


def get_val_tfms(sz=512):
    return A.Compose([
        A.Resize(sz, sz),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])


# ============================================================
# 4) LOSS
# ============================================================
class DiceLossMulticlass(nn.Module):
    def __init__(self, num_classes, smooth=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.softmax(logits, dim=1)
        targets_1h = F.one_hot(targets, num_classes=self.num_classes).permute(0, 3, 1, 2).float()

        dims = (0, 2, 3)
        intersection = torch.sum(probs * targets_1h, dims)
        union = torch.sum(probs, dims) + torch.sum(targets_1h, dims)

        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()


class FocalLossMulticlass(nn.Module):
    def __init__(self, gamma=2.0, alpha=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha  # tensor [C]

    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, weight=self.alpha, reduction="none")
        pt = torch.exp(-ce)
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()


class DiceFocalLoss(nn.Module):
    def __init__(self, num_classes, dice_w=0.5, focal_w=0.5, gamma=2.0, alpha=None):
        super().__init__()
        self.dice = DiceLossMulticlass(num_classes)
        self.focal = FocalLossMulticlass(gamma=gamma, alpha=alpha)
        self.dice_w = dice_w
        self.focal_w = focal_w

    def forward(self, logits, targets):
        return self.dice_w * self.dice(logits, targets) + self.focal_w * self.focal(logits, targets)


@torch.no_grad()
def mean_dice_no_bg(logits, targets, num_classes):
    preds = torch.argmax(logits, dim=1)
    dices = []
    for c in range(1, num_classes):
        p = (preds == c).float()
        t = (targets == c).float()
        inter = (p * t).sum()
        denom = p.sum() + t.sum()
        d = (2 * inter + 1.0) / (denom + 1.0)
        dices.append(d)
    return torch.stack(dices).mean().item()


# ============================================================
# 5) TRAIN
# ============================================================
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def train():
    device = get_device()
    print("Device:", device)

    train_ds = MultiClassSegDataset(OUT_TRAIN_IMG, OUT_TRAIN_MSK, transform=get_train_tfms(IMG_SIZE))
    val_ds   = MultiClassSegDataset(OUT_VAL_IMG,   OUT_VAL_MSK,   transform=get_val_tfms(IMG_SIZE))

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    model = smp.UnetPlusPlus(
        encoder_name="timm-efficientnet-b3",
        encoder_weights="imagenet",
        in_channels=3,
        classes=NUM_CLASSES,
        activation=None
    ).to(device)

    # Strong class weights (BG extremely downweighted)
    alpha = torch.tensor([0.03, 2.5, 2.5, 2.5, 2.5, 1.8], dtype=torch.float32).to(device)

    criterion = DiceFocalLoss(
        num_classes=NUM_CLASSES,
        dice_w=0.5,
        focal_w=0.5,
        gamma=2.0,
        alpha=alpha
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

    best_dice = -1e9

    for ep in range(1, EPOCHS + 1):
        # ---------------- TRAIN ----------------
        model.train()
        tr_loss = 0.0

        for imgs, masks in tqdm(train_loader, desc=f"Train {ep}/{EPOCHS}", leave=False):
            imgs = imgs.to(device)
            masks = masks.to(device)

            optimizer.zero_grad(set_to_none=True)
            logits = model(imgs)
            loss = criterion(logits, masks)
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()

        tr_loss /= max(1, len(train_loader))

        # ---------------- VAL ----------------
        model.eval()
        va_loss = 0.0
        va_dice = 0.0

        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc=f"Val {ep}/{EPOCHS}", leave=False):
                imgs = imgs.to(device)
                masks = masks.to(device)

                logits = model(imgs)
                loss = criterion(logits, masks)

                va_loss += loss.item()
                va_dice += mean_dice_no_bg(logits, masks, NUM_CLASSES)

        va_loss /= max(1, len(val_loader))
        va_dice /= max(1, len(val_loader))

        print(f"Epoch {ep:02d} | TrainLoss={tr_loss:.4f} | ValLoss={va_loss:.4f} | ValDice(no-bg)={va_dice:.4f}")

        if va_dice > best_dice:
            best_dice = va_dice
            torch.save(model.state_dict(), CKPT_PATH)
            print(f"Saved BEST -> {CKPT_PATH} (dice={best_dice:.4f})")


# ============================================================
# RUN
# ============================================================
if __name__ == "__main__":
    ensure_dirs()

    # Prepare dataset (train/val from training set)
    prepare_train_val(force_rebuild=True)

    # Sanity check (should show non-zero pixels)
    sanity_check_masks(OUT_TRAIN_MSK, n=10)
    sanity_check_masks(OUT_VAL_MSK, n=10)

    # Save GT overlay previews so you visually confirm masks are correct
    debug_preview(OUT_TRAIN_IMG, OUT_TRAIN_MSK, n=8)
    debug_preview(OUT_VAL_IMG, OUT_VAL_MSK, n=8)

    print("\nStarting training...\n")
    train()


Total: 54 | Train: 44 | Val: 10


Preparing TRAIN: 100%|██████████| 44/44 [00:03<00:00, 11.83it/s]
Preparing VAL: 100%|██████████| 10/10 [00:00<00:00, 11.20it/s]



Sanity check -> data_idrid_multiclass/train/masks (showing 10)
IDRiD_01.png | unique: [0 1 2 3 5] | nonzero_pixels: 369508
IDRiD_02.png | unique: [0 1 2 3 5] | nonzero_pixels: 365750
IDRiD_03.png | unique: [0 1 2 3 4 5] | nonzero_pixels: 717840
IDRiD_04.png | unique: [0 1 2 3 5] | nonzero_pixels: 249335
IDRiD_05.png | unique: [0 1 2 3 5] | nonzero_pixels: 238343
IDRiD_07.png | unique: [0 1 2 3 5] | nonzero_pixels: 417478
IDRiD_09.png | unique: [0 1 2 3 5] | nonzero_pixels: 633525
IDRiD_10.png | unique: [0 1 2 3 5] | nonzero_pixels: 1146350
IDRiD_11.png | unique: [0 1 2 3 5] | nonzero_pixels: 572246
IDRiD_12.png | unique: [0 1 2 3 5] | nonzero_pixels: 374786

Sanity check -> data_idrid_multiclass/val/masks (showing 10)
IDRiD_06.png | unique: [0 1 2 3 5] | nonzero_pixels: 366131
IDRiD_08.png | unique: [0 1 2 3 4 5] | nonzero_pixels: 305854
IDRiD_18.png | unique: [0 1 2 3 4 5] | nonzero_pixels: 491177
IDRiD_24.png | unique: [0 1 2 3 5] | nonzero_pixels: 247043
IDRiD_28.png | unique: [0 1

  A.ShiftScaleRotate(
                                                           

Epoch 01 | TrainLoss=0.5237 | ValLoss=0.4904 | ValDice(no-bg)=0.0372
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.0372)


                                                           

Epoch 02 | TrainLoss=0.4818 | ValLoss=0.4696 | ValDice(no-bg)=0.0785
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.0785)


                                                           

Epoch 03 | TrainLoss=0.4612 | ValLoss=0.4514 | ValDice(no-bg)=0.1260
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.1260)


                                                           

Epoch 04 | TrainLoss=0.4463 | ValLoss=0.4352 | ValDice(no-bg)=0.1755
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.1755)


                                                           

Epoch 05 | TrainLoss=0.4318 | ValLoss=0.4282 | ValDice(no-bg)=0.1790
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.1790)


                                                           

Epoch 06 | TrainLoss=0.4234 | ValLoss=0.4200 | ValDice(no-bg)=0.1551


                                                           

Epoch 07 | TrainLoss=0.4117 | ValLoss=0.4090 | ValDice(no-bg)=0.1891
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.1891)


                                                           

Epoch 08 | TrainLoss=0.4019 | ValLoss=0.4051 | ValDice(no-bg)=0.2016
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.2016)


                                                           

Epoch 09 | TrainLoss=0.3912 | ValLoss=0.3922 | ValDice(no-bg)=0.2143
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.2143)


                                                            

Epoch 10 | TrainLoss=0.3838 | ValLoss=0.3852 | ValDice(no-bg)=0.2213
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.2213)


                                                            

Epoch 11 | TrainLoss=0.3738 | ValLoss=0.3762 | ValDice(no-bg)=0.2766
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.2766)


                                                            

Epoch 12 | TrainLoss=0.3628 | ValLoss=0.3666 | ValDice(no-bg)=0.2948
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.2948)


                                                            

Epoch 13 | TrainLoss=0.3534 | ValLoss=0.3576 | ValDice(no-bg)=0.2976
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.2976)


                                                            

Epoch 14 | TrainLoss=0.3429 | ValLoss=0.3504 | ValDice(no-bg)=0.3280
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.3280)


                                                            

Epoch 15 | TrainLoss=0.3329 | ValLoss=0.3445 | ValDice(no-bg)=0.3455
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.3455)


                                                            

Epoch 16 | TrainLoss=0.3199 | ValLoss=0.3316 | ValDice(no-bg)=0.3919
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.3919)


                                                            

Epoch 17 | TrainLoss=0.3092 | ValLoss=0.3206 | ValDice(no-bg)=0.4286
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.4286)


                                                            

Epoch 18 | TrainLoss=0.2968 | ValLoss=0.3123 | ValDice(no-bg)=0.4319
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.4319)


                                                            

Epoch 19 | TrainLoss=0.2778 | ValLoss=0.3020 | ValDice(no-bg)=0.4556
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.4556)


                                                            

Epoch 20 | TrainLoss=0.2681 | ValLoss=0.2939 | ValDice(no-bg)=0.4199


                                                            

Epoch 21 | TrainLoss=0.2623 | ValLoss=0.2786 | ValDice(no-bg)=0.4751
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.4751)


                                                            

Epoch 22 | TrainLoss=0.2540 | ValLoss=0.2729 | ValDice(no-bg)=0.4365


                                                            

Epoch 23 | TrainLoss=0.2496 | ValLoss=0.2812 | ValDice(no-bg)=0.4521


                                                            

Epoch 24 | TrainLoss=0.2569 | ValLoss=0.2735 | ValDice(no-bg)=0.4628


                                                            

Epoch 25 | TrainLoss=0.2528 | ValLoss=0.2785 | ValDice(no-bg)=0.4262


                                                            

Epoch 26 | TrainLoss=0.2355 | ValLoss=0.2640 | ValDice(no-bg)=0.4812
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.4812)


                                                            

Epoch 27 | TrainLoss=0.2275 | ValLoss=0.2570 | ValDice(no-bg)=0.4954
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.4954)


                                                            

Epoch 28 | TrainLoss=0.2258 | ValLoss=0.2626 | ValDice(no-bg)=0.4733


                                                            

Epoch 29 | TrainLoss=0.2369 | ValLoss=0.2626 | ValDice(no-bg)=0.4703


                                                            

Epoch 30 | TrainLoss=0.2171 | ValLoss=0.2518 | ValDice(no-bg)=0.5003
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.5003)


                                                            

Epoch 31 | TrainLoss=0.2203 | ValLoss=0.2507 | ValDice(no-bg)=0.4981


                                                            

Epoch 32 | TrainLoss=0.2143 | ValLoss=0.2487 | ValDice(no-bg)=0.5024
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.5024)


                                                            

Epoch 33 | TrainLoss=0.2176 | ValLoss=0.2475 | ValDice(no-bg)=0.5077
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.5077)


                                                            

Epoch 34 | TrainLoss=0.2053 | ValLoss=0.2543 | ValDice(no-bg)=0.5050


                                                            

Epoch 35 | TrainLoss=0.2084 | ValLoss=0.2478 | ValDice(no-bg)=0.5104
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.5104)


                                                            

Epoch 36 | TrainLoss=0.2009 | ValLoss=0.2565 | ValDice(no-bg)=0.5026


                                                            

Epoch 37 | TrainLoss=0.2093 | ValLoss=0.2528 | ValDice(no-bg)=0.5241
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.5241)


                                                            

Epoch 38 | TrainLoss=0.2081 | ValLoss=0.2436 | ValDice(no-bg)=0.5204


                                                            

Epoch 39 | TrainLoss=0.2165 | ValLoss=0.2363 | ValDice(no-bg)=0.5385
Saved BEST -> checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.5385)


                                                            

Epoch 40 | TrainLoss=0.2169 | ValLoss=0.2431 | ValDice(no-bg)=0.5137




In [19]:
!pwd

/Users/akhilgattu/Desktop/VLM_project


In [9]:
import os
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2


# --------------------------
# CONFIG
# --------------------------
CKPT_PATH = "/Users/akhilgattu/Desktop/VLM_project/checkpoints/unetpp_effb3_idrid_5class.pth"
IMG_DIR   = "/Users/akhilgattu/Desktop/VLM_project/A. Segmentation/1. Original Images/b. Testing Set"

OUT_DIR = "visual_results"
IMG_SIZE = 512
NUM_CLASSES = 6  # 0..5


# --------------------------
# CLASS COLORS (BGR)
# --------------------------
COLORS = {
    0: (0, 0, 0),         # BG - black
    1: (0, 0, 255),       # MA - red
    2: (0, 165, 255),     # HE - orange
    3: (0, 255, 255),     # EX - yellow
    4: (255, 0, 255),     # SE - magenta
    5: (255, 255, 0),     # OD - cyan
}

CLASS_NAMES = {
    0: "BG",
    1: "MA",
    2: "HE",
    3: "EX",
    4: "SE",
    5: "OD",
}


# --------------------------
# TRANSFORM
# --------------------------
def get_tfms(img_size=512):
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])


# --------------------------
# UTILS
# --------------------------
def list_images(folder):
    exts = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
    files = [f for f in os.listdir(folder) if f.lower().endswith(exts)]
    files.sort()
    return [os.path.join(folder, f) for f in files]


def colorize_mask(mask):
    """mask: HxW int (0..NUM_CLASSES-1)"""
    h, w = mask.shape
    color = np.zeros((h, w, 3), dtype=np.uint8)
    for k, bgr in COLORS.items():
        color[mask == k] = bgr
    return color


def overlay(image_bgr, color_mask_bgr, alpha=0.45):
    return cv2.addWeighted(image_bgr, 1 - alpha, color_mask_bgr, alpha, 0)


def draw_boundaries(base_bgr, mask, thickness=2):
    """
    Draw class boundaries so segmentation pops.
    """
    out = base_bgr.copy()
    for cls_id in range(1, NUM_CLASSES):  # skip BG
        bin_mask = (mask == cls_id).astype(np.uint8) * 255
        contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(out, contours, -1, (255, 255, 255), thickness)
    return out


def put_legend(img_bgr, mask, top_k=6):
    """
    Add legend with class names + pixel %.
    """
    out = img_bgr.copy()
    h, w = out.shape[:2]

    # compute class stats
    total = h * w
    counts = [(cid, int((mask == cid).sum())) for cid in range(NUM_CLASSES)]
    counts.sort(key=lambda x: x[1], reverse=True)

    x0, y0 = 15, 25
    line_h = 22

    cv2.putText(out, "Legend:", (x0, y0),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)

    y = y0 + 25

    # show top_k classes excluding BG if you want
    shown = 0
    for cid, c in counts:
        if shown >= top_k:
            break

        pct = (c / total) * 100.0
        name = CLASS_NAMES.get(cid, str(cid))
        bgr = COLORS.get(cid, (255, 255, 255))

        # color box
        cv2.rectangle(out, (x0, y - 14), (x0 + 16, y + 2), bgr, -1)

        # text
        text = f"{name}: {pct:.2f}%"
        cv2.putText(out, text, (x0 + 25, y),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)

        y += line_h
        shown += 1

    return out


def label_regions(img_bgr, mask):
    """
    Put class labels at the center of each segmented blob.
    """
    out = img_bgr.copy()

    for cls_id in range(1, NUM_CLASSES):  # skip BG
        bin_mask = (mask == cls_id).astype(np.uint8) * 255

        # clean tiny noise
        bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))

        contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        for cnt in contours:
            area = cv2.contourArea(cnt)
            if area < 200:  # ignore tiny blobs
                continue

            M = cv2.moments(cnt)
            if M["m00"] == 0:
                continue
            cx = int(M["m10"] / M["m00"])
            cy = int(M["m01"] / M["m00"])

            label = CLASS_NAMES.get(cls_id, str(cls_id))

            # shadow text for visibility
            cv2.putText(out, label, (cx, cy),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 4, cv2.LINE_AA)
            cv2.putText(out, label, (cx, cy),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

    return out


# --------------------------
# MAIN
# --------------------------
def main():
    os.makedirs(OUT_DIR, exist_ok=True)

    # Correct device selection for Mac MPS
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    print("Device:", device)

    model = smp.UnetPlusPlus(
        encoder_name="timm-efficientnet-b3",
        encoder_weights=None,
        in_channels=3,
        classes=NUM_CLASSES,
        activation=None,
    ).to(device)

    ckpt = torch.load(CKPT_PATH, map_location=device)

    # If checkpoint is {"model": state_dict} or {"state_dict": ...}
    if isinstance(ckpt, dict) and "state_dict" in ckpt:
        state_dict = ckpt["state_dict"]
    elif isinstance(ckpt, dict) and "model" in ckpt:
        state_dict = ckpt["model"]
    else:
        state_dict = ckpt

    # remove "module." if trained with DataParallel
    new_state = {}
    for k, v in state_dict.items():
        nk = k.replace("module.", "")
        new_state[nk] = v

    model.load_state_dict(new_state, strict=True)
    model.eval()

    tfms = get_tfms(IMG_SIZE)

    img_paths = list_images(IMG_DIR)
    if len(img_paths) == 0:
        raise RuntimeError(f"No images found in {IMG_DIR}")

    print(f"Found {len(img_paths)} images.")

    for ip in img_paths:
        fname = os.path.splitext(os.path.basename(ip))[0]

        img_bgr = cv2.imread(ip, cv2.IMREAD_COLOR)
        if img_bgr is None:
            print(f"Skipping (cannot read): {ip}")
            continue

        img_bgr = cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE))
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

        out = tfms(image=img_rgb)
        x = out["image"].unsqueeze(0).to(device)

        with torch.no_grad():
            logits = model(x)  # [1,C,H,W]
            pred = torch.argmax(logits, dim=1)  # [1,H,W]
            pred_mask = pred.squeeze(0).detach().cpu().numpy().astype(np.uint8)

        pred_color = colorize_mask(pred_mask)
        pred_overlay = overlay(img_bgr, pred_color, alpha=0.45)

        # make it visually strong
        pred_overlay = draw_boundaries(pred_overlay, pred_mask, thickness=2)
        pred_overlay = label_regions(pred_overlay, pred_mask)
        pred_overlay = put_legend(pred_overlay, pred_mask)

        # Save outputs
        # Best format for raw mask is PNG (lossless)
        cv2.imwrite(os.path.join(OUT_DIR, f"{fname}_mask.png"), pred_mask)

        # Color mask
        cv2.imwrite(os.path.join(OUT_DIR, f"{fname}_color.png"), pred_color)

        # Overlay with labels
        cv2.imwrite(os.path.join(OUT_DIR, f"{fname}_overlay_labeled.png"), pred_overlay)

    print(f"\nSaved results in: {OUT_DIR}")


if __name__ == "__main__":
    main()


Device: mps
Found 27 images.

Saved results in: visual_results
