## Data augmentation
A notebook to explore augmentations to be applied to images during training.

In [None]:
from pathlib import Path

import cv2
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torch import Tensor
from torchvision.transforms.functional import pil_to_tensor
from torchvision.ops import box_convert
from torchvision.tv_tensors import BoundingBoxes

from ssd.data import DataAugmenter, LetterboxTransform, SSDDataset
from ssd.structs import FrameLabels

### Define constants

In [None]:
IMAGE_WIDTH = 300
IMAGE_HEIGHT = 300

DTYPE = torch.float32
DEVICE = torch.device("cpu")

IMAGE_FILE = Path("/mnt/data/datasets/object_detection/coco/images/train2017/000000484814.jpg")
LABEL_FILE = Path("/mnt/data/datasets/object_detection/coco/labels/train2017/000000484814.txt")

### Show the label on the letterboxed image

In [None]:
# Load in the image
image = Image.open(IMAGE_FILE)
image_tensor = pil_to_tensor(image)

# Load in the label
objects = SSDDataset.read_label_file(LABEL_FILE, DEVICE, DTYPE)

# Apply the letterbox transform
transform = LetterboxTransform(IMAGE_WIDTH, IMAGE_HEIGHT)
image_tensor, objects = transform(image_tensor, objects, DEVICE)

In [None]:
# Display the labels
im = image_tensor.permute((1, 2, 0)).to(torch.uint8).cpu().numpy().copy()
boxes = box_convert(objects.boxes, "cxcywh", "xyxy") * 300
for idx in range(objects.boxes.shape[0]):
    p1 = tuple(boxes[idx, :2].to(torch.int).cpu().tolist())
    p2 = tuple(boxes[idx, 2:].to(torch.int).cpu().tolist())
    im = cv2.rectangle(im, p1, p2, (255, 0, 0), 1)

plt.imshow(im)

### Augment the image

In [None]:
augmenter = DataAugmenter(IMAGE_WIDTH, IMAGE_HEIGHT)

trans_image: Tensor
trans_objects: FrameLabels
trans_image, trans_objects = augmenter(image_tensor / 255, objects)
trans_image = trans_image * 255

In [None]:
im = trans_image.permute((1, 2, 0)).to(torch.uint8).cpu().numpy().copy()
boxes = box_convert(trans_objects.boxes, "cxcywh", "xyxy")

for idx in range(boxes.shape[0]):
    p1 = tuple(boxes[idx, :2].to(torch.int).cpu().tolist())
    p2 = tuple(boxes[idx, 2:].to(torch.int).cpu().tolist())
    im = cv2.rectangle(im, p1, p2, (255, 0, 0), 1)

plt.imshow(im)