In [None]:
import sys

# Install PyTorch with CUDA 11.8
!{sys.executable} -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install all other packages
!{sys.executable} -m pip install \
  numpy matplotlib pillow \
  opencv-python \
  pycocotools \
  albumentations \
  scikit-image \
  pandas \
  ultralytics

In [None]:
import os
import gc
import json
import random
import cv2
import torch
torch.cuda.empty_cache()
import torchvision

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from datetime import datetime

from PIL import Image
from collections import Counter

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as F
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as maskUtils

import albumentations as A
from albumentations.pytorch import ToTensorV2
from ultralytics import YOLO

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# 1. User parameters
dataset_dir    = 'data'                  # ← change this
INPUT_FILE     = 'annotations.json'  # ← your single annotations file
TEST_PCT       = 0.10                    # 10% test
VAL_PCT        = 0.10                    # 10% val
RANDOM_SEED    = 42                      # for reproducibility

# 2. Load the full annotation JSON
with open(os.path.join(dataset_dir, INPUT_FILE), 'r') as f:
    full = json.load(f)

images            = full['images']
annotations       = full.get('annotations', [])
categories        = full.get('categories', [])

# 3. Shuffle & split image list
random.seed(RANDOM_SEED)
random.shuffle(images)
N = len(images)
n_test = int(N * TEST_PCT + 0.5)
n_val  = int(N * VAL_PCT  + 0.5)

test_images  = images[:n_test]
val_images   = images[n_test:n_test+n_val]
train_images = images[n_test+n_val:]

# 4. Build index sets for fast lookup
test_ids  = {img['id'] for img in test_images}
val_ids   = {img['id'] for img in val_images}
train_ids = {img['id'] for img in train_images}

# 5. Partition annotations
def split_annotations(annotations, ids):
    return [a for a in annotations if a['image_id'] in ids]

train_anns = split_annotations(annotations, train_ids)
val_anns = split_annotations(annotations, val_ids)
test_anns = split_annotations(annotations, test_ids)

# 6. Helper to dump one subset
def create_subset(name, imgs, annotations):
    out = {
        'categories': categories,
        'images': imgs,
        'annotations': annotations,
    }
    path = os.path.join(dataset_dir, f'annotations_{name}.json')
    with open(path, 'w') as f:
        json.dump(out, f)
    print(f'Wrote {name}: {len(imgs)} images, {len(annotations)} masks')

In [None]:
# 7. Write three files
create_subset('train', train_images, train_anns)
create_subset('val', val_images, val_anns)
create_subset('test', test_images, test_anns)

In [None]:
ann_file    = os.path.join(dataset_dir, 'annotations_train.json')
coco        = COCO(ann_file)
img_ids     = coco.getImgIds()[:3]

for img_id in img_ids:
    # --- load image + anns ---
    img = coco.loadImgs(img_id)[0]
    I   = Image.open(os.path.join(dataset_dir, img['file_name']))
    ann_ids = coco.getAnnIds(imgIds=[img_id])
    anns    = coco.loadAnns(ann_ids)

    # --- print how many + their categories ---
    print(f"\nImage {img_id} ({img['file_name']}) has {len(anns)} masks:")
    for a in anns:
        cat = coco.loadCats(a['category_id'])[0]
        print(f"  • {cat['name']} (super: {cat['supercategory']})")

    # --- set up figure ---
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(12,6))
    fig.suptitle(f"Image {img_id}", fontsize=14)

    ax1.imshow(I)
    ax1.set_title("Original")
    ax1.axis('off')

    ax2.imshow(I)
    ax2.set_title("Colour‑coded masks")
    ax2.axis('off')

    # --- overlay each mask in its own random colour ---
    for ann in anns:
        mask = coco.annToMask(ann)  # H×W binary
        color = (random.random(), random.random(), random.random())
        # make a H×W×3 RGB mask
        rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.float32)
        for c in range(3):
            rgb_mask[...,c] = mask * color[c]
        ax2.imshow(rgb_mask, alpha=0.5)

    plt.tight_layout()
    plt.show()

In [None]:
# === CONFIGURE THIS ===
ann_file    = os.path.join(dataset_dir, 'annotations.json')

