In [None]:
#!/usr/bin/env python3
import os
import json
import random

# 1. User parameters
DATASET_DIR    = 'data'    # ← change this
INPUT_FILE     = 'annotations.json'      # ← your single annotations file
TEST_PCT       = 0.10                    # 10% test
VAL_PCT        = 0.10                    # 10% val
RANDOM_SEED    = 42                      # for reproducibility

# 2. Load the full annotation JSON
with open(os.path.join(DATASET_DIR, INPUT_FILE), 'r') as f:
    full = json.load(f)

images            = full['images']
annotations       = full.get('annotations', [])
scene_annotations = full.get('scene_annotations', [])
info              = full.get('info', None)
licenses          = full.get('licenses', [])
categories        = full.get('categories', [])
scene_categories  = full.get('scene_categories', [])

# 3. Shuffle & split image list
random.seed(RANDOM_SEED)
random.shuffle(images)
N = len(images)
n_test = int(N * TEST_PCT + 0.5)
n_val  = int(N * VAL_PCT  + 0.5)

test_images  = images[:n_test]
val_images   = images[n_test:n_test+n_val]
train_images = images[n_test+n_val:]

# 4. Build index sets for fast lookup
test_ids  = {img['id'] for img in test_images}
val_ids   = {img['id'] for img in val_images}
train_ids = {img['id'] for img in train_images}

# 5. Partition annotations
def split_anns(anns, idset):
    return [a for a in anns if a['image_id'] in idset]

train_anns        = split_anns(annotations,       train_ids)
val_anns          = split_anns(annotations,       val_ids)
test_anns         = split_anns(annotations,       test_ids)
train_scene_anns  = split_anns(scene_annotations, train_ids)
val_scene_anns    = split_anns(scene_annotations, val_ids)
test_scene_anns   = split_anns(scene_annotations, test_ids)

# 6. Helper to dump one subset
def dump_subset(name, imgs, anns, scene_anns):
    out = {
        'info':             info,
        'licenses':         licenses,
        'categories':       categories,
        'scene_categories': scene_categories,
        'images':           imgs,
        'annotations':      anns,
        'scene_annotations': scene_anns,
    }
    path = os.path.join(DATASET_DIR, f'annotations_{name}.json')
    with open(path, 'w') as f:
        json.dump(out, f)
    print(f'Wrote {name}: {len(imgs)} images, {len(anns)} masks')

# 7. Write three files
dump_subset('train', train_images, train_anns, train_scene_anns)
dump_subset('val',   val_images,   val_anns,   val_scene_anns)
dump_subset('test',  test_images,  test_anns,  test_scene_anns)


In [None]:
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os, random

dataset_dir = 'data'
ann_file    = os.path.join(dataset_dir, 'annotations_train.json')
coco        = COCO(ann_file)
img_ids     = coco.getImgIds()[:3]

for img_id in img_ids:
    # --- load image + anns ---
    img = coco.loadImgs(img_id)[0]
    I   = Image.open(os.path.join(dataset_dir, img['file_name']))
    ann_ids = coco.getAnnIds(imgIds=[img_id])
    anns    = coco.loadAnns(ann_ids)

    # --- print how many + their categories ---
    print(f"\nImage {img_id} ({img['file_name']}) has {len(anns)} masks:")
    for a in anns:
        cat = coco.loadCats(a['category_id'])[0]
        print(f"  • {cat['name']} (super: {cat['supercategory']})")

    # --- set up figure ---
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(12,6))
    fig.suptitle(f"Image {img_id}", fontsize=14)

    ax1.imshow(I)
    ax1.set_title("Original")
    ax1.axis('off')

    ax2.imshow(I)
    ax2.set_title("Colour‑coded masks")
    ax2.axis('off')

    # --- overlay each mask in its own random colour ---
    for ann in anns:
        mask = coco.annToMask(ann)  # H×W binary
        color = (random.random(), random.random(), random.random())
        # make a H×W×3 RGB mask
        rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.float32)
        for c in range(3):
            rgb_mask[...,c] = mask * color[c]
        ax2.imshow(rgb_mask, alpha=0.5)

    plt.tight_layout()
    plt.show()


In [None]:
import json
import os
from collections import Counter
import matplotlib.pyplot as plt

# === CONFIGURE THIS ===
dataset_dir = 'data'
ann_file    = os.path.join(dataset_dir, 'annotations.json')

