In [2]:
!pip install -U torch torchvision
!pip install -U segmentation-models-pytorch timm
!pip install -U albumentations opencv-python
!pip install -U numpy tqdm matplotlib

Collecting torch
  Downloading torch-2.9.1-cp313-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting torchvision
  Downloading torchvision-0.24.1-cp313-cp313-macosx_12_0_arm64.whl.metadata (5.9 kB)
Downloading torch-2.9.1-cp313-none-macosx_11_0_arm64.whl (74.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.5/74.5 MB[0m [31m3.5 MB/s[0m  [33m0:00:21[0mm0:00:01[0m00:01[0m
[?25hDownloading torchvision-0.24.1-cp313-cp313-macosx_12_0_arm64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m3.6 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: torch, torchvision
[2K  Attempting uninstall: torch
[2K    Found existing installation: torch 2.7.0
[2K    Uninstalling torch-2.7.0:━━━━━━━━━━━━━━━━━━━[0m [32m0/2[0m [torch]
[2K      Successfully uninstalled torch-2.7.0━━[0m [32m0/2[0m [torch]
[2K  Attempting uninstall: torchvision━━━━━━━━━━━━━[0m [32m0/2[0m [torch]
[2K    Found

In [24]:
import os
import cv2
import numpy as np
from glob import glob
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp


# =========================
# 0) SET YOUR BASE DIR
# =========================
BASE_DIR = "/Users/akhilgattu/Desktop/VLM_project/"  # change if needed

SEG_DIR = os.path.join(BASE_DIR, "A. Segmentation")

TRAIN_IMG_DIR = os.path.join(SEG_DIR, "1. Original Images", "a. Training Set")
TEST_IMG_DIR  = os.path.join(SEG_DIR, "1. Original Images", "b. Testing Set")

TRAIN_GT_DIR = os.path.join(SEG_DIR, "2. All Segmentation Groundtruths", "a. Training Set")
TEST_GT_DIR  = os.path.join(SEG_DIR, "2. All Segmentation Groundtruths", "b. Testing Set")

# output prepared dataset
OUT_DIR = "data_idrid_multiclass"
OUT_TRAIN_IMG = os.path.join(OUT_DIR, "train", "images")
OUT_TRAIN_MSK = os.path.join(OUT_DIR, "train", "masks")
OUT_VAL_IMG   = os.path.join(OUT_DIR, "val", "images")
OUT_VAL_MSK   = os.path.join(OUT_DIR, "val", "masks")


# =========================
# 1) CLASS MAP (IDRiD)
# =========================
CLASS_TO_ID = {
    "MA": 1,
    "HE": 2,
    "EX": 3,
    "SE": 4,
    "OD": 5,
}
NUM_CLASSES = 6  # including background 0


# =========================
# 2) HELPERS
# =========================
def ensure_dirs():
    os.makedirs(OUT_TRAIN_IMG, exist_ok=True)
    os.makedirs(OUT_TRAIN_MSK, exist_ok=True)
    os.makedirs(OUT_VAL_IMG, exist_ok=True)
    os.makedirs(OUT_VAL_MSK, exist_ok=True)


def list_images(folder):
    exts = ["*.jpg", "*.png", "*.jpeg", "*.tif", "*.tiff"]
    files = []
    for e in exts:
        files += glob(os.path.join(folder, e))
    files.sort()
    return files


def find_mask(gt_root, base_name, lesion_code):
    """
    IDRiD GT names look like:
      IDRiD_01_MA.tif
      IDRiD_01_HE.tif
      IDRiD_01_EX.tif
      IDRiD_01_SE.tif
      IDRiD_01_OD.tif

    base_name: "IDRiD_01"
    lesion_code: "MA", "HE", "EX", "SE", "OD"
    """
    patterns = [
        os.path.join(gt_root, f"{base_name}_{lesion_code}.tif"),
        os.path.join(gt_root, f"{base_name}_{lesion_code}.png"),
        os.path.join(gt_root, f"{base_name}_{lesion_code}.jpg"),
        os.path.join(gt_root, f"{base_name}_{lesion_code}.jpeg"),
        os.path.join(gt_root, f"{base_name}_{lesion_code}.tiff"),
    ]
    for p in patterns:
        if os.path.exists(p):
            return p
    return None


def build_multiclass_mask(image_shape_hw, gt_root, base_name):
    """
    Merge all binary lesion masks into ONE multiclass mask.
    Priority order (if overlap happens):
       MA > HE > EX > SE > OD
    (you can change it; overlaps are usually minimal)
    """
    H, W = image_shape_hw
    mask = np.zeros((H, W), dtype=np.uint8)

    priority = ["OD", "SE", "EX", "HE", "MA"]  # low->high, so MA overwrites others

    for lesion in priority:
        mp = find_mask(gt_root, base_name, lesion)
        if mp is None:
            continue

        m = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
        if m is None:
            continue

        # binarize (IDRiD masks are 0/255 typically)
        m_bin = (m > 0).astype(np.uint8)
        cls_id = CLASS_TO_ID[lesion]

        # assign this class where lesion exists
        mask[m_bin == 1] = cls_id

    return mask


def prepare_split(images_dir, gts_dir, out_img_dir, out_msk_dir):
    img_paths = list_images(images_dir)
    if len(img_paths) == 0:
        raise RuntimeError(f"No images found in: {images_dir}")

    print(f"Found {len(img_paths)} images in {images_dir}")

    for ip in tqdm(img_paths, desc=f"Preparing {os.path.basename(images_dir)}"):
        img = cv2.imread(ip, cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {ip}")

        H, W = img.shape[:2]
        base = os.path.splitext(os.path.basename(ip))[0]  # e.g. IDRiD_01

        multiclass_mask = build_multiclass_mask((H, W), gts_dir, base)

        # save image as PNG/JPG (keep original extension if you want)
        out_img_path = os.path.join(out_img_dir, os.path.basename(ip))
        out_msk_path = os.path.join(out_msk_dir, f"{base}.tif")  # single-channel PNG

        cv2.imwrite(out_img_path, img)
        cv2.imwrite(out_msk_path, multiclass_mask)


# =========================
# 3) DATASET + AUGS
# =========================
class MultiClassSegDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images = list_images(images_dir)
        self.masks = []
        for ip in self.images:
            base = os.path.splitext(os.path.basename(ip))[0]
            mp = os.path.join(masks_dir, f"{base}.tif")
            self.masks.append(mp)

        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        ip = self.images[idx]
        mp = self.masks[idx]

        img = cv2.imread(ip, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Missing mask: {mp}")

        if self.transform:
            out = self.transform(image=img, mask=mask)
            img = out["image"]
            mask = out["mask"]

        mask = torch.as_tensor(mask, dtype=torch.long)
        return img, mask


def get_train_tfms(sz=512):
    return A.Compose([
        A.Resize(sz, sz),
        A.RandomBrightnessContrast(p=0.5),
        A.CLAHE(p=0.3),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.ShiftScaleRotate(
            shift_limit=0.05, scale_limit=0.10, rotate_limit=20,
            border_mode=cv2.BORDER_CONSTANT,
            value=0, mask_value=0, p=0.5
        ),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])


def get_val_tfms(sz=512):
    return A.Compose([
        A.Resize(sz, sz),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])


# =========================
# 4) LOSSES (Dice + Focal)
# =========================
class DiceLossMulticlass(nn.Module):
    def __init__(self, num_classes, smooth=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.softmax(logits, dim=1)
        targets_1h = F.one_hot(targets, num_classes=self.num_classes).permute(0, 3, 1, 2).float()

        dims = (0, 2, 3)
        intersection = torch.sum(probs * targets_1h, dims)
        union = torch.sum(probs, dims) + torch.sum(targets_1h, dims)

        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()


class FocalLossMulticlass(nn.Module):
    def __init__(self, gamma=2.0, alpha=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, weight=self.alpha, reduction="none")
        pt = torch.exp(-ce)
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()


class DiceFocalLoss(nn.Module):
    def __init__(self, num_classes, dice_w=0.6, focal_w=0.4, gamma=2.0, alpha=None):
        super().__init__()
        self.dice = DiceLossMulticlass(num_classes)
        self.focal = FocalLossMulticlass(gamma=gamma, alpha=alpha)
        self.dice_w = dice_w
        self.focal_w = focal_w

    def forward(self, logits, targets):
        return self.dice_w * self.dice(logits, targets) + self.focal_w * self.focal(logits, targets)


@torch.no_grad()
def mean_dice_no_bg(logits, targets, num_classes):
    preds = torch.argmax(logits, dim=1)
    dices = []
    for c in range(1, num_classes):
        p = (preds == c).float()
        t = (targets == c).float()
        inter = (p * t).sum()
        denom = p.sum() + t.sum()
        d = (2 * inter + 1.0) / (denom + 1.0)
        dices.append(d)
    return torch.stack(dices).mean().item()


# =========================
# 5) TRAINING
# =========================
def train():
    device = torch.device("mps" if torch.mps.is_available() else "cpu")
    print("Device:", device)

    img_size = 512
    batch_size = 4
    epochs = 35
    lr = 3e-4

    train_ds = MultiClassSegDataset(
        os.path.join(OUT_DIR, "train", "images"),
        os.path.join(OUT_DIR, "train", "masks"),
        transform=get_train_tfms(img_size),
    )
    val_ds = MultiClassSegDataset(
        os.path.join(OUT_DIR, "val", "images"),
        os.path.join(OUT_DIR, "val", "masks"),
        transform=get_val_tfms(img_size),
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)

    model = smp.UnetPlusPlus(
        encoder_name="timm-efficientnet-b3",
        encoder_weights="imagenet",
        in_channels=3,
        classes=NUM_CLASSES,
        activation=None
    ).to(device)

    criterion = DiceFocalLoss(NUM_CLASSES, dice_w=0.6, focal_w=0.4, gamma=2.0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    best = -1e9
    os.makedirs("checkpoints", exist_ok=True)

    for ep in range(1, epochs + 1):
        model.train()
        tr_loss = 0.0

        for imgs, masks in tqdm(train_loader, desc=f"Train {ep}/{epochs}", leave=False):
            imgs = imgs.to(device)
            masks = masks.to(device)

            optimizer.zero_grad(set_to_none=True)
            logits = model(imgs)
            loss = criterion(logits, masks)
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()

        tr_loss /= max(1, len(train_loader))

        # val
        model.eval()
        va_loss = 0.0
        va_dice = 0.0

        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc=f"Val {ep}/{epochs}", leave=False):
                imgs = imgs.to(device)
                masks = masks.to(device)

                logits = model(imgs)
                loss = criterion(logits, masks)
                va_loss += loss.item()
                va_dice += mean_dice_no_bg(logits, masks, NUM_CLASSES)

        va_loss /= max(1, len(val_loader))
        va_dice /= max(1, len(val_loader))

        print(f"Epoch {ep:02d} | TrainLoss={tr_loss:.4f} | ValLoss={va_loss:.4f} | ValDice(no-bg)={va_dice:.4f}")

        if va_dice > best:
            best = va_dice
            torch.save(model.state_dict(), "checkpoints/unetpp_effb3_idrid_5class.pth")
            print(f"Saved BEST model: checkpoints/unetpp_effb3_idrid_5class.pth (dice={best:.4f})")


# =========================
# RUN
# =========================
if __name__ == "__main__":
    ensure_dirs()

    # prepare dataset only if not prepared already
    if len(os.listdir(OUT_TRAIN_MSK)) == 0:
        print("\nPreparing TRAIN split...")
        prepare_split(TRAIN_IMG_DIR, TRAIN_GT_DIR, OUT_TRAIN_IMG, OUT_TRAIN_MSK)

    if len(os.listdir(OUT_VAL_MSK)) == 0:
        print("\nPreparing VAL split (using IDRiD test folder as val)...")
        prepare_split(TEST_IMG_DIR, TEST_GT_DIR, OUT_VAL_IMG, OUT_VAL_MSK)

    print("\nStarting training...")
    train()



Preparing TRAIN split...
Found 54 images in /Users/akhilgattu/Desktop/VLM_project/A. Segmentation/1. Original Images/a. Training Set


Preparing a. Training Set: 100%|██████████| 54/54 [00:03<00:00, 14.48it/s]



Preparing VAL split (using IDRiD test folder as val)...
Found 27 images in /Users/akhilgattu/Desktop/VLM_project/A. Segmentation/1. Original Images/b. Testing Set


Preparing b. Testing Set: 100%|██████████| 27/27 [00:01<00:00, 13.52it/s]
  A.ShiftScaleRotate(



Starting training...
Device: mps


                                                           

Epoch 01 | TrainLoss=1.0481 | ValLoss=1.0577 | ValDice(no-bg)=0.0003
Saved BEST model: checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.0003)


                                                           

Epoch 02 | TrainLoss=0.8185 | ValLoss=0.7461 | ValDice(no-bg)=0.0029
Saved BEST model: checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.0029)


                                                           

Epoch 03 | TrainLoss=0.6836 | ValLoss=0.6323 | ValDice(no-bg)=0.0892
Saved BEST model: checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.0892)


                                                           

Epoch 04 | TrainLoss=0.6053 | ValLoss=0.5850 | ValDice(no-bg)=0.3872
Saved BEST model: checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.3872)


                                                           

Epoch 05 | TrainLoss=0.5643 | ValLoss=0.5568 | ValDice(no-bg)=0.3767


                                                           

Epoch 06 | TrainLoss=0.5429 | ValLoss=0.5411 | ValDice(no-bg)=0.2771


                                                           

Epoch 07 | TrainLoss=0.5315 | ValLoss=0.5290 | ValDice(no-bg)=0.4299
Saved BEST model: checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.4299)


                                                           

Epoch 08 | TrainLoss=0.5251 | ValLoss=0.5243 | ValDice(no-bg)=0.8324
Saved BEST model: checkpoints/unetpp_effb3_idrid_5class.pth (dice=0.8324)


                                                             

Epoch 09 | TrainLoss=0.5209 | ValLoss=0.5201 | ValDice(no-bg)=1.0000
Saved BEST model: checkpoints/unetpp_effb3_idrid_5class.pth (dice=1.0000)


                                                            

Epoch 10 | TrainLoss=0.5183 | ValLoss=0.5173 | ValDice(no-bg)=1.0000


                                                            

Epoch 11 | TrainLoss=0.5163 | ValLoss=0.5151 | ValDice(no-bg)=1.0000


                                                            

Epoch 12 | TrainLoss=0.5147 | ValLoss=0.5136 | ValDice(no-bg)=1.0000


                                                           

KeyboardInterrupt: 

In [19]:
!pwd

/Users/akhilgattu/Desktop/VLM_project


In [22]:
import os
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp

import albumentations as A
from albumentations.pytorch import ToTensorV2


# --------------------------
# CONFIG
# --------------------------
CKPT_PATH = "/Users/akhilgattu/Desktop/VLM_project/checkpoints/unetpp_effb3_idrid_5class.pth"

# Put your test images here (any folder)
# Example: use official IDRiD test images
IMG_DIR = "/Users/akhilgattu/Desktop/VLM_project/A. Segmentation/1. Original Images/b. Testing Set"

OUT_DIR = "visual_results"
IMG_SIZE = 512

NUM_CLASSES = 6  # 0..5


# --------------------------
# CLASS COLORS (BGR)
# 0=BG, 1=MA, 2=HE, 3=EX, 4=SE, 5=OD
# --------------------------
COLORS = {
    0: (0, 0, 0),         # BG - black
    1: (0, 0, 255),       # MA - red
    2: (0, 165, 255),     # HE - orange
    3: (0, 255, 255),     # EX - yellow
    4: (255, 0, 255),     # SE - magenta
    5: (255, 255, 0),     # OD - cyan
}

CLASS_NAMES = {
    0: "BG",
    1: "MA",
    2: "HE",
    3: "EX",
    4: "SE",
    5: "OD",
}


# --------------------------
# TRANSFORM (same as training normalize)
# --------------------------
def get_tfms(img_size=512):
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])


# --------------------------
# UTILS
# --------------------------
def list_images(folder):
    exts = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
    files = [f for f in os.listdir(folder) if f.lower().endswith(exts)]
    files.sort()
    return [os.path.join(folder, f) for f in files]


def colorize_mask(mask_0_5):
    """mask_0_5: HxW int values 0..5"""
    h, w = mask_0_5.shape
    color = np.zeros((h, w, 3), dtype=np.uint8)
    for k, bgr in COLORS.items():
        color[mask_0_5 == k] = bgr
    return color


def overlay(image_bgr, color_mask_bgr, alpha=0.45):
    return cv2.addWeighted(image_bgr, 1 - alpha, color_mask_bgr, alpha, 0)


# --------------------------
# MAIN
# --------------------------
def main():
    os.makedirs(OUT_DIR, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # Load model
    model = smp.UnetPlusPlus(
        encoder_name="timm-efficientnet-b3",
        encoder_weights=None,  # IMPORTANT: weights come from checkpoint
        in_channels=3,
        classes=NUM_CLASSES,
        activation=None
    ).to(device)

    state = torch.load(CKPT_PATH, map_location=device)
    model.load_state_dict(state)
    model.eval()

    tfms = get_tfms(IMG_SIZE)

    img_paths = list_images(IMG_DIR)
    if len(img_paths) == 0:
        raise RuntimeError(f"No images found in {IMG_DIR}")

    print(f"Found {len(img_paths)} images.")

    for ip in img_paths:
        fname = os.path.splitext(os.path.basename(ip))[0]

        # read original image
        img_bgr = cv2.imread(ip, cv2.IMREAD_COLOR)
        if img_bgr is None:
            print(f"Skipping (cannot read): {ip}")
            continue

        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

        # transform
        out = tfms(image=img_rgb)
        x = out["image"].unsqueeze(0).to(device)  # [1,3,H,W]

        # predict
        with torch.no_grad():
            logits = model(x)                      # [1,6,H,W]
            pred = torch.argmax(logits, dim=1)     # [1,H,W]
            pred_mask = pred.squeeze(0).cpu().numpy().astype(np.uint8)

        # colorize + overlay
        pred_color = colorize_mask(pred_mask)
        pred_overlay = overlay(
            cv2.resize(img_bgr, (IMG_SIZE, IMG_SIZE)),
            pred_color,
            alpha=0.45
        )

        # save outputs
        # 1) raw mask (0..5)
        cv2.imwrite(os.path.join(OUT_DIR, f"{fname}_pred_mask.png"), pred_mask)

        # 2) color mask
        cv2.imwrite(os.path.join(OUT_DIR, f"{fname}_pred_color.png"), pred_color)

        # 3) overlay
        cv2.imwrite(os.path.join(OUT_DIR, f"{fname}_overlay.png"), pred_overlay)

    print(f"\nSaved results in: {OUT_DIR}")


if __name__ == "__main__":
    main()


Device: cpu
Found 27 images.

Saved results in: visual_results
