In [60]:
!pip install datasets



In [61]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from datasets import load_dataset

In [62]:
dataset = load_dataset("theneuralmaze/celebrity_faces").with_format('torch')['train']
print(dataset)

transformer = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
])

def transform(example):
    example['image'] = transformer(example['image'])
    return example

Dataset({
    features: ['image', 'label'],
    num_rows: 3000
})


In [63]:
test_transformer = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),
    transforms.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.2, hue=0.1) ])

def test_transform(example):
    example['image'] = test_transformer(example['image'])
    return example

In [64]:
unique_labels = set(dataset['label'])
num_unique_labels = len(unique_labels)
print(f"Number of calebrities: {num_unique_labels}")

Number of calebrities: 105


In [65]:
from collections import Counter

label_counts = Counter(dataset['label'])
print(f"Number of photos of a given celebrity:")
label_counts

Number of photos of a given celebrity:


Counter({'Jason Momoa': 42,
         'Krysten Ritter': 25,
         'alycia dabnem carey': 37,
         'Millie Bobby Brown': 28,
         'Miley Cyrus': 28,
         'Maisie Williams': 32,
         'Taylor Swift': 30,
         'Logan Lerman': 38,
         'Brenton Thwaites': 40,
         'Tom Holland': 30,
         'Alexandra Daddario': 36,
         'Nadia Hilker': 18,
         'Mark Zuckerberg': 16,
         'Keanu Reeves': 33,
         'Zac Efron': 39,
         'elizabeth olsen': 40,
         'Inbar Lavi': 23,
         'Henry Cavil': 30,
         'Gwyneth Paltrow': 29,
         'Anne Hathaway': 38,
         'Chris Evans': 20,
         'Hugh Jackman': 23,
         'Morena Baccarin': 34,
         'Bobby Morley': 24,
         'Alex Lawther': 24,
         'Ben Affleck': 19,
         'gal gadot': 33,
         'scarlett johansson': 35,
         'Anthony Mackie': 22,
         'Leonardo DiCaprio': 48,
         'Mark Ruffalo': 32,
         'Emma Watson': 45,
         'Jessica Barden': 21,
  

In [66]:
def filter_labels(example):
  return label_counts[example['label']] >= 13

filtered_dataset = dataset.filter(filter_labels)

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [67]:
ds = filtered_dataset.map(transform, batched=False)

split = ds.train_test_split(test_size=0.2)
train_ds, test_ds = split['train'], split['test']

train_loader = DataLoader(train_ds, shuffle=True)
test_loader = DataLoader(test_ds, shuffle=False)

Map:   0%|          | 0/2978 [00:00<?, ? examples/s]

In [68]:
test_ds = test_ds.map(test_transform, batched=False)

Map:   0%|          | 0/596 [00:00<?, ? examples/s]