In [20]:
import torch
import os

In [21]:
from torch.utils.data import Dataset, ConcatDataset

class YoloDataset(Dataset):
    def __init__(self, path, S=7, B=2, C = 0):
        super().__init__()
        self.X , self.y = torch.load(path)
        self.S = S
        self.B = B
        self.C = C

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        image = self.X
        boxes = self.y
        labels = torch.zeros((self.S, self.S, self.C + 5))
        
        # single box 
        if boxes.dim() == 1 and len(boxes) == 4:
            x, y, w, h = boxes.tolist()
            i, j = int(self.S * y), int(self.S * x)
            x_cell, y_cell = self.S * x - j, self.S * y - i
            w_cell, h_cell = w * self.S, h * self.S
            
            if labels[i, j, 0] == 0:
                labels[i, j, 0] = 1
                labels[i, j, 1:5] = torch.tensor([x_cell, y_cell, w_cell, h_cell])
        # multiple boxes
        elif boxes.dim() > 1:
            for unscaled_box in boxes:
                x, y, w, h = unscaled_box.tolist()
                i, j = int(self.S * y), int(self.S * x)
                x_cell, y_cell = self.S * x - j, self.S * y - i
                w_cell, h_cell = w * self.S, h * self.S
                
                if labels[i, j, 0] == 0:
                    labels[i, j, 0] = 1
                    labels[i, j, 1:5] = torch.tensor([x_cell, y_cell, w_cell, h_cell])
        
        return image, labels

In [22]:
PATH = '/home/kuba/Documents/data/raw/single-face-tensors/train'
all_train_datasets = []

for file in sorted(os.listdir(PATH)):
    try:
        dataset = YoloDataset((PATH+"/"+file))
        all_train_datasets.append(dataset)
    except Exception as e:
        print(f"Error loading {file}: {str(e)}")
        continue

combined = ConcatDataset(all_train_datasets)

In [23]:
sample = '/home/kuba/Documents/data/raw/single-face-tensors/train/1253221dcb65bd1d.pt'
X, y = torch.load(sample)

In [24]:
X.shape, y.shape

(torch.Size([3, 224, 224]), torch.Size([4]))

In [25]:
dataset = YoloDataset(sample)
sig = dataset[0]

In [26]:
x,y = sig

In [29]:
y.shape

torch.Size([7, 7, 5])

In [28]:
X.shape

torch.Size([3, 224, 224])