# 1. Load your annotations
with open(ann_file, 'r') as f:
    data = json.load(f)

# 2. Build a lookup for categories
cats = {c['id']: c for c in data['categories']}

# 3. Count up masks per category
cat_counts = Counter()
for ann in data['annotations']:
    cat_counts[cats[ann['category_id']]['name']] += 1

# sort descending
cat_items = sorted(cat_counts.items(), key=lambda x: x[1], reverse=True)
names, counts = zip(*cat_items)

# 3a. Print total masks and number of distinct categories
total_masks      = sum(counts)
num_categories   = len(names)
print(f"Total masks in dataset: {total_masks}")
print(f"Number of categories   : {num_categories}")

In [None]:
# 4. Plot horizontal bar chart
plt.figure(figsize=(8, 12))
plt.barh(names, counts)
plt.gca().invert_yaxis()   # largest at the top
plt.xlabel("Number of Masks")
plt.title("Masks per Category")
plt.tight_layout()
plt.show()

In [None]:
# === CONFIGURE THIS ===
ann_file    = os.path.join(dataset_dir, 'annotations.json')

# 1. Load your annotations
with open(ann_file, 'r') as f:
    data = json.load(f)

# 2. Build a lookup for categories
cats = {c['id']: c for c in data['categories']}

# 3. Count up masks per supercategory
super_counts = Counter()
for ann in data['annotations']:
    supercat = cats[ann['category_id']]['supercategory']
    super_counts[supercat] += 1

# 3a. Sort descending
items = sorted(super_counts.items(), key=lambda x: x[1], reverse=True)
names, counts = zip(*items)

# 3b. Print totals
total_masks        = sum(counts)
num_supercats      = len(names)
print(f"Total masks in dataset      : {total_masks}")
print(f"Number of supercategories   : {num_supercats}\n")

In [None]:
# 4. Plot horizontal bar chart
plt.figure(figsize=(8, 6))
plt.barh(names, counts)
plt.gca().invert_yaxis()    # largest at the top
plt.xlabel("Number of Masks")
plt.title("Masks per Supercategory")
plt.tight_layout()
plt.show()

In [None]:
def clip_boxes_to_image(boxes, img_w, img_h):
    clipped = []
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        x_min = max(0, min(x_min, img_w - 1))
        y_min = max(0, min(y_min, img_h - 1))
        x_max = max(0, min(x_max, img_w - 1))
        y_max = max(0, min(y_max, img_h - 1))
        clipped.append([x_min, y_min, x_max, y_max])
    return clipped

