In [2]:
import torch
from datasets import load_dataset
import torchvision.transforms as transforms

In [3]:
class DataManager:
    """
    Manager for Tiny ImageNet (zh-plus/tiny-imagenet) via Hugging Face Datasets.
    """
    def __init__(self, batch_size, num_workers=2):
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Transforms cho train và valid
        self.transform_train = transforms.Compose([
            transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.transform_valid = transforms.Compose([
            transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def get_loaders(self):
        # 1) Load dataset
        ds = load_dataset("zh-plus/tiny-imagenet")
        train_ds, valid_ds = ds["train"], ds["valid"]

        # 2) Nếu nhãn là string, ánh xạ sang số nguyên
        label_feature = train_ds.features["label"]
        class2idx = None
        if hasattr(label_feature, 'names'):
            class2idx = {name: idx for idx, name in enumerate(label_feature.names)}

        # 3) preprocess per-sample cho train
        def preprocess_train(ex):
            img = ex["image"]
            img = self.transform_train(img)
            lbl = ex["label"]
            if class2idx and isinstance(lbl, str):
                lbl = class2idx[lbl]
            return {"image": img, "label": lbl}

        # 4) preprocess per-sample cho valid
        def preprocess_valid(ex):
            img = ex["image"]
            img = self.transform_valid(img)
            lbl = ex["label"]
            if class2idx and isinstance(lbl, str):
                lbl = class2idx[lbl]
            return {"image": img, "label": lbl}

        train_ds = train_ds.map(preprocess_train, batched=False)
        valid_ds = valid_ds.map(preprocess_valid, batched=False)

        # 5) Chuyển sang tensor cho PyTorch
        train_ds.set_format(type="torch", columns=["image", "label"])
        valid_ds.set_format(type="torch", columns=["image", "label"])

        # 6) Tạo DataLoader
        train_loader = torch.utils.data.DataLoader(
            train_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers
        )
        valid_loader = torch.utils.data.DataLoader(
            valid_ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers
        )

        return train_loader, valid_loader

In [7]:
data_manager = DataManager(2, 2)
trainloader, testloader = data_manager.get_loaders()

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Generating train split: 100%|██████████| 100000/100000 [00:02<00:00, 39498.64 examples/s]
Generating valid split: 100%|██████████| 10000/10000 [00:00<00:00, 31081.21 examples/s]
Map: 100%|██████████| 100000/100000 [03:19<00:00, 502.11 examples/s]
Map: 100%|██████████| 10000/10000 [00:37<00:00, 266.29 examples/s]


In [9]:
for i in testloader:
    print(i)
    break

{'image': tensor([[[[2.0777, 2.1290, 2.1462,  ..., 2.2318, 2.2318, 2.2318],
          [2.1633, 2.1462, 2.1633,  ..., 2.2318, 2.2318, 2.2318],
          [2.1975, 2.1975, 2.1633,  ..., 2.2147, 2.2147, 2.2147],
          ...,
          [1.2728, 1.9578, 2.2318,  ..., 1.6838, 1.6667, 1.5982],
          [2.0434, 2.0777, 1.8037,  ..., 2.1462, 2.1119, 1.9920],
          [2.0092, 1.9407, 2.0263,  ..., 2.2318, 2.2318, 2.1975]],

         [[2.4111, 2.4111, 2.4286,  ..., 2.4111, 2.4111, 2.4111],
          [2.4286, 2.4286, 2.4111,  ..., 2.4111, 2.4111, 2.4111],
          [2.4286, 2.4286, 2.4111,  ..., 2.3936, 2.3936, 2.3936],
          ...,
          [1.4482, 2.1485, 2.4286,  ..., 1.8508, 1.8333, 1.7633],
          [2.4286, 2.4286, 2.1835,  ..., 2.3235, 2.2885, 2.2185],
          [2.4286, 2.4286, 2.4286,  ..., 2.4111, 2.4111, 2.4286]],

         [[2.5703, 2.5877, 2.6051,  ..., 2.6226, 2.6226, 2.6226],
          [2.6226, 2.6051, 2.6051,  ..., 2.6226, 2.6226, 2.6226],
          [2.6400, 2.6400, 2.605