# 1. Load your annotations
with open(ann_file, 'r') as f:
    data = json.load(f)

# 2. Build a lookup for categories
cats = {c['id']: c for c in data['categories']}

# 3. Count up masks per category
cat_counts = Counter()
for ann in data['annotations']:
    cat_counts[cats[ann['category_id']]['name']] += 1

# sort descending
cat_items = sorted(cat_counts.items(), key=lambda x: x[1], reverse=True)
names, counts = zip(*cat_items)

# 3a. Print total masks and number of distinct categories
total_masks      = sum(counts)
num_categories   = len(names)
print(f"Total masks in dataset: {total_masks}")
print(f"Number of categories   : {num_categories}")

# 4. Plot horizontal bar chart
plt.figure(figsize=(8, 12))
plt.barh(names, counts)
plt.gca().invert_yaxis()   # largest at the top
plt.xlabel("Number of Masks")
plt.title("Masks per Category")
plt.tight_layout()
plt.show()


In [None]:
import json
import os
from collections import Counter
import matplotlib.pyplot as plt

# === CONFIGURE THIS ===
dataset_dir = 'data'
ann_file    = os.path.join(dataset_dir, 'annotations.json')

# 1. Load your annotations
with open(ann_file, 'r') as f:
    data = json.load(f)

# 2. Build a lookup for categories
cats = {c['id']: c for c in data['categories']}

# 3. Count up masks per supercategory
super_counts = Counter()
for ann in data['annotations']:
    supercat = cats[ann['category_id']]['supercategory']
    super_counts[supercat] += 1

# 3a. Sort descending
items = sorted(super_counts.items(), key=lambda x: x[1], reverse=True)
names, counts = zip(*items)

# 3b. Print totals
total_masks        = sum(counts)
num_supercats      = len(names)
print(f"Total masks in dataset      : {total_masks}")
print(f"Number of supercategories   : {num_supercats}\n")

# 4. Plot horizontal bar chart
plt.figure(figsize=(8, 6))
plt.barh(names, counts)
plt.gca().invert_yaxis()    # largest at the top
plt.xlabel("Number of Masks")
plt.title("Masks per Supercategory")
plt.tight_layout()
plt.show()


In [None]:
pip install torch

In [None]:
import sys
!{sys.executable} -m pip install albumentations

In [None]:
pip install torchvision

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from pycocotools.coco import COCO
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision.transforms.functional as F

def clip_boxes_to_image(boxes, img_w, img_h):
    clipped = []
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        x_min = max(0, min(x_min, img_w - 1))
        y_min = max(0, min(y_min, img_h - 1))
        x_max = max(0, min(x_max, img_w - 1))
        y_max = max(0, min(y_max, img_h - 1))
        clipped.append([x_min, y_min, x_max, y_max])
    return clipped

class TacoDataset(Dataset):
    def __init__(self, images_dir, annotation_path, transforms=None, max_resolution=(4000, 4000)):
        self.images_dir = images_dir
        self.coco = COCO(annotation_path)
        self.transforms = transforms
        self.max_resolution = max_resolution
        self.image_ids = [img_id for img_id in self.coco.imgs
                          if self.coco.imgs[img_id]['width'] <= max_resolution[0] and
                             self.coco.imgs[img_id]['height'] <= max_resolution[1]]
        print(f"✅ Loaded {len(self.image_ids)} images under resolution {max_resolution}")

    import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from pycocotools.coco import COCO
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision.transforms.functional as F

def clip_boxes_to_image(boxes, img_w, img_h):
    clipped = []
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        x_min = max(0, min(x_min, img_w - 1))
        y_min = max(0, min(y_min, img_h - 1))
        x_max = max(0, min(x_max, img_w - 1))
        y_max = max(0, min(y_max, img_h - 1))
        clipped.append([x_min, y_min, x_max, y_max])
    return clipped

class TacoDataset(Dataset):
    def __init__(self, images_dir, annotation_path, transforms=None, max_resolution=(4000, 4000)):
        self.images_dir     = images_dir
        self.coco           = COCO(annotation_path)
        self.transforms     = transforms
        self.max_resolution = max_resolution
        # only keep images under max_resolution
        self.image_ids = [
            img_id for img_id, info in self.coco.imgs.items()
            if info['width'] <= max_resolution[0] and info['height'] <= max_resolution[1]
        ]
        print(f"✅ Loaded {len(self.image_ids)} images under resolution {max_resolution}")

    import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from pycocotools.coco import COCO