In [None]:
class TacoDataset(Dataset):
    def __init__(self, images_dir, annotation_path, transforms=None, max_resolution=(4000, 4000)):
        self.images_dir = images_dir
        self.coco = COCO(annotation_path)
        self.transforms = transforms
        self.max_resolution = max_resolution
        self.image_ids = [img_id for img_id in self.coco.imgs
                          if self.coco.imgs[img_id]['width'] <= max_resolution[0] and
                             self.coco.imgs[img_id]['height'] <= max_resolution[1]]
        print(f"Loaded {len(self.image_ids)} images under resolution {max_resolution}")

    def __getitem__(self, index):
        img_id = self.image_ids[index]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.images_dir, img_info['file_name'])
        image = np.array(Image.open(img_path).convert("RGB"))
        img_h, img_w = image.shape[:2]

        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)

        boxes, labels, masks = [], [], []
        for ann in anns:
            x, y, w, h = ann['bbox']
            x_min, y_min = x, y
            x_max, y_max = x + w, y + h
            if x_max <= x_min or y_max <= y_min:
                continue
            boxes.append([x_min, y_min, x_max, y_max])
            labels.append(ann['category_id'])
            masks.append(self.coco.annToMask(ann))

        boxes = clip_boxes_to_image(boxes, img_w, img_h)

        resized_masks = []
        for m in masks:
            if m.shape != (img_h, img_w):
                m_img = Image.fromarray(m.astype(np.uint8))
                m_resized = F.resize(m_img, (img_h, img_w), interpolation=Image.NEAREST)
                m = np.array(m_resized)
            resized_masks.append(m)

        if not boxes:
            boxes = [[0, 0, 1, 1]]
            labels = [0]
            resized_masks = [np.zeros((img_h, img_w), dtype=np.uint8)]

        if self.transforms:
            try:
                transformed = self.transforms(
                    image=image,
                    masks=resized_masks,
                    bboxes=boxes,
                    category_ids=labels
                )

                valid_boxes, valid_labels, valid_masks = [], [], []
                for i, box in enumerate(transformed["bboxes"]):
                    x_min, y_min, x_max, y_max = box
                    if (x_max - x_min) > 1 and (y_max - y_min) > 1:
                        valid_boxes.append(box)
                        valid_labels.append(transformed["category_ids"][i])
                        valid_masks.append(transformed["masks"][i])

                if not valid_boxes:
                    h, w = transformed["image"].shape[1:]
                    valid_boxes = [[0, 0, 1, 1]]
                    valid_labels = [0]
                    valid_masks = [np.zeros((h, w), dtype=np.uint8)]

                image = transformed['image'].float() / 255.0
                boxes = torch.as_tensor(valid_boxes, dtype=torch.float32)
                labels = torch.as_tensor(valid_labels, dtype=torch.int64)
                masks = torch.stack([torch.tensor(m, dtype=torch.uint8) for m in valid_masks])

            except Exception as e:
                print(f"Transform failed on image_id {img_id} with error: {e}")
                image = F.to_tensor(Image.fromarray(image))
            
                # Filter invalid boxes again (even after failed transform)
                final_boxes, final_labels, final_masks = [], [], []
                for i, box in enumerate(boxes):
                    x_min, y_min, x_max, y_max = box
                    if (x_max - x_min) > 1 and (y_max - y_min) > 1:
                        final_boxes.append(box)
                        final_labels.append(labels[i])
                        final_masks.append(resized_masks[i])
            
                if not final_boxes:
                    final_boxes = [[0, 0, 1, 1]]
                    final_labels = [0]
                    final_masks = [np.zeros((img_h, img_w), dtype=np.uint8)]

                boxes = torch.as_tensor(final_boxes, dtype=torch.float32)
                labels = torch.as_tensor(final_labels, dtype=torch.int64)
                masks = torch.stack([torch.tensor(m, dtype=torch.uint8) for m in final_masks])

        else:
            image = F.to_tensor(Image.fromarray(image))
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            masks = torch.as_tensor(np.stack(resized_masks), dtype=torch.uint8)

        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': torch.tensor([img_id]),
            'area': torch.as_tensor([ann['area'] for ann in anns], dtype=torch.float32),
            'iscrowd': torch.as_tensor([ann.get('iscrowd', 0) for ann in anns], dtype=torch.int64)
        }

        return image, target

    def __len__(self):
        return len(self.image_ids)

In [None]:
# ✅ New augmentations using Detectron2-style (simple and robust)
# def get_train_transform():
#     return A.Compose([
#         A.Resize(512, 512),
#         A.HorizontalFlip(p=0.5),
#         A.RandomBrightnessContrast(p=0.2),
#         ToTensorV2()
#     ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))

# taruns changes
def get_train_transform():
    return A.Compose([
        A.Resize(512, 512),  # or use A.LongestMaxSize(512) + A.PadIfNeeded if keeping aspect ratio
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.2),
        A.RandomRotate90(p=0.3),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=10, border_mode=0, p=0.5),

        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0)),
            A.ISONoise(p=1),
            A.MultiplicativeNoise(p=1),
        ], p=0.3),

        A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.3),
        A.RandomBrightnessContrast(p=0.3),

        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # optional: pretrained normalization
        ToTensorV2()
    ],
    bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids'])
    )


def get_val_transform():
    return A.Compose([
        A.Resize(512, 512),
        ToTensorV2()
    ])

In [None]:
# Dataset & DataLoader setup
train_dataset = TacoDataset(
    images_dir='data',
    annotation_path='data/annotations_train.json',
    transforms=get_train_transform()
)

val_dataset = TacoDataset(
    images_dir='data',
    annotation_path='data/annotations_val.json',
    transforms=get_val_transform()
)

