## Object Detection With PyTorch and COCO Dataset

### To-Do List:
- The labels don't match when the image is flipped with the transforms. Fix it.
- Add a training loop.
- Add a testing loop.

In [23]:
from pathlib import Path

import torch
import torchvision
import torchvision.transforms.v2 as transformsV2

from PIL import Image

EPOCHS = 10
BATCH_SIZE = 5
NUM_WORKERS = 0
LEARNING_RATE = 0.001

classes = ["__background__", "cup"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Paths
path_to_dataset_root = Path("../datasets/cup_dataset")
path_to_coco = path_to_dataset_root / "coco.json"
path_to_images = path_to_dataset_root / "train"
path_to_test = path_to_dataset_root / "test"
path_to_rest = path_to_dataset_root / "cup"

# Transforms
train_transform = transformsV2.Compose(transforms=[
    transformsV2.Compose(transforms=[transformsV2.ToImage(), transformsV2.ToDtype(torch.float32, scale=True)]),
    # # Commented out for now.
    # transformsV2.RandomHorizontalFlip(p=1),
    # transformsV2.RandomVerticalFlip(p=1),
    transformsV2.RandomPhotometricDistort(p=1)
])

In [24]:
class COCODataset(torch.utils.data.Dataset):
    """
    A Dataset class that utilises the CocoDetection class and grabs annotation data from a given coco file.
    """
    def __init__(self,
                 path: Path,
                 torch_coco_object: torchvision.datasets.CocoDetection,
                 transforms: torchvision.transforms) -> None:
        self.path = path
        self.coco: torchvision.datasets.CocoDetection = torch_coco_object.coco
        self.transforms: torchvision.transforms.v2 = transforms
        self.all_images = list(path.glob("*.jpg"))
        
    def __len__(self) -> int:
        """
        Returns the length of the dataset.
        
        Returns:
            (int): The length of the dataset.
        """
        return len(self.coco)
    
    def __getitem__(self, idx) -> tuple[torch.Tensor, dict]:
        """
        Returns an image from the dataset and its annotations.
        
        Args:
            idx (int): Index of the image to be grabbed.
            
        Returns:
            (tuple(torch.Tensor, dict)): A tuple containing the image and it's annotation data.
        """
        # Get the image name
        image_name: str = self.all_images[idx].name
        
        # Get the id from the image name
        image_id: int = int(image_name.rstrip(".jpg"))
        
        # Get the batch using the image id
        image = Image.open(self.path / image_name)
        
        # Initialise empty boxes and categories lists.
        boxes, categories = [], []
        
        # For every annotation in the coco annotations
        for index in self.coco.anns:
            # If the image_id generated from the image name matches the image_id of the annotation
            if self.coco.anns[index]["image_id"] == image_id:
                # Add the bounding box coordinates to the boxes list.
                boxes.append(self.coco.anns[index]["bbox"])
                
                # Add the category to the categories list. 
                categories.append(self.coco.anns[index]["category_id"])
        
        # Create the targets dictionary
        targets = {
            "boxes": boxes,
            "labels": categories
        }
        
        # If there are any transforms provided, apply them on both the targets and the image
        if self.transforms is not None:
            return self.transforms(image, targets)

        # Else return the image and targets as is.
        return image, targets

In [25]:
def draw_from_coco(path_to_images: Path,
                   path_to_coco: Path,
                   transform: transformsV2,
                   classes: list,
                   idx: int) -> None:
    """
    Draws to an image using its annotations grabbed from a COCO annotation file.
    
    Args:
        path_to_images (Path): Path to the images that contains the annotated images.
        path_to_coco (Path): Path to the annotation file.
        transform (torchvision.transforms.v2): Transforms to be applied on the images.
        classes (list): The list of classes.
        idx (int): The index of the image to be drawn on.
    """
    
    torch_coco_object = torchvision.datasets.CocoDetection(root=path_to_images,
                                                           annFile=path_to_coco,
                                                           transforms=transform)

    coco_dataset = COCODataset(path=path_to_images, torch_coco_object=torch_coco_object, transforms=transform)
    batch = coco_dataset.__getitem__(idx)
    tensor_image, anns = batch


    boxes, labels = [], []

    for box in anns["boxes"]:
        bbox = torchvision.ops.box_convert(boxes=torch.as_tensor(box),
                                        in_fmt="xywh",
                                        out_fmt="xyxy")
        boxes.append(bbox)

    labels = [classes[label] for label in anns["labels"]]
    boxes = torch.stack(boxes)
    drawn_image = torchvision.utils.draw_bounding_boxes(image=tensor_image,
                                                        boxes=boxes,
                                                        labels=labels,
                                                        colors=(255,0,0))

    pil_image = torchvision.transforms.v2.ToPILImage()(drawn_image)
    Image._show(pil_image)
    
draw_from_coco(path_to_images=path_to_images,
               path_to_coco=path_to_coco,
               transform=train_transform,
               classes=classes,
               idx=0)

loading annotations into memory...
Done (t=0.42s)
creating index...
index created!