import torchvision.transforms.functional as F

def clip_boxes_to_image(boxes, img_w, img_h):
    clipped = []
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        x_min = max(0, min(x_min, img_w - 1))
        y_min = max(0, min(y_min, img_h - 1))
        x_max = max(0, min(x_max, img_w - 1))
        y_max = max(0, min(y_max, img_h - 1))
        clipped.append([x_min, y_min, x_max, y_max])
    return clipped

class TacoDataset(Dataset):
    def __init__(self, images_dir, annotation_path, transforms=None, max_resolution=(4000, 4000)):
        self.images_dir     = images_dir
        self.coco           = COCO(annotation_path)
        self.transforms     = transforms
        self.max_resolution = max_resolution
        self.image_ids = [
            img_id for img_id, info in self.coco.imgs.items()
            if info['width']  <= max_resolution[0]
            and info['height'] <= max_resolution[1]
        ]
        print(f"✅ Loaded {len(self.image_ids)} images under resolution {max_resolution}")

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        img_id   = self.image_ids[index]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.images_dir, img_info['file_name'])
        image_np = np.array(Image.open(img_path).convert("RGB"))
        img_h, img_w = image_np.shape[:2]

        # load annotations
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns    = self.coco.loadAnns(ann_ids)

        # build raw boxes, labels, masks lists
        boxes, labels, masks = [], [], []
        for ann in anns:
            x, y, w, h = ann['bbox']
            x_min, y_min = x, y
            x_max, y_max = x + w, y + h
            if x_max <= x_min or y_max <= y_min:
                continue
            boxes.append([x_min, y_min, x_max, y_max])
            labels.append(ann['category_id'])
            masks.append(self.coco.annToMask(ann))

        # clamp boxes
        boxes = clip_boxes_to_image(boxes, img_w, img_h)

        # resize masks
        resized_masks = []
        for m in masks:
            if m.shape != (img_h, img_w):
                m_img     = Image.fromarray(m.astype(np.uint8))
                m_resized = F.resize(m_img, (img_h, img_w), interpolation=Image.NEAREST)
                m         = np.array(m_resized)
            resized_masks.append(m)

        # apply transforms if any
        if self.transforms:
            try:
                transformed = self.transforms(
                    image=image_np,
                    masks=resized_masks,
                    bboxes=boxes,
                    category_ids=labels
                )
                image = transformed['image'].float() / 255.0

                boxes  = torch.as_tensor(transformed['bboxes'], dtype=torch.float32)
                labels = torch.as_tensor(transformed['category_ids'], dtype=torch.int64)
                masks  = torch.stack([torch.tensor(m, dtype=torch.uint8)
                                      for m in transformed['masks']])
                
                # filter out tiny/invalid
                keep = []
                for i, box in enumerate(boxes):
                    if (box[2]-box[0] > 1) and (box[3]-box[1] > 1):
                        keep.append(i)
                if len(keep) < boxes.size(0):
                    boxes  = boxes[keep]
                    labels = labels[keep]
                    masks  = masks[keep]

            except Exception as e:
                print(f"⚠️ Transform failed on image_id {img_id}: {e}")
                image  = F.to_tensor(Image.fromarray(image_np))
                boxes  = torch.as_tensor(boxes, dtype=torch.float32)
                labels = torch.as_tensor(labels, dtype=torch.int64)
                masks  = torch.as_tensor(np.stack(resized_masks), dtype=torch.uint8)
        else:
            image  = F.to_tensor(Image.fromarray(image_np))
            boxes  = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            masks  = torch.as_tensor(np.stack(resized_masks), dtype=torch.uint8)

        # —— Remove any degenerate boxes (zero width or height) ——
        if boxes.numel() > 0:
            widths  = boxes[:, 2] - boxes[:, 0]
            heights = boxes[:, 3] - boxes[:, 1]
            keep    = (widths > 0) & (heights > 0)
            if keep.sum() < boxes.size(0):
                boxes  = boxes[keep]
                labels = labels[keep]
                masks  = masks[keep]

        # —— Handle empty targets —— 
        if boxes.numel() == 0:
            boxes   = torch.zeros((0, 4), dtype=torch.float32)
            labels  = torch.zeros((0,),    dtype=torch.int64)
            masks   = torch.zeros((0, img_h, img_w), dtype=torch.uint8)
            areas   = torch.zeros((0,),    dtype=torch.float32)
            iscrowd = torch.zeros((0,),    dtype=torch.int64)
        else:
            areas   = torch.as_tensor([ann['area']      for ann in anns], dtype=torch.float32)
            iscrowd = torch.as_tensor([ann.get('iscrowd', 0) for ann in anns], dtype=torch.int64)

        target = {
            'boxes':    boxes,
            'labels':   labels,
            'masks':    masks,
            'image_id': torch.tensor([img_id]),
            'area':     areas,
            'iscrowd':  iscrowd
        }

        return image, target

    def __len__(self):
        return len(self.image_ids)


    def __len__(self):
        return len(self.image_ids)