test_dataset = TacoDataset(
    images_dir='data',
    annotation_path='data/annotations_test.json',
    transforms=get_val_transform()
)

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print(f"\nTrain: {len(train_dataset)} images")
print(f"Val:   {len(val_dataset)} images")
print(f"Test:  {len(test_dataset)} images")

In [None]:
def show_image_with_boxes_masks(image, boxes, masks, title=""):
    img = image.copy()
    for box in boxes:
        x1, y1, x2, y2 = map(int, box)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)

    for mask in masks:
        img[mask > 0] = img[mask > 0] * 0.5 + np.array([255, 0, 0]) * 0.5  # overlay red mask

    plt.figure(figsize=(8,8))
    plt.imshow(img.astype(np.uint8))
    plt.title(title)
    plt.axis("off")
    plt.show()

In [None]:
# Pick a random image
index = random.randint(0, len(train_dataset) - 1)

# Load without transform
original = TacoDataset(
    images_dir='data',
    annotation_path='data/annotations_train.json',
    transforms=None
)[index]

orig_img = original[0].permute(1, 2, 0).numpy() * 255
orig_boxes = original[1]['boxes'].numpy()
orig_masks = original[1]['masks'].numpy()

# Load with transform
augmented = train_dataset[index]
aug_img = augmented[0].permute(1, 2, 0).numpy() * 255
aug_boxes = augmented[1]['boxes'].numpy()
aug_masks = augmented[1]['masks'].numpy()

In [None]:
# Show before
show_image_with_boxes_masks(orig_img, orig_boxes, orig_masks, title="Before Augmentation")

# Show after
show_image_with_boxes_masks(aug_img, aug_boxes, aug_masks, title="After Augmentation")

In [None]:
images, targets = next(iter(train_loader))

print(f'Batch size: {len(images)}')
print(f'Image shape: {images[0].shape}')
print(f'Target keys: {targets[0].keys()}')

# Now loop through each key-value pair in the first target
print("\n--- Target values ---")
for key, value in targets[0].items():
    print(f"{key}: {value}")

In [None]:
# --- Create the Mask R-CNN model ---
def get_model(num_classes):
    backbone = resnet_fpn_backbone('resnet50', pretrained=True)
    model = MaskRCNN(backbone, num_classes=num_classes)

    # Replace box head
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace mask head
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)
    
    return model

In [None]:
# --- Setup ---
num_classes = len(train_dataset.coco.getCatIds())# + 1  # +1 for background
print(f'Classes: {num_classes}')

In [None]:
# base model
model = get_model(num_classes).to(device)

print(f'Model is on device: {device}')

In [None]:
def get_model(num_classes, model_name="resnet50", yolo_variant="yolov8n-seg.pt"):
    """
    Return a detection model. Supports torchvision Mask R-CNN and YOLOv8.

    Parameters:
    - num_classes: int
    - model_name: str — options: resnet50, resnet101, mobilenet, custom_resnet50, yolo
    - yolo_variant: str — YOLOv8 model file (used only if model_name="yolo")

    Returns:
    - model (torch.nn.Module or Ultralytics YOLO object)
    """
    if model_name == "resnet50":
        backbone = resnet_fpn_backbone('resnet50', pretrained=True)
        model = MaskRCNN(backbone, num_classes=num_classes)

    elif model_name == "resnet101":
        backbone = resnet_fpn_backbone('resnet101', pretrained=True)
        model = MaskRCNN(backbone, num_classes=num_classes)

    elif model_name == "mobilenet":
        backbone = torchvision.models.mobilenet_v2(pretrained=True).features
        backbone.out_channels = 1280
        model = MaskRCNN(backbone, num_classes=num_classes)

    elif model_name == "custom_resnet50":
        backbone = resnet_fpn_backbone("resnet50", pretrained=True, trainable_layers=3)
        model = MaskRCNN(backbone, num_classes=num_classes)

    elif model_name == "yolo":
        if YOLO is None:
            raise ImportError("Ultralytics is not installed. Run `pip install ultralytics` first.")
        model = YOLO(yolo_variant)

    else:
        raise ValueError(f"Unsupported model: {model_name}")

    # Only modify heads for torchvision models
    if model_name != "yolo":
        # Replace box head
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

        # Replace mask head
        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)

    return model

