## Object Detection With PyTorch and COCO Dataset

### To-Do List:
- Fix collate_fn and dataset classes
    - Dataset class needs to return a tuple of images and labels
    - collate_fn's return value is wrong.
- Check TVTensors
- The labels don't match when the image is flipped with the transforms. Fix it.
- Add a training loop.
- Add a testing loop.

In [11]:
from pathlib import Path

import torch
import torchvision
import torchvision.transforms.v2 as transformsV2

from PIL import Image

EPOCHS = 10
BATCH_SIZE = 5
NUM_WORKERS = 0
LEARNING_RATE = 0.001

classes = ["__background__", "cup"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=weights)
optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=5,
                                            gamma=0.1)

# Paths
path_to_dataset_root = Path("../datasets/cup_dataset")
path_to_coco = path_to_dataset_root / "coco.json"
path_to_train = path_to_dataset_root / "train"
path_to_validation = path_to_dataset_root / "validation"
path_to_rest = path_to_dataset_root / "cup"

# Transforms
train_transform = transformsV2.Compose(transforms=[
    torchvision.transforms.ToTensor(),
    transformsV2.RandomPhotometricDistort(p=1)
])

test_transform = transformsV2.Compose(transforms=[
    transformsV2.Compose(transforms=[transformsV2.ToImage(), transformsV2.ToDtype(torch.float32, scale=True)]),
])

In [36]:
def collate_fn(batch):
    """
    Custom collate function to handle variable-sized inputs in the DataLoader.

    Args:
        batch (list): A list of tuples (image_tensor, labels).

    Returns:
        Tuple(torch.Tensor, List): A tuple containing a batch of images and a list of labels.
    """
    # Separate the images and the labels
    images, labels = zip(*batch)
    
    # Stack images into a single tensor
    images = torch.stack(images)

    # Initialize lists for boxes and classes
    boxes = []
    classes = []

    # Iterate through each label and append to respective lists
    for label in labels:
        for box in label["boxes"]:
            boxes.append(box)
        for _class in label["labels"]:
            classes.append(_class)
            
    print(boxes, classes)

    return images, {'boxes': boxes, 'labels': classes}

def collate_fn(batch):
    # TODO: Get the biggest tuple's length and set it as the maximum length, and pad the rest of the batches with 0
    # [(1,1), (1,1,1), (1,1,1,1)] -> [(1,1,0,0), (1,1,1,0), (1,1,1,1)]
    for item in batch:
        image = item["image"]
        labels = item["labels"]
        max_len = len(labels)
    print(image, labels)


In [29]:
import torch.utils.data
import torchvision.transforms.functional


class COCODataset(torch.utils.data.Dataset):
    """
    A Dataset class that utilises the CocoDetection class and grabs annotation data from a given coco file.
    """
    def __init__(self,
                 path: Path,
                 torch_coco_object: torchvision.datasets.CocoDetection,
                 transforms: torchvision.transforms) -> None:
        self.path = path
        self.coco: torchvision.datasets.CocoDetection = torch_coco_object.coco
        self.transforms: torchvision.transforms.v2 = transforms
        self.all_images = list(path.glob("*.jpg"))
        
    def __len__(self) -> int:
        """
        Returns the length of the dataset.
        
        Returns:
            (int): The length of the dataset.
        """
        return len(self.all_images)
  
    
    def __getitem__(self, idx) -> tuple[torch.Tensor, dict]:
        """
        Returns an image from the dataset and its annotations.
        
        Args:
            idx (int): Index of the image to be grabbed.
            
        Returns:
            (tuple(torch.Tensor, dict)): A tuple containing the image and it's annotation data.
        """
        # Get the image name
        image_name: str = self.all_images[idx].name
        
        # Get the id from the image name
        image_id: int = int(image_name.rstrip(".jpg"))
        
        # Get the batch using the image id
        image = Image.open(self.path / image_name)
        
        # Initialise empty boxes and categories lists.
        boxes, categories = [], []
        
        # For every annotation in the coco annotations
        for index in self.coco.anns:
            # If the image_id generated from the image name matches the image_id of the annotation
            if self.coco.anns[index]["image_id"] == image_id:
                # Add the bounding box coordinates to the boxes list.
                boxes.append(self.coco.anns[index]["bbox"])
                
                # Add the category to the categories list. 
                categories.append(self.coco.anns[index]["category_id"])

        targets = {
            "boxes": boxes,
            "categories": categories
        }
        
        sample = {
            "image": image,
            "labels": targets
        }
        
        # If there are any transforms provided, apply them on transforms and targets
        if self.transforms is not None:
            image, targets = self._apply_transforms(image=image, targets=sample["labels"])
            sample["image"] = image
            return sample

        # Else return the image and targets as is.
        return sample
    
        
    def _apply_transforms(self, image, targets):
        targets["boxes"] = torch.tensor(targets["boxes"], dtype=torch.float32)
        # Apply horizontal flip
        if torch.rand(1).item() < 0.5:
            image = torchvision.transforms.functional.hflip(image)
            width, _ = image.size
            targets["boxes"][:, 0] = width - targets["boxes"][:, 0] - targets["boxes"][:, 2]
        
        # Apply vertical flip
        if torch.rand(1).item() < 0.5:
            image = torchvision.transforms.functional.vflip(image)
            _, height = image.size
            targets["boxes"][:, 1] = height - targets["boxes"][:, 1] - targets["boxes"][:, 3]
        
        # Other transformations can be applied similarly
        image = self.transforms(image)
        return image, targets


