In [1]:
from IPython.display import clear_output
!pip install datasets
clear_output()

In [2]:
import torch

from PIL import Image
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader

In [3]:
# Transformation pipeline to preprocess images: resize, convert to tensor, and normalize.
transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.Resize((128, 128)), # Set size as necessary
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Custom function to process and batch data.
def collate_fn(batch):
    """Custom collate function to handle batch processing."""
    images = []
    labels = []
    for item in batch:
        image = item['image']
        image = transform(image)
        images.append(image)
        labels.append(item['label'])
    # Return a batch of images and labels as tensors.
    return torch.stack(images), torch.tensor(labels)

In [5]:
# Load ImageNet-1K with streaming to handle the huge dataset efficiently
train_dataset = load_dataset('imagenet-1k', split='train', streaming=True, trust_remote_code=True)
test_dataset = load_dataset('imagenet-1k', split='test', streaming=True, trust_remote_code=True)

# Create a DataLoader from the streaming dataset.
def get_dataloader(train_dataset, batch_size: int, num_workers: int):
    return DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)

# Dataloader for training data.
train_dataloader = get_dataloader(train_dataset, batch_size=64, num_workers=2)

# Dataloader for testing data.
test_dataloader = get_dataloader(test_dataset, batch_size=64, num_workers=2)

In [6]:
for images, labels in train_dataloader:
    # training code
    print(f"Batch of images shape: {images.shape}")
    print(f"Batch of labels: {labels}")

Batch of images shape: torch.Size([64, 3, 64, 64])
Batch of labels: tensor([726, 917,  13, 939,   6, 983, 655, 579, 702, 845,  69, 822, 575, 906,
        752, 219, 192, 191, 292, 848, 108, 372, 765, 473, 525, 639, 686,  99,
        127,  76, 905, 550,  30, 634, 907, 979, 718, 154, 914, 293,   9, 922,
        130,  33, 968, 719, 653, 840, 139, 198, 236, 304, 547, 940, 215, 853,
        805,  28, 104,  67, 311, 429, 941, 950])
Batch of images shape: torch.Size([64, 3, 64, 64])
Batch of labels: tensor([486, 521, 871,  63, 650, 733, 482, 403, 397, 209,  21,  34, 248, 753,
        243, 858, 142, 434,   0, 766, 830, 222, 318, 800, 191, 983, 997, 610,
        991, 517, 848, 116, 148, 697, 933, 309, 297, 724, 357, 381, 653, 882,
        255, 823, 882, 208, 878, 927, 223, 911, 675,  44, 573, 964, 370, 141,
        134, 321, 789, 316, 631, 277, 917, 872])
Batch of images shape: torch.Size([64, 3, 64, 64])
Batch of labels: tensor([603, 971, 486, 504, 497, 670, 459, 559, 940,   9, 829, 888, 773, 7

KeyboardInterrupt: 