In [None]:
# Torchvision Mask R-CNN
model = get_model(num_classes=num_classes, model_name="resnet101").to(device)
model

# YOLOv8 with Ultralytics
# model = get_model(num_classes=5, model_name="yolo", yolo_variant="yolov8s-seg.pt")

In [None]:
optimizer = optim.SGD(
    [p for p in model.parameters() if p.requires_grad],
    lr=0.005, # 0.001, 0.005, 0.0001
    momentum=0.9,
    weight_decay=0.0005)

# lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=16)

In [None]:
# from skimage import measure

# def binary_mask_to_polygon(mask, tolerance=2):
#     if mask.shape[0] < 2 or mask.shape[1] < 2:
#         print(f"⚠️ Skipping too-small mask with shape: {mask.shape}")
#         return []  # Skip invalid masks

#     contours = measure.find_contours(mask, 0.5)
#     segmentations = []

#     for contour in contours:
#         contour = np.flip(contour, axis=1)  # (y,x) → (x,y)
#         segmentation = contour.ravel().tolist()
#         if len(segmentation) >= 6:  # must have at least 3 points
#             segmentations.append(segmentation)

#     return segmentations

In [None]:
# from pycocotools.cocoeval import COCOeval
# import numpy as np

# def evaluate_model(model, data_loader, coco_gt, device, coco_category_ids):
#     model.eval()
#     coco_results = []

#     with torch.no_grad():
#         for images, targets in data_loader:
#             images = [img.to(device) for img in images]
#             outputs = model(images)

#             for i, output in enumerate(outputs):
#                 img_id = int(targets[i]["image_id"].item())
#                 boxes = output["boxes"].detach().cpu().numpy()
#                 scores = output["scores"].detach().cpu().numpy()
#                 labels = output["labels"].detach().cpu().numpy()
#                 masks = output["masks"].detach().cpu().numpy()[:, 0, :, :] 

#                 for j in range(len(boxes)):
#                     # if scores[j] < 0.05:
#                     #     continue

#                     x1, y1, x2, y2 = boxes[j]
#                     if x2 <= x1 or y2 <= y1:
#                         continue

#                     polygons = binary_mask_to_polygon(masks[j])
#                     if not polygons:
#                         continue

#                     coco_results.append({
#                         "image_id": img_id,
#                         "category_id": int(coco_category_ids.get(labels[j], labels[j])),  # map label
#                         "bbox": [x1, y1, x2 - x1, y2 - y1],
#                         "score": 1.0,
#                         "segmentation": polygons
#                     })

#     if not coco_results:
#         print("⚠️ No valid predictions to evaluate.")
#         return

#     coco_dt = coco_gt.loadRes(coco_results)
#     coco_eval = COCOeval(coco_gt, coco_dt, iouType='segm')
#     coco_eval.evaluate()
#     coco_eval.accumulate()
#     coco_eval.summarize()


In [None]:
def calculate_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    return intersection / union if union > 0 else 0.0

