In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
import os

In [2]:
class CocoDetectionDataset(Dataset):
    def __init__(self, root, annFile, transforms=None):
        self.root = root
        self.transforms = transforms
        with open(annFile, 'r') as f:
            self.coco = json.load(f)
        self.imgs = self.coco['images']
        self.anns = self.coco['annotations']
        self.img_id_to_anns = {}
        for ann in self.anns:
            self.img_id_to_anns.setdefault(ann['image_id'], []).append(ann)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_info = self.imgs[idx]
        img_path = os.path.join(self.root, img_info['file_name'])
        img = Image.open(img_path).convert("RGB")

        annots = self.img_id_to_anns.get(img_info['id'], [])
        boxes = []
        labels = []
        for ann in annots:
            boxes.append(ann['bbox'])  # COCO bbox format: [x, y, width, height]
            labels.append(ann['category_id'])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = torch.tensor([img_info['id']])

        if self.transforms:
            img, target = self.transforms(img, target)

        return img, target