In [None]:
# ✅ New augmentations using Detectron2-style (simple and robust)
def get_train_transform():
    return A.Compose([
        A.Resize(512, 512),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        ToTensorV2()
    ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))


def get_val_transform():
    return A.Compose([
        A.Resize(256, 256),
        ToTensorV2()
    ])

# Dataset & DataLoader setup
train_dataset = TacoDataset(
    images_dir='data',
    annotation_path='data/annotations_train.json',
    transforms=get_train_transform()
)

val_dataset = TacoDataset(
    images_dir='data',
    annotation_path='data/annotations_val.json',
    transforms=get_val_transform()
)

test_dataset = TacoDataset(
    images_dir='data',
    annotation_path='data/annotations_test.json',
    transforms=get_val_transform()
)

def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

print(f"\n📦 Train: {len(train_dataset)} images")
print(f"📦 Val:   {len(val_dataset)} images")
print(f"📦 Test:  {len(test_dataset)} images")


In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import random
import cv2

def show_image_with_boxes_masks(image, boxes, masks, title=""):
    img = image.copy()
    for box in boxes:
        x1, y1, x2, y2 = map(int, box)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)

    for mask in masks:
        img[mask > 0] = img[mask > 0] * 0.5 + np.array([255, 0, 0]) * 0.5  # overlay red mask

    plt.figure(figsize=(8,8))
    plt.imshow(img.astype(np.uint8))
    plt.title(title)
    plt.axis("off")
    plt.show()


# Pick a random image
index = random.randint(0, len(train_dataset) - 1)

# Load without transform
original = TacoDataset(
    images_dir='data',
    annotation_path='data/annotations_train.json',
    transforms=None
)[index]
orig_img = original[0].permute(1, 2, 0).numpy() * 255
orig_boxes = original[1]['boxes'].numpy()
orig_masks = original[1]['masks'].numpy()

# Load with transform
augmented = train_dataset[index]
aug_img = augmented[0].permute(1, 2, 0).numpy() * 255
aug_boxes = augmented[1]['boxes'].numpy()
aug_masks = augmented[1]['masks'].numpy()

# Show before
show_image_with_boxes_masks(orig_img, orig_boxes, orig_masks, title="Before Augmentation")

# Show after
show_image_with_boxes_masks(aug_img, aug_boxes, aug_masks, title="After Augmentation")


In [None]:
for imgs, targets in train_loader:
    print("Batch size:", len(imgs))
    print("1st image shape:", imgs[0])
    print("1st target keys:", targets[0].keys())
    print("1st target values:", targets[0].values())
    break

In [None]:
pip install scikit-image

In [None]:
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import torch
import torch.optim as optim

# --- Create the Mask R-CNN model ---
def get_model(num_classes):
    model = maskrcnn_resnet50_fpn(pretrained=True)

    # Replace box head
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace mask head
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)
    
    return model


In [None]:
from torch.optim.lr_scheduler import OneCycleLR

# --- Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(train_dataset.coco.getCatIds()) + 1  # +1 for background
model = get_model(num_classes).to(device)

# use the same optimizer
optimizer = optim.SGD(
    [p for p in model.parameters() if p.requires_grad],
    lr=0.005,             # bump up to 0.005 (the torchvision default)
    momentum=0.9,
    weight_decay=0.0005
)