In [None]:
# def evaluate_segmentation_metrics(model, data_loader, device, iou_threshold=0.5):
#     model.eval()
#     all_ious = []
#     true_positives, false_positives, false_negatives = 0, 0, 0
#     cls_correct, cls_total = 0, 0
#
#     with torch.no_grad():
#         for images, targets in data_loader:
#             images = [img.to(device) for img in images]
#
#             outputs = model(images)
#             outputs = [{k: v.cpu() for k, v in out.items()} for out in outputs]
#
#             torch.cuda.empty_cache()
#
#             for i in range(len(images)):
#                 gt_masks = targets[i]['masks'].cpu().numpy().astype(bool)
#                 gt_labels = targets[i]['labels'].cpu().numpy()
#                 pred_masks = outputs[i]['masks'].numpy() > 0.5
#                 pred_masks = pred_masks[:, 0, :, :]
#                 pred_labels = outputs[i]['labels'].numpy()
#
#                 matched_gt = set()
#                 matched_pred = set()
#
#                 for gi, (gt, gt_cls) in enumerate(zip(gt_masks, gt_labels)):
#                     best_iou = 0
#                     best_pi = -1
#                     for pi, (pred, pred_cls) in enumerate(zip(pred_masks, pred_labels)):
#                         if pred_cls != gt_cls:
#                             continue  # only match masks of same category
#                         iou = calculate_iou(gt, pred)
#                         if iou > best_iou:
#                             best_iou = iou
#                             best_pi = pi
#
#                     if best_iou >= iou_threshold:
#                         matched_gt.add(gi)
#                         matched_pred.add(best_pi)
#                         true_positives += 1
#                         all_ious.append(best_iou)
#                         # Classification accuracy
#                         if pred_labels[best_pi] == gt_cls:
#                             cls_correct += 1
#                         cls_total += 1
#                     else:
#                         false_negatives += 1
#                         cls_total += 1
#
#                 for pi in range(len(pred_masks)):
#                     if pi not in matched_pred:
#                         false_positives += 1
#                         cls_total += 1
#
#             del images, targets, outputs
#             torch.cuda.empty_cache()
#
#     precision = true_positives / (true_positives + false_positives + 1e-6)
#     recall = true_positives / (true_positives + false_negatives + 1e-6)
#     mean_iou = np.mean(all_ious) if all_ious else 0.0
#     cls_acc = cls_correct / (cls_total + 1e-6)
#
#     return {
#         "mean_iou": mean_iou,
#         "precision": precision,
#         "recall": recall,
#         "cls_accuracy": cls_acc,
#         "num_predictions": true_positives + false_positives,
#         "num_ground_truth": true_positives + false_negatives,
#     }

In [None]:
# def train_and_evaluate(model, train_loader, val_loader, test_loader, optimizer, lr_scheduler, device, num_epochs=10, print_interval=20):
#     for epoch in range(num_epochs):
#         model.train()
#         total_loss = 0.0
#         print(f"\n\U0001F680 Starting Epoch {epoch+1}/{num_epochs} at {datetime.now().strftime('%H:%M:%S')}")
#
#         for step, (images, targets) in enumerate(train_loader):
#             images = [img.to(device) for img in images]
#             targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
#
#             try:
#                 loss_dict = model(images, targets)
#                 loss = sum(loss for loss in loss_dict.values())
#
#                 if torch.isnan(loss) or torch.isinf(loss):
#                     print(f"\u26A0\ufe0f Invalid loss at step {step+1}, skipping")
#                     continue
#
#                 optimizer.zero_grad()
#                 loss.backward()
#                 optimizer.step()
#                 total_loss += loss.item()
#             except RuntimeError as e:
#                 print(f"\u274C CUDA error at step {step+1}: {e}")
#                 torch.cuda.empty_cache()
#                 continue
#
#             if (step + 1) % print_interval == 0 or (step + 1) == len(train_loader):
#                 print(f"  \U0001F501 Step {step+1}/{len(train_loader)} — Batch Loss: {loss.item():.4f}")
#
#             del images, targets
#             torch.cuda.empty_cache()
#
#         lr_scheduler.step()
#         avg_loss = total_loss / len(train_loader)
#         print(f"\n\U0001F4D8 Epoch {epoch+1} — Average Epoch Loss: {avg_loss:.4f}")
#         if epoch%2==0:
#             for name, loader in [("train", train_loader), ("val", val_loader)]:
#                 print(f"\n\U0001F50D Evaluating on {name} set...")
#                 try:
#                     metrics = evaluate_segmentation_metrics(model, loader, device)
#                     print(f"\U0001F4CF {name.upper()} → IoU: {metrics['mean_iou']:.4f} | "
#                           f"Precision: {metrics['precision']:.4f} | Recall: {metrics['recall']:.4f} | "
#                           f"Class Acc: {metrics['cls_accuracy']:.4f}")
#                 except Exception as e:
#                     print(f"\u274C Evaluation failed on {name}: {e}")
#
#     print("\n\U0001F50D Final Evaluation on TEST set...")
#     metrics_test = evaluate_segmentation_metrics(model, test_loader, device)
#     print(f"\U0001F4CF TEST → IoU: {metrics_test['mean_iou']:.4f} | "
#           f"Precision: {metrics_test['precision']:.4f} | Recall: {metrics_test['recall']:.4f} | "
#           f"Class Acc: {metrics_test['cls_accuracy']:.4f}")
#
#     torch.save(model.state_dict(), "maskrcnn_taco_baseline.pth")
#     print("\u2705 Model saved to maskrcnn_taco_baseline.pth")


