# Inspect ImageNet2012 torchvision Dataset

In [5]:
from PIL import Image
import torch
import torchvision

Ref for data preparation: https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh

In [6]:
dataset_train = torchvision.datasets.ImageNet(root='../data/imagenet2012/', split='train', transform=torchvision.transforms.ToTensor())
dataset_val = torchvision.datasets.ImageNet(root='../data/imagenet2012/', split='val', transform=torchvision.transforms.ToTensor())

In [7]:
print(f"Len of train dataset: {len(dataset_train)}")
print(f"Len of val dataset: {len(dataset_val)}")
print(f"Number of classes: {len(dataset_train.classes)}")
print(f"\nFirst 10th classes")
for i in range(10):
    print(f"Class {i}: {dataset_train.classes[i]}, ID={dataset_train.class_to_idx[dataset_train.classes[i][0]]}")
print(f"\nSample")
x, y = dataset_train[2]
print(f"Image: {x} ({type(x)=}), hw={x.size}")  # PIL Image with non-fixed size
print(f"Label: {y} ({type(y)=})")

Len of train dataset: 1281167
Len of val dataset: 50000
Number of classes: 1000

First 10th classes
Class 0: ('tench', 'Tinca tinca'), ID=0
Class 1: ('goldfish', 'Carassius auratus'), ID=1
Class 2: ('great white shark', 'white shark', 'man-eater', 'man-eating shark', 'Carcharodon carcharias'), ID=2
Class 3: ('tiger shark', 'Galeocerdo cuvieri'), ID=3
Class 4: ('hammerhead', 'hammerhead shark'), ID=4
Class 5: ('electric ray', 'crampfish', 'numbfish', 'torpedo'), ID=5
Class 6: ('stingray',), ID=6
Class 7: ('cock',), ID=7
Class 8: ('hen',), ID=8
Class 9: ('ostrich', 'Struthio camelus'), ID=9

Sample
Image: tensor([[[0.1216, 0.1412, 0.1373,  ..., 0.2667, 0.3137, 0.3059],
         [0.1373, 0.1529, 0.1529,  ..., 0.2431, 0.2314, 0.2471],
         [0.1529, 0.1529, 0.1451,  ..., 0.1804, 0.2314, 0.2118],
         ...,
         [0.3137, 0.2706, 0.2549,  ..., 0.4196, 0.3020, 0.1333],
         [0.3176, 0.2902, 0.2667,  ..., 0.4157, 0.2980, 0.1333],
         [0.2902, 0.3137, 0.2941,  ..., 0.4039, 0.

In [8]:
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=1, shuffle=False)
print(f"Example imgs batch: {next(iter(dataloader_train))[0].size()}")
print(f"Example labels batch: {next(iter(dataloader_train))[1].size()}")

Example imgs batch: torch.Size([1, 3, 250, 250])
Example labels batch: torch.Size([1])
