In [74]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset


class Imagenette(Dataset):
    def __init__(
        self,
        annotations_file,
        img_dir,
        transform=None,
        target_transform=None,
        valid=False,
        label_noise=0,
    ):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.valid = valid

        self.img_labels = self.img_labels[self.img_labels["is_valid"] == valid]

        # 0 = 0% noise, 1 = 1% noise, 2 = 5% noise, 3 = 25% noise, 4 = 50% noise
        if label_noise < 0:
            label_noise = 0
        elif label_noise > 4:
            label_noise = 4
        self.label_noise = label_noise + 1

        self.classes = {
            label: i
            for i, label in enumerate(
                self.img_labels.iloc[:, self.label_noise].unique()
            )
        }

    def __len__(self) -> int:
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, self.label_noise]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        label = self.classes[label]
        return image, label

In [76]:

training_data = Imagenette(
    "../data/raw/imagenette2-320/noisy_imagenette.csv", "../data/raw/imagenette2-320/", valid=False, label_noise=2)

In [77]:
from torch.utils.data import DataLoader
training_generator = DataLoader(training_data)

array(['n02979186', 'n03417042', 'n03394916', 'n03888257', 'n02102040',
       'n01440764', 'n03000684', 'n03028079', 'n03425413', 'n03445777'],
      dtype=object)

In [79]:
for x, y in training_generator:
    print(x,y)
    break

tensor([[[[177, 176, 175,  ..., 215, 220, 169],
          [177, 176, 175,  ..., 225, 185, 162],
          [177, 177, 176,  ..., 188, 126, 147],
          ...,
          [127, 118, 125,  ..., 146, 140, 134],
          [ 91,  83,  89,  ..., 145, 138, 132],
          [105,  97, 102,  ..., 143, 134, 128]],

         [[177, 176, 175,  ..., 215, 220, 169],
          [177, 176, 175,  ..., 225, 185, 162],
          [177, 177, 176,  ..., 188, 126, 147],
          ...,
          [127, 118, 125,  ..., 109, 104,  98],
          [ 91,  83,  89,  ..., 108, 102,  96],
          [105,  97, 102,  ..., 106,  98,  92]],

         [[177, 176, 175,  ..., 215, 220, 169],
          [177, 176, 175,  ..., 225, 185, 162],
          [177, 177, 176,  ..., 188, 126, 147],
          ...,
          [127, 118, 125,  ...,  82,  78,  72],
          [ 91,  83,  89,  ...,  81,  76,  70],
          [105,  97, 102,  ...,  79,  72,  66]]]], dtype=torch.uint8) tensor([0])
