In [2]:
import os
import shutil
from datasets import load_dataset, concatenate_datasets, DatasetDict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET_SOURCE = "Ryan-sjtu/celebahq-caption"

dataset = load_dataset(DATASET_SOURCE)

In [3]:
train_percent = 0.45
test_percent = 0.05

train_size = int(len(dataset["train"]) * train_percent)
test_size = int(len(dataset["train"]) * test_percent)
print(f"Train size: {train_size}, Test size: {test_size}")

Train size: 13500, Test size: 1500


In [4]:
def label_dataset(batch):
    # `batch` is a dictionary where each value is a list of entries
    texts = batch['text']
    # Apply the label logic to each entry in the batch
    batch['label'] = [1 if 'woman' in text.lower() else 0 for text in texts]
    return batch

ds = dataset['train'].map(label_dataset, batched=True, batch_size=32, remove_columns=['text'], num_proc=4)

Map (num_proc=4): 100%|██████████| 30000/30000 [00:07<00:00, 4255.21 examples/s] 


In [5]:
ds

Dataset({
    features: ['image', 'label'],
    num_rows: 30000
})

In [6]:
def filter_batch(batch, value):
    return [x for x in batch if x['label'] == value]

In [7]:
dataset_0 = ds.filter(lambda batch: [x==0 for x in batch['label']], batched=True, batch_size=32, num_proc=4)
dataset_1 = ds.filter(lambda batch: [x==1 for x in batch['label']], batched=True, batch_size=32, num_proc=4)
print(f"Dataset 0 size: {len(dataset_0)}, Dataset 1 size: {len(dataset_1)}")

Dataset 0 size: 10966, Dataset 1 size: 19034


In [8]:
def split_dataset(dataset, train_size, test_size):
    split_dataset = dataset.train_test_split(test_size=test_size)
    train_dataset = split_dataset['train'].train_test_split(train_size=train_size / (1 - test_size))
    return train_dataset['train'], train_dataset['test'], split_dataset['test']

train_dataset_0, _, test_dataset_0 = split_dataset(dataset_0, train_percent, test_percent)
train_dataset_1, _, test_dataset_1 = split_dataset(dataset_1, train_percent, test_percent)

# Combine 
train_dataset = concatenate_datasets([train_dataset_0, train_dataset_1])
test_dataset = concatenate_datasets([test_dataset_0, test_dataset_1])

train_dataset = train_dataset.shuffle(seed=0)
test_dataset = test_dataset.shuffle(seed=0)

final_datasets = DatasetDict({
    'train': train_dataset,
    'test': test_dataset
})

In [9]:
final_datasets

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 13499
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1501
    })
})

In [3]:
save_dir = "../output/celebahq/test/all"
if os.path.exists(save_dir):
    shutil.rmtree(save_dir)
    print(f"Deleted {save_dir}")
os.makedirs(save_dir, exist_ok=True)

Deleted ../output/celebahq/test/all


In [4]:
def create_save_image_function(save_directory):
    # create closure to encapsulate save_directory
    def save_image(example, index):
        import os
        image = example['image']
        filepath = os.path.join(save_directory, f'image_{index}.png')
        image.save(filepath)
        return example  # Return the unmodified example
    return save_image

save_image_function = create_save_image_function(save_dir)

test_dataset.map(save_image_function, with_indices=True, batched=False, num_proc=6)

NameError: name 'test_dataset' is not defined