# Working with CelebA Dataset

In this notebook, we'll demonstrate how to load, explore, visualize, preprocess, and create DataLoaders for the CelebA dataset using the `datasets` library.

In [None]:
# Step 1: Install Required Libraries
!pip install datasets matplotlib torch torchvision

### Step 2: Load the Dataset

https://huggingface.co/datasets/eurecom-ds/celeba-hq-small

In [None]:
from datasets import load_dataset

# Load the CelebA dataset
dataset = load_dataset('eurecom-ds/celeba-hq-small')

# Print the dataset structure
print(dataset)

### Step 3: Explore the Dataset

In [None]:
# Access the training split
train_dataset = dataset['train']

# Print the first example in the training dataset
print(train_dataset[0])

### Step 4: Visualize Some Images

In [None]:
import matplotlib.pyplot as plt

def show_images(dataset, num_images=5):
    plt.figure(figsize=(15, 5))
    for i in range(num_images):
        image = dataset[i]['image']
        plt.subplot(1, num_images, i+1)
        plt.imshow(image)
        plt.axis('off')
    plt.show()

# Show some images from the training dataset
show_images(train_dataset)

### Step 5: Preprocess the Images

In [None]:
from torchvision import transforms
import torch

# Define a transformation pipeline
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Apply the transformation to the dataset
def preprocess(batch):
    batch['image'] = [transform(image) for image in batch['image']]
    return batch

train_dataset = train_dataset.with_transform(preprocess)

### Step 6: Create DataLoaders

In [None]:
from torch.utils.data import DataLoader

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Iterate through the DataLoader
for batch in train_loader:
    images = batch['image']
    labels = batch['attributes']
    print(images.shape, labels.shape)
    break