# one-cycle LR scheduler for num_epochs * steps_per_epoch total steps
steps_per_epoch = len(train_loader)
num_epochs=10
lr_scheduler = OneCycleLR(
    optimizer,
    max_lr=0.005,
    total_steps = num_epochs * steps_per_epoch,
    pct_start = 0.1,      # 10% of the cycle is warm-up
    anneal_strategy = 'cos', 
    div_factor = 25.0,    # start LR = max_lr/div_factor
    final_div_factor = 10000.0
)


In [None]:
# from skimage import measure

# def binary_mask_to_polygon(mask, tolerance=2):
#     if mask.shape[0] < 2 or mask.shape[1] < 2:
#         print(f"⚠️ Skipping too-small mask with shape: {mask.shape}")
#         return []  # Skip invalid masks

#     contours = measure.find_contours(mask, 0.5)
#     segmentations = []

#     for contour in contours:
#         contour = np.flip(contour, axis=1)  # (y,x) → (x,y)
#         segmentation = contour.ravel().tolist()
#         if len(segmentation) >= 6:  # must have at least 3 points
#             segmentations.append(segmentation)

#     return segmentations



In [None]:
# from pycocotools.cocoeval import COCOeval
# import numpy as np

# def evaluate_model(model, data_loader, coco_gt, device, coco_category_ids):
#     model.eval()
#     coco_results = []

#     with torch.no_grad():
#         for images, targets in data_loader:
#             images = [img.to(device) for img in images]
#             outputs = model(images)

#             for i, output in enumerate(outputs):
#                 img_id = int(targets[i]["image_id"].item())
#                 boxes = output["boxes"].detach().cpu().numpy()
#                 scores = output["scores"].detach().cpu().numpy()
#                 labels = output["labels"].detach().cpu().numpy()
#                 masks = output["masks"].detach().cpu().numpy()[:, 0, :, :] 

#                 for j in range(len(boxes)):
#                     # if scores[j] < 0.05:
#                     #     continue

#                     x1, y1, x2, y2 = boxes[j]
#                     if x2 <= x1 or y2 <= y1:
#                         continue

#                     polygons = binary_mask_to_polygon(masks[j])
#                     if not polygons:
#                         continue

#                     coco_results.append({
#                         "image_id": img_id,
#                         "category_id": int(coco_category_ids.get(labels[j], labels[j])),  # map label
#                         "bbox": [x1, y1, x2 - x1, y2 - y1],
#                         "score": 1.0,
#                         "segmentation": polygons
#                     })

#     if not coco_results:
#         print("⚠️ No valid predictions to evaluate.")
#         return

#     coco_dt = coco_gt.loadRes(coco_results)
#     coco_eval = COCOeval(coco_gt, coco_dt, iouType='segm')
#     coco_eval.evaluate()
#     coco_eval.accumulate()
#     coco_eval.summarize()


In [None]:
import torch
import numpy as np
from datetime import datetime
def calculate_iou(mask1: torch.Tensor, mask2: torch.Tensor) -> float:
    """
    Compute IoU between two binary masks.
    """
    intersection = torch.logical_and(mask1, mask2).sum().item()
    union        = torch.logical_or(mask1, mask2).sum().item()
    return intersection / union if union > 0 else 0.0

