Data augmentation techniques such as flipping, scaling, noise addition, random rotations, flipping, zooming, and shifting also brghtness variation have to be implemented using Python libraries such as Albumentations, YOLOv8 or similar image processing tools



In [5]:
import os
from torch.utils.data import Dataset
from PIL import Image
import torch

class CustomDataset(Dataset):
    def __init__(self, annotations_dir, img_dir, transform=None, target_transform=None):
        self.annotations_dir = annotations_dir
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.img_files = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        # Image Path
        img_name = self.img_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        # Label Path and Parsing
        label_path = os.path.join(self.annotations_dir, img_name.replace('jpg', 'txt'))
        boxes = []
        labels = []
        
        with open(label_path, 'r') as f:
            for line in f.readlines():
                data = list(map(float, line.strip().split()))
                class_id, bbox = data[0], data[1:]
                # Add bounding box and class label to respective lists
                labels.append(int(class_id))
                boxes.append(bbox)  # bounding box coordinates

        # Convert boxes and labels to tensors
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.long)

        # Apply transformations
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            boxes = self.target_transform(boxes)
            labels = self.target_transform(labels)

        return image, {'boxes': boxes, 'labels': labels}


In [6]:
import torch
from torchvision.transforms import v2
from torchvision import tv_tensors

H, W = 32, 32
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)

# Define image transformations
transforms = v2.Compose([
    v2.ToImage(),
    
    v2.RandomResizedCrop(size=(224, 224), antialias=True),  # Or Resize(antialias=True)
    v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
boxes = torch.randint(0, H // 2, size=(3, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))

img, boxes = transforms(img, boxes)
# And you can pass arbitrary input structures
output_dict = transforms({"image": img, "boxes": boxes})

WeedDataset = CustomDataset(annotations_dir= 'Data/luxeed_heatmaps/data/labels', img_dir='Data/luxeed_heatmaps/data/images', transform=transforms)

