In [4]:
import os
import torch
import torch.utils.data as data
from PIL import Image
from pycocotools.coco import COCO
from typing import Any


In [5]:
class CocoDataset(data.Dataset):
    def __init__(self, root, annotation_file, transforms=None) -> None:
        super().__init__()
        """
        Args:
            root (string): Directory path to the images.
            annotation_file (string): Path to the COCO JSON annotation file.
            transforms (callable, optional): Optional transforms to be applied
                on a sample.
        """
        # Initialize the COCO API
        self.coco = COCO(annotation_file)
        self.image_ids = list(self.coco.imgs.keys())
        self.root = root
        self.transforms = transforms

    def __len__(self):
        """
        Returns the total number of images in the dataset.
        """
        return len(self.image_ids)
    
    def __getitem__(self, idx) -> Any:
        """
        Retrieves an image and its corresponding annotations for a given index.

        Args:
            index (int): Index of the image to retrieve.

        Returns:
            tuple: (image, target) where:
                - image (PIL.Image): The image.
                - target (dict): A dictionary containing the annotations.
        """
        # Get image id
        image_id = self.image_ids[idx]

                # Load image metadata
        image_info = self.coco.loadImgs(image_id)[0]
        image_path = os.path.join(self.root, image_info['file_name'])

        # Open the image using PIL
        image = Image.open(image_path).convert('RGB')

        # Load annotations for the current image
        ann_ids = self.coco.getAnnIds(imgIds=image_id)
        annotations = self.coco.loadAnns(ann_ids)

        # Process annotations into a suitable format for PyTorch
        boxes = []
        labels = []
        for ann in annotations:
            # COCO bbox format is [x, y, width, height]
            x, y, w, h = ann['bbox']
            # Convert to [x_min, y_min, x_max, y_max] format for many models
            x_min, y_min, x_max, y_max = x, y, x + w, y + h
            boxes.append([x_min, y_min, x_max, y_max])
            labels.append(ann['category_id'])

        # Convert to PyTorch tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([image_id])
        target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # iscrowd = torch.as_tensor([ann["iscrowd"] for ann in annotations], dtype=torch.int64)
        # target["iscrowd"] = iscrowd

        # Apply transformations if provided
        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target
    
def collate_fn(batch):
    """
    A custom collate function for object detection datasets.
    It takes a list of samples (image, target) and returns a batch.
    """
    return tuple(zip(*batch))

from torchvision import transforms

# Define a simple transform pipeline
# Note: For object detection, transforms like RandomCrop need to be applied
# to both the image and the bounding boxes.
# This simple example only shows basic image transforms.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])



In [9]:
# Create the custom dataset
image_dir = "Datasets/COCO/images/train2017"
annotation_file = "Datasets/COCO/annotations/instances_train2017.json"

print(os.path.isdir(image_dir))
print(os.path.isfile(annotation_file))

True
True


In [10]:
coco_dataset = CocoDataset(root=image_dir, annotation_file=annotation_file, transforms=transform)

# Create the DataLoader
data_loader = data.DataLoader(
    coco_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn  # Use the custom collate function
)

# Example of iterating through the DataLoader
for images, targets in data_loader:
    print("Shape of images:", images.shape)
    print("Shape of targets:", targets.shape)
    break

loading annotations into memory...
Done (t=28.64s)
creating index...
index created!


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/media/thanhnv154te/15a388e9-5b07-4e1c-b175-72e9839b5c872/workspace/dev_venv/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/media/thanhnv154te/15a388e9-5b07-4e1c-b175-72e9839b5c872/workspace/dev_venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/media/thanhnv154te/15a388e9-5b07-4e1c-b175-72e9839b5c872/workspace/dev_venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_2574243/3737425442.py", line 74, in __getitem__
    image, target = self.transforms(image, target)
TypeError: Compose.__call__() takes 2 positional arguments but 3 were given