In [None]:
def evaluate_segmentation_metrics(
    model,
    data_loader,
    device
) -> dict:
    """
    Evaluate Mask R-CNN on a dataset, computing precision, recall, mean IoU, and class accuracy.
    - Filters predictions with score < score_thresh.
    - Binarizes masks at mask_thresh.
    - Handles images with zero GT or zero predictions.
    - Class accuracy = (# correct GT–pred matches) / (# GT + # unmatched preds).
    """
    model.eval()

    tp, fp, fn = 0, 0, 0
    all_ious   = []
    cls_correct, cls_total = 0, 0

    score_thresh = 0.5
    mask_thresh  = 0.5
    iou_threshold= 0.5

    with torch.no_grad():
        for images, targets in data_loader:
            images  = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(images)

            for i in range(len(images)):
                # --- Ground truth ---
                gt_masks  = targets[i]['masks'].bool()   # [G, H, W]
                gt_labels = targets[i]['labels']         # [G]

                # --- Predictions (score + mask threshold) ---
                scores     = outputs[i]['scores']        # [P]
                keep_pred  = scores > score_thresh
                pred_masks = outputs[i]['masks'][keep_pred, 0] > mask_thresh  # [P, H, W]
                pred_labels= outputs[i]['labels'][keep_pred]                # [P]

                # --- Special cases ---
                if gt_masks.shape[0] == 0:
                    # no GT => all preds are false positives
                    fp += pred_masks.shape[0]
                    continue
                if pred_masks.shape[0] == 0:
                    # no preds => all GT are false negatives
                    fn += gt_masks.shape[0]
                    continue

                # --- Greedy 1:1 matching by IoU & class ---
                matched_pred = set()
                for gi, (gt_m, gt_cls) in enumerate(zip(gt_masks, gt_labels)):
                    best_iou = 0.0
                    best_pi  = -1
                    for pi, (pm, p_cls) in enumerate(zip(pred_masks, pred_labels)):
                        if pi in matched_pred or p_cls != gt_cls:
                            continue
                        iou = calculate_iou(gt_m, pm)
                        if iou > best_iou:
                            best_iou = iou
                            best_pi  = pi

                    if best_iou >= iou_threshold:
                        # True positive
                        tp += 1
                        all_ious.append(best_iou)
                        matched_pred.add(best_pi)
                        # Classification was correct
                        cls_correct += 1
                        cls_total   += 1
                    else:
                        # Missed this GT
                        fn += 1
                        cls_total += 1

                # Unmatched preds are false positives + classification errors
                num_unmatched = pred_masks.shape[0] - len(matched_pred)
                fp += num_unmatched
                cls_total += num_unmatched

    # ——— Metrics ———
    precision    = tp / (tp + fp + 1e-6)
    recall       = tp / (tp + fn + 1e-6)
    mean_iou     = np.mean(all_ious) if all_ious else 0.0
    cls_accuracy = cls_correct / (cls_total + 1e-6)

    return {
        "mean_iou":      mean_iou,
        "precision":     precision,
        "recall":        recall,
        "cls_accuracy":  cls_accuracy
    }


In [None]:
def train_and_evaluate(model, train_loader, val_loader, test_loader, optimizer, lr_scheduler, device, num_epochs=10, print_interval=40):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        print(f"\n🚀 Starting Epoch {epoch+1}/{num_epochs} at {datetime.now().strftime('%H:%M:%S')}")

        for step, (images, targets) in enumerate(train_loader):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            try:
                loss_dict = model(images, targets)
                loss = sum(loss for loss in loss_dict.values())

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"⚠️ Invalid loss at step {step+1}, skipping")
                    continue

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            except RuntimeError as e:
                print(f"❌ CUDA error at step {step+1}: {e}")
                torch.cuda.empty_cache()
                continue

            if (step + 1) % print_interval == 0 or (step + 1) == len(train_loader):
                print(f"  🔁 Step {step+1}/{len(train_loader)} — Batch Loss: {loss.item():.4f}")

            del images, targets
            torch.cuda.empty_cache()

        lr_scheduler.step()
        avg_loss = total_loss / len(train_loader)
        print(f"\n📘 Epoch {epoch+1} — Average Epoch Loss: {avg_loss:.4f}")
        if epoch % 2 == 0:
            for name, loader in [("train", train_loader), ("val", val_loader)]:
                print(f"\n🔍 Evaluating on {name} set...")
                try:
                    metrics = evaluate_segmentation_metrics(model, loader, device)
                    print(f"📏 {name.upper()} → IoU: {metrics['mean_iou']:.4f} | "
                          f"Precision: {metrics['precision']:.4f} | Recall: {metrics['recall']:.4f} | "
                          f"Class Acc: {metrics['cls_accuracy']:.4f}")
                except Exception as e:
                    print(f"❌ Evaluation failed on {name}: {e}")

    print("\n🔍 Final Evaluation on TEST set...")
    metrics_test = evaluate_segmentation_metrics(model, test_loader, device)
    print(f"📏 TEST → IoU: {metrics_test['mean_iou']:.4f} | "
          f"Precision: {metrics_test['precision']:.4f} | Recall: {metrics_test['recall']:.4f} | "
          f"Class Acc: {metrics_test['cls_accuracy']:.4f}")

    torch.save(model.state_dict(), "maskrcnn_taco_baseline.pth")
    print("✅ Model saved to maskrcnn_taco_baseline.pth")


