In [None]:
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from xml.etree import ElementTree as ET

In [None]:
class barcodeDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        self.labels = list(sorted(os.listdir(os.path.join(root, "labels"))))

    def __getitem__(self, idx):
        # load images and labels
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        label_path = os.path.join(self.root, "labels", self.labels[idx])
        
        img = Image.open(img_path).convert("RGB")
        
        label = ET.parse(label_path)
        root = label.getroot()
        
        num_objs = 0
        
        boxes = []
        for obj in root.iter('object'):
            num_objs += 1
            
            bndbox = obj.find('bndbox')
            
            xmin = np.float16(bndbox.find('xmin').text)
            xmax = np.float16(bndbox.find('xmax').text)
            ymin = np.float16(bndbox.find('ymin').text)
            ymax = np.float16(bndbox.find('ymax').text)
            
            boxes.append([xmin, ymin, xmax, ymax])


        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [None]:
import transforms as T

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
dataset = barcodeDataset('data', get_transform(train=True))
dataset.__len__()

In [None]:
# Split to Train/Test
train_size = int(0.8 * dataset.__len__())
test_size = dataset.__len__() - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [None]:
next(iter(train_dataset))[0].size()