In [None]:
# Cell 1 — Imports & paths

%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import random
from typing import List, Dict

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import cv2
import matplotlib.pyplot as plt

PROJECT_ROOT = Path("..").resolve()

DATA_ROOT = PROJECT_ROOT / "mug_coco_yolo"
TRAIN_IMAGES_DIR = DATA_ROOT / "images" / "train2017"
TRAIN_LABELS_DIR = DATA_ROOT / "labels" / "train2017"
VAL_IMAGES_DIR = DATA_ROOT / "images" / "val2017"
VAL_LABELS_DIR = DATA_ROOT / "labels" / "val2017"

print("Train images:", TRAIN_IMAGES_DIR)
print("Train labels:", TRAIN_LABELS_DIR)
print("Val images:", VAL_IMAGES_DIR)
print("Val labels:", VAL_LABELS_DIR)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


In [None]:
# Cell 2 — Dataset class for YOLO-format labels

class MugYoloDataset(Dataset):
    """
    Reads images + YOLO-format labels and returns data for Faster R-CNN:
      image: Tensor [3, H, W]
      target: dict with:
          - boxes: FloatTensor [N, 4] (x1, y1, x2, y2 in pixels)
          - labels: LongTensor [N] (1 for mug)
          - image_id: Tensor [1]
    """

    def __init__(self, images_dir: Path, labels_dir: Path, transforms=None):
        self.images_dir = Path(images_dir)
        self.labels_dir = Path(labels_dir)
        self.transforms = transforms

        self.image_paths = sorted(list(self.images_dir.glob("*.jpg")))
        print(f"Found {len(self.image_paths)} images in {images_dir}")

    def __len__(self):
        return len(self.image_paths)

    def _load_yolo_labels(self, label_path: Path, img_w: int, img_h: int):
        boxes = []
        if not label_path.exists():
            return boxes

        with open(label_path, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 5:
                    continue
                cls, xc, yc, w, h = parts
                xc = float(xc) * img_w
                yc = float(yc) * img_h
                w = float(w) * img_w
                h = float(h) * img_h

                x1 = xc - w / 2
                y1 = yc - h / 2
                x2 = xc + w / 2
                y2 = yc + h / 2
                boxes.append([x1, y1, x2, y2])
        return boxes

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.labels_dir / (img_path.stem + ".txt")

        img = cv2.imread(str(img_path))
        if img is None:
            raise RuntimeError(f"Could not read image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_h, img_w = img.shape[:2]

        boxes = self._load_yolo_labels(label_path, img_w, img_h)

        boxes_tensor = torch.as_tensor(boxes, dtype=torch.float32)
        if len(boxes) == 0:
            labels_tensor = torch.zeros((0,), dtype=torch.int64)
        else:
            labels_tensor = torch.ones((len(boxes),), dtype=torch.int64)  # 1 = mug

        image_id = torch.tensor([idx], dtype=torch.int64)

        target = {
            "boxes": boxes_tensor,
            "labels": labels_tensor,
            "image_id": image_id,
        }

        img_tensor = transforms.ToTensor()(img)  # [0,1] float

        if self.transforms is not None:
            img_tensor = self.transforms(img_tensor)

        return img_tensor, target


In [None]:
# Cell 3 — Collate function for variable-size targets

def collate_fn(batch):
    images, targets = list(zip(*batch))
    return list(images), list(targets)


In [None]:
# Cell 4 — Create datasets and dataloaders

train_dataset = MugYoloDataset(TRAIN_IMAGES_DIR, TRAIN_LABELS_DIR, transforms=None)
val_dataset   = MugYoloDataset(VAL_IMAGES_DIR,   VAL_LABELS_DIR,   transforms=None)

BATCH_SIZE = 4

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn,
)


In [None]:
# Cell 5 — Visual sanity check: show one sample

images, targets = next(iter(train_loader))
print("Batch size:", len(images))
print("Image shape:", images[0].shape)
print("Target keys:", targets[0].keys())
print("Boxes:", targets[0]["boxes"])
print("Labels:", targets[0]["labels"])

# plot the first image with boxes
img = images[0].permute(1, 2, 0).numpy()  # [H, W, C]
img_h, img_w = img.shape[:2]

import matplotlib.pyplot as plt
import matplotlib.patches as patches

fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.imshow(img)
for box in targets[0]["boxes"]:
    x1, y1, x2, y2 = box.tolist()
    rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                             linewidth=2, edgecolor='g', facecolor='none')
    ax.add_patch(rect)
plt.axis("off")
plt.show()


In [None]:
# Cell 6 — Define the Faster R-CNN model (pretrained on COCO, fine-tune on mug)

num_classes = 2  # background + mug

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights="COCO_V1"
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
    in_features, num_classes
)

model.to(DEVICE)


In [None]:
# Cell 7 — Training loop (simple version)

import torch.optim as optim
from tqdm.auto import tqdm

LR = 0.005
NUM_EPOCHS = 10

params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=LR, momentum=0.9, weight_decay=0.0005)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

def train_one_epoch(model, data_loader, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    for imgs, targets in tqdm(data_loader, desc=f"Epoch {epoch}"):
        imgs = [img.to(device) for img in imgs]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(imgs, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        running_loss += losses.item()

    return running_loss / len(data_loader)

for epoch in range(1, NUM_EPOCHS + 1):
    avg_loss = train_one_epoch(model, train_loader, optimizer, DEVICE, epoch)
    lr_scheduler.step()
    print(f"Epoch {epoch}/{NUM_EPOCHS} - Avg Loss: {avg_loss:.4f}")


In [None]:
# Cell 8 — Save trained Faster R-CNN weights

OUTPUT_DIR = PROJECT_ROOT / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

faster_rcnn_path = OUTPUT_DIR / "fasterrcnn_mug.pth"
torch.save(model.state_dict(), faster_rcnn_path)

print("Saved Faster R-CNN weights to:", faster_rcnn_path)


In [None]:
# Cell 9 — Quick inference test on one validation image

model.eval()
imgs, targets = next(iter(val_loader))
imgs = [img.to(DEVICE) for img in imgs]

with torch.no_grad():
    outputs = model(imgs)

print("Predicted boxes:", outputs[0]["boxes"])
print("Scores:", outputs[0]["scores"])
print("Labels:", outputs[0]["labels"])
