In [77]:
import torch
from torch.utils.data import Dataset
from torchvision.io import decode_image
import os
import pandas as pd
import torchvision.transforms.v2 as T

In [78]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, f"{self.img_labels.iloc[idx, 0]}.jpg")
        image = decode_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [79]:
def transform_target(label):
    brand_to_idx = {
        "adidas": 0,
        "converse": 1,
        "nike": 2,
    }
    return brand_to_idx[label]

transform = T.Compose([
    T.Resize((240, 240)),
    T.ToDtype(torch.float32, scale=True),
])

In [80]:
from torch.utils.data import DataLoader

training_data = CustomImageDataset(annotations_file="data\\train\\annotations.csv", img_dir="data\\train\\images", transform=transform,target_transform=transform_target)

testing_data = CustomImageDataset(annotations_file="data\\test\\annotations.csv", img_dir="data\\test\\images", transform=transform, target_transform=transform_target)

In [75]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(testing_data, batch_size=64, shuffle=False)

tensor(255, dtype=torch.uint8)
torch.Size([64])