torch_coco_object = torchvision.datasets.CocoDetection(root=path_to_train, annFile=path_to_coco, transform=train_transform)

train_dataset = COCODataset(path=path_to_train,
                            torch_coco_object=torch_coco_object,
                            transforms=train_transform)
validation_dataset = COCODataset(path=path_to_validation,
                           torch_coco_object=torch_coco_object,
                           transforms=test_transform)

train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True,
                                               collate_fn=collate_fn)

validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=False,
                                               collate_fn=collate_fn)

loading annotations into memory...
Done (t=0.44s)
creating index...
index created!


In [21]:
def draw_from_coco(coco_dataset: COCODataset,
                   classes: list,
                   idx: int) -> None:
    """
    Draws to an image using its annotations grabbed from a COCO annotation file.
    
    Args:
        path_to_images (Path): Path to the images that contains the annotated images.
        path_to_coco (Path): Path to the annotation file.
        transform (torchvision.transforms.v2): Transforms to be applied on the images.
        classes (list): The list of classes.
        idx (int): The index of the image to be drawn on.
    """

    # Grab a batch from the dataset
    batch = coco_dataset.__getitem__(idx)    
    image = batch["image"]
    anns = batch["labels"]
    
    # Change the format of the bounding box coordinates and append them into a box list
    boxes, labels = [], []
    for box in anns["boxes"]:
        bbox = torchvision.ops.box_convert(boxes=torch.as_tensor(box),
                                        in_fmt="xywh",
                                        out_fmt="xyxy")
        boxes.append(bbox)

    # Create a labels list from the aformentioned annotations
    labels = [classes[label] for label in anns["labels"]]
    
    # Stack the boxes 
    boxes = torch.stack(boxes)
    
    # Draw the labels onto the image
    drawn_image = torchvision.utils.draw_bounding_boxes(image=image,
                                                        boxes=boxes,
                                                        labels=labels,
                                                        colors=(255,0,0))
    # Turn the drawn tensor image into a PIL image
    pil_image = torchvision.transforms.v2.ToPILImage()(drawn_image)
    
    # Show the image
    Image._show(pil_image)

In [37]:
batch = [train_dataset.__getitem__(0), train_dataset.__getitem__(1)]
collate_fn(batch)

tensor([[[0.3400, 0.3441, 0.3483,  ..., 0.3449, 0.3573, 0.3696],
         [0.3400, 0.3400, 0.3400,  ..., 0.3407, 0.3490, 0.3531],
         [0.3359, 0.3359, 0.3318,  ..., 0.3325, 0.3366, 0.3366],
         ...,
         [0.3755, 0.3796, 0.3837,  ..., 0.3959, 0.3916, 0.3872],
         [0.3920, 0.3961, 0.4044,  ..., 0.3746, 0.3787, 0.3829],
         [0.3713, 0.3796, 0.3879,  ..., 0.3457, 0.3540, 0.3663]],

        [[0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
         [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
         [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009],
         ...,
         [0.0009, 0.0009, 0.0009,  ..., 0.0142, 0.0098, 0.0055],
         [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0011],
         [0.0009, 0.0009, 0.0009,  ..., 0.0009, 0.0009, 0.0009]],

        [[0.4454, 0.4498, 0.4541,  ..., 0.4396, 0.4526, 0.4657],
         [0.4454, 0.4454, 0.4454,  ..., 0.4352, 0.4439, 0.4483],
         [0.4411, 0.4411, 0.4367,  ..., 0.4265, 0.4309, 0.

In [73]:
def train_step(model, data_loader, optimizer, device):
    """
    Performs one training step for the object detection model.

    Args:
        model (torch.nn.Module): The Faster R-CNN model.
        data_loader (torch.utils.data.DataLoader): The DataLoader providing the training data.
        optimizer (torch.optim.Optimizer): The optimizer for the model.
        device (torch.device): The device (CPU or GPU) to run the training on.

    Returns:
        float: The total loss for the step.
    """
    model.train()  # Set the model to training mode
    total_loss = 0.0
    
    for images, targets in data_loader:
        # Move images and targets to the device
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        # Forward pass
        loss_dict = model(images, targets)
        
        # Calculate total loss
        losses = sum(loss for loss in loss_dict.values())
        total_loss += losses.item()
        
        # Backpropagation
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    
    return total_loss

loss = train_step(model=model,
           data_loader=train_dataloader,
           optimizer=optimizer,
           device=device)

TypeError: expected Tensor as element 0 in argument 0, but got dict