# Download and Explore ImageNet Dataset

This notebook demonstrates how to download the ImageNet dataset and visualize sample images.

**Note:** ImageNet requires registration and acceptance of terms of use. You'll need to obtain credentials from https://image-net.org/

## Install Required Libraries

In [None]:
!pip install torch torchvision matplotlib pillow requests

## Import Libraries

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import random

## Download ImageNet Dataset

ImageNet dataset requires manual download due to licensing. Here are the steps:

1. Register at https://image-net.org/
2. Download the ILSVRC2012 dataset
3. Extract to a local directory

Alternatively, we can use torchvision's ImageNet loader if you have the dataset downloaded.

In [None]:
# Set the path where ImageNet dataset will be stored
# If you already have ImageNet downloaded, set this to your dataset path
imagenet_path = Path('../data/imagenet')
imagenet_path.mkdir(parents=True, exist_ok=True)

print(f"ImageNet data directory: {imagenet_path}")
print("\nNote: You need to manually download ImageNet from https://image-net.org/")
print("The dataset should be organized with train/ and val/ subdirectories.")

## Alternative: Use a Subset or Sample Dataset

For demonstration purposes, we can use ImageNet samples available through torchvision or use the Tiny ImageNet that's already in your workspace.

In [None]:
# Using Tiny ImageNet which is already available in the workspace
tiny_imagenet_path = Path('../data/tiny-imagenet-200')

if tiny_imagenet_path.exists():
    print(f"Found Tiny ImageNet at: {tiny_imagenet_path}")
    data_path = tiny_imagenet_path / 'train'
else:
    print("Tiny ImageNet not found. Please download the full ImageNet dataset.")
    data_path = imagenet_path / 'train'

## Load Dataset

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

# Load dataset using ImageFolder (works for both ImageNet and Tiny ImageNet)
try:
    dataset = datasets.ImageFolder(root=str(data_path), transform=transform)
    print(f"Dataset loaded successfully!")
    print(f"Number of classes: {len(dataset.classes)}")
    print(f"Number of images: {len(dataset)}")
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please ensure the dataset is properly downloaded and extracted.")

## Display Sample Images

In [None]:
def imshow(img, title=None):
    """Display a tensor as an image."""
    img = img.numpy().transpose((1, 2, 0))
    # Denormalize if needed
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    if title:
        plt.title(title)
    plt.axis('off')

# Get random sample images
num_samples = 6
indices = random.sample(range(len(dataset)), num_samples)

# Create a grid of images
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, img_idx in enumerate(indices):
    img, label = dataset[img_idx]
    class_name = dataset.classes[label]
    
    plt.sca(axes[idx])
    imshow(img, title=f"Class: {class_name}")

plt.tight_layout()
plt.savefig('../images/imagenet_samples.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Displayed {num_samples} random images from the dataset")

## Display Images from Specific Classes

In [None]:
# Show a few images from the first class
first_class_idx = 0
class_name = dataset.classes[first_class_idx]

# Find all images from this class
class_indices = [i for i, (_, label) in enumerate(dataset.imgs) if label == first_class_idx]

# Display up to 4 images from this class
num_display = min(4, len(class_indices))
fig, axes = plt.subplots(1, num_display, figsize=(15, 4))

if num_display == 1:
    axes = [axes]

for idx in range(num_display):
    img, _ = dataset[class_indices[idx]]
    plt.sca(axes[idx])
    imshow(img, title=f"{class_name}")

plt.tight_layout()
plt.show()

print(f"Displayed {num_display} images from class: {class_name}")

## Dataset Statistics

In [None]:
# Calculate images per class
from collections import Counter

labels = [label for _, label in dataset.imgs]
label_counts = Counter(labels)

print(f"Dataset Statistics:")
print(f"Total classes: {len(dataset.classes)}")
print(f"Total images: {len(dataset)}")
print(f"Average images per class: {len(dataset) / len(dataset.classes):.1f}")
print(f"\nFirst 5 classes:")
for i, class_name in enumerate(dataset.classes[:5]):
    print(f"  {i}: {class_name} - {label_counts[i]} images")