In [None]:
# taruns changes

def evaluate_segmentation_metrics(model, data_loader, device, iou_threshold=0.5):
    model.eval()
    all_ious = []
    true_positives, false_positives, false_negatives = 0, 0, 0
    cls_correct, cls_total = 0, 0

    with torch.no_grad():
        progress_bar = tqdm(data_loader, desc="Evaluating", leave=False)

        for images, targets in progress_bar:
            images = [img.to(device) for img in images]
            outputs = model(images)
            outputs = [{k: v.cpu() for k, v in out.items()} for out in outputs]

            for i in range(len(images)):
                gt_masks = targets[i]['masks'].cpu().numpy().astype(bool)
                gt_labels = targets[i]['labels'].cpu().numpy()
                pred_masks = outputs[i]['masks'].numpy() > 0.5
                pred_masks = pred_masks[:, 0, :, :]
                pred_labels = outputs[i]['labels'].numpy()

                matched_gt = set()
                matched_pred = set()

                for gi, (gt, gt_cls) in enumerate(zip(gt_masks, gt_labels)):
                    best_iou = 0
                    best_pi = -1
                    for pi, (pred, pred_cls) in enumerate(zip(pred_masks, pred_labels)):
                        if pred_cls != gt_cls:
                            continue
                        iou = calculate_iou(gt, pred)
                        if iou > best_iou:
                            best_iou = iou
                            best_pi = pi

                    if best_iou >= iou_threshold:
                        matched_gt.add(gi)
                        matched_pred.add(best_pi)
                        true_positives += 1
                        all_ious.append(best_iou)
                        if pred_labels[best_pi] == gt_cls:
                            cls_correct += 1
                        cls_total += 1
                    else:
                        false_negatives += 1
                        cls_total += 1

                for pi in range(len(pred_masks)):
                    if pi not in matched_pred:
                        false_positives += 1
                        cls_total += 1

            # 🧹 Clean up per batch
            del images, targets, outputs
            torch.cuda.empty_cache()

    precision = true_positives / (true_positives + false_positives + 1e-6)
    recall = true_positives / (true_positives + false_negatives + 1e-6)
    mean_iou = np.mean(all_ious) if all_ious else 0.0
    cls_acc = cls_correct / (cls_total + 1e-6)

    # 🧹 Final cleanup after full evaluation
    torch.cuda.empty_cache()

    return {
        "mean_iou": mean_iou,
        "precision": precision,
        "recall": recall,
        "cls_accuracy": cls_acc,
    }

In [None]:
# taruns changes