In [None]:
train_and_evaluate(model, train_loader, val_loader, test_loader, optimizer, lr_scheduler, device, num_epochs=10, print_interval=40)

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from IPython.display import Image as IPImage, display

def print_predictions_on_test_set(
    model,
    test_loader,
    dataset,
    device,
    score_thresh: float = 0.5,
    mask_thresh:  float = 0.5,
    save_dir:     str   = "predicted_masks"
):
    """
    Runs the model on the test_loader, then for each image:
     - Filters predictions by score_thresh
     - Binarizes masks at mask_thresh
     - Overlays masks in red and draws boxes in yellow
     - Saves and displays each result inline
    """
    model.eval()
    os.makedirs(save_dir, exist_ok=True)

    # Map COCO category IDs to names
    category_id_to_name = {
        cat['id']: cat['name']
        for cat in dataset.coco.loadCats(dataset.coco.getCatIds())
    }

    with torch.no_grad():
        for images, targets in test_loader:
            # Move to device
            images = [img.to(device) for img in images]
            outputs = model(images)

            for img_tensor, output, target in zip(images, outputs, targets):
                img_id  = int(target['image_id'].item())
                boxes   = output['boxes'].cpu().numpy()
                scores  = output['scores'].cpu().numpy()
                labels  = output['labels'].cpu().numpy()
                masks   = (output['masks'][:, 0].cpu().numpy() > mask_thresh)

                # Convert tensor to H×W×3 uint8 image
                img_np = img_tensor.cpu().permute(1, 2, 0).numpy()
                img_np = (img_np * 255).astype(np.uint8)

                # Create a figure & axis
                fig, ax = plt.subplots(figsize=(6, 6))
                ax.imshow(img_np)

                any_pred = False
                for j, (box, score, label_id, mask) in enumerate(
                    zip(boxes, scores, labels, masks)
                ):
                    if score < score_thresh:
                        continue
                    any_pred = True

                    # Draw bounding box in yellow
                    x1, y1, x2, y2 = box
                    rect = patches.Rectangle(
                        (x1, y1),
                        x2 - x1,
                        y2 - y1,
                        linewidth=2,
                        edgecolor="yellow",
                        facecolor="none"
                    )
                    ax.add_patch(rect)

                    # Overlay mask in transparent red
                    red_overlay = img_np.copy()
                    red_overlay[mask] = (
                        red_overlay[mask] * 0.5 + np.array([255, 0, 0]) * 0.5
                    )
                    ax.imshow(red_overlay, alpha=0.5)

                    # Print caption
                    label_name = category_id_to_name.get(label_id, f"Class {label_id}")
                    ax.text(
                        x1,
                        y1 - 5,
                        f"{label_name}: {score:.2f}",
                        color="yellow",
                        fontsize=12,
                        backgroundcolor="black",
                        alpha=0.7
                    )

                if not any_pred:
                    ax.text(
                        10,
                        20,
                        "⚠️ No predictions above threshold",
                        color="red",
                        fontsize=14,
                        backgroundcolor="white"
                    )

                ax.axis("off")

                # Save and display
                out_path = os.path.join(save_dir, f"img_{img_id}.png")
                fig.savefig(out_path, bbox_inches="tight", pad_inches=0)
                plt.close(fig)

                display(IPImage(out_path))


In [None]:
print_predictions_on_test_set(model, test_loader, test_dataset, device, score_threshold=0.5, show_masks=True)

In [None]:
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as maskUtils
import numpy as np

def evaluate_ground_truth_as_prediction(coco_gt):
    coco_results = []

    for img_id in coco_gt.getImgIds():
        anns = coco_gt.loadAnns(coco_gt.getAnnIds(imgIds=img_id))
        for ann in anns:
            if 'segmentation' not in ann or 'bbox' not in ann or 'category_id' not in ann:
                continue

            # COCO format expects score when evaluating detections
            coco_results.append({
                "image_id": img_id,
                "category_id": ann["category_id"],
                "bbox": ann["bbox"],
                "score": 1.0,  # max confidence
                "segmentation": ann["segmentation"]
            })

    # Convert to COCO predictions format
    coco_dt = coco_gt.loadRes(coco_results)

    # Run evaluation
    coco_eval = COCOeval(coco_gt, coco_dt, iouType='segm')
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()


In [None]:
evaluate_ground_truth_as_prediction(val_dataset.coco)