def train_and_evaluate(model, train_loader, val_loader, test_loader, optimizer, lr_scheduler, device, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        print(f"\nEpoch {epoch+1}/{num_epochs} started - {datetime.now().strftime('%H:%M:%S')}")

        # Use tqdm for progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Training]", leave=False)

        for step, (images, targets) in enumerate(progress_bar):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            try:
                loss_dict = model(images, targets)
                loss = sum(loss for loss in loss_dict.values())

                if torch.isnan(loss) or torch.isinf(loss):
                    progress_bar.write(f"Warning: Invalid loss at batch {step+1}, skipping batch.")
                    continue

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

                # Update progress bar with latest batch loss
                progress_bar.set_postfix(loss=loss.item())

            except RuntimeError as e:
                progress_bar.write(f"RuntimeError at batch {step+1}: {e}")
                continue

            # 🧹 Free up memory for each batch
            del images, targets, loss_dict, loss
            torch.cuda.empty_cache()

        lr_scheduler.step()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

        # 🧹 Free cache after epoch
        torch.cuda.empty_cache()

        # Evaluate after every epoch
        for name, loader in [("Train", train_loader), ("Val", val_loader)]:
            print(f"\nEvaluating {name} set...")
            try:
                metrics = evaluate_segmentation_metrics(model, loader, device)
                print(f"{name} Metrics — IoU: {metrics['mean_iou']:.4f}, "
                      f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, "
                      f"Cls Acc: {metrics['cls_accuracy']:.4f}")
            except Exception as e:
                print(f"Evaluation error on {name}: {e}")

            # 🧹 Free cache after each evaluation
            torch.cuda.empty_cache()

    # Final Test Evaluation
    print("\nFinal Evaluation on TEST set:")
    metrics_test = evaluate_segmentation_metrics(model, test_loader, device)
    print(f"TEST Metrics — IoU: {metrics_test['mean_iou']:.4f}, "
          f"Precision: {metrics_test['precision']:.4f}, Recall: {metrics_test['recall']:.4f}, "
          f"Cls Acc: {metrics_test['cls_accuracy']:.4f}")

    # 🧹 Free cache after final test evaluation
    torch.cuda.empty_cache()

    # Save model
    torch.save(model.state_dict(), "maskrcnn_taco_baseline.pth")
    print("Model saved to maskrcnn_taco_baseline.pth")

In [None]:
train_and_evaluate(model, train_loader, val_loader, test_loader, optimizer, lr_scheduler, device, num_epochs=10)

In [None]:
# def evaluate_ground_truth_as_prediction(coco_gt):
#     coco_results = []
#
#     for img_id in coco_gt.getImgIds():
#         anns = coco_gt.loadAnns(coco_gt.getAnnIds(imgIds=img_id))
#         for ann in anns:
#             if 'segmentation' not in ann or 'bbox' not in ann or 'category_id' not in ann:
#                 continue
#
#             # COCO format expects score when evaluating detections
#             coco_results.append({
#                 "image_id": img_id,
#                 "category_id": ann["category_id"],
#                 "bbox": ann["bbox"],
#                 "score": 1.0,  # max confidence
#                 "segmentation": ann["segmentation"]
#             })
#
#     # Convert to COCO predictions format
#     coco_dt = coco_gt.loadRes(coco_results)
#
#     # Run evaluation
#     coco_eval = COCOeval(coco_gt, coco_dt, iouType='segm')
#     coco_eval.evaluate()
#     coco_eval.accumulate()
#     coco_eval.summarize()

In [None]:
# evaluate_ground_truth_as_prediction(val_dataset.coco)

In [None]:
def evaluate_ground_truth_as_prediction(coco_gt, iou_type='segm'):
    """
    Evaluate ground truth annotations against themselves using COCO evaluation.

    Args:
        coco_gt (COCO): Ground truth COCO object
        iou_type (str): Type of evaluation - 'bbox' or 'segm'

    Returns:
        COCOeval object with results
    """
    if len(coco_gt.getImgIds()) == 0:
        print("No images found in COCO dataset.")
        return None

    coco_results = []
    for img_id in coco_gt.getImgIds():
        anns = coco_gt.loadAnns(coco_gt.getAnnIds(imgIds=img_id))
        for ann in anns:
            if not all(k in ann for k in ('segmentation', 'bbox', 'category_id')):
                continue

            coco_results.append({
                "image_id": img_id,
                "category_id": ann["category_id"],
                "bbox": ann["bbox"],
                "score": 1.0,  # perfect confidence
                "segmentation": ann["segmentation"]
            })

    if not coco_results:
        print("No valid annotations found to evaluate.")
        return None

    # Load results in COCO format
    coco_dt = coco_gt.loadRes(coco_results)

    # Perform evaluation
    coco_eval = COCOeval(coco_gt, coco_dt, iouType=iou_type)
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

    return coco_eval

In [None]:
# coco_gt = COCO('data/annotations_val.json')
coco_gt = val_dataset.coco

print(f'Evaluate masks (segmentation)')
evaluate_ground_truth_as_prediction(coco_gt, iou_type='segm')

print(f'Evaluate bounding boxes')
evaluate_ground_truth_as_prediction(coco_gt, iou_type='bbox')