In [53]:
import numpy as np
import torch
from datasets import load_dataset, Dataset as HfDataset
from torchvision import transforms
from functools import partial


url = "Ryan-sjtu/celebahq-caption"
image_size = 512
batch_size = 5
num_workers = 0

dataset = load_dataset(url, split="train")

In [54]:
train_test_split = dataset.train_test_split(test_size=0.2, shuffle=True)

train_dataset = train_test_split['train']
val_dataset = train_test_split['test']

In [55]:
len(train_dataset), len(val_dataset)

(24000, 6000)

In [62]:
class CustomResizeAndCrop:
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, image):
        width, height = image.size
        if width > height:
            scale = self.target_size / width
        else:
            scale = self.target_size / height

        new_width = int(width * scale)
        new_height = int(height * scale)

        resized_image = transforms.Resize((new_height, new_width))(image)
        final_image = transforms.CenterCrop(self.target_size)(resized_image)

        return final_image
    
def preprocess_celebahq_caption(samples, transform):
    prefix = "a photography of"
    samples["text"] = [i.lower().removeprefix(prefix).strip() for i in samples["text"]]
    samples["image"] = [transform(i) for i in samples["image"]]

    return samples


def collate_celebahq_caption(samples):
    print(f"samples type: {type(samples)}")
    images = torch.stack([sample["image"] for sample in samples])
    texts = np.stack([sample["text"] for sample in samples])

    return images, texts

In [63]:
from torch.utils.data import Dataset, DataLoader

class BatchHuggingFaceDataset(Dataset):
    def __init__(self, hf_dataset, preprocess_fn, batch_size):
        """
        Custom Dataset to apply transformations in batch.
        
        Args:
            hf_dataset: Hugging Face dataset.
            preprocess_fn: Function to preprocess a batch of data.
            batch_size: Number of items in each batch.
        """
        self.hf_dataset = hf_dataset
        self.preprocess_fn = preprocess_fn
        self.batch_size = batch_size

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        start_idx = idx * self.batch_size
        end_idx = min(start_idx + self.batch_size, len(self.hf_dataset))

        batch = HfDataset.from_dict(self.hf_dataset[start_idx:end_idx])
        processed_batch = batch.map(self.preprocess_fn, batch_size=self.batch_size, batched=True)

        return processed_batch

In [64]:
transform_fn = transforms.Compose(
            [
                CustomResizeAndCrop(target_size=image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5]),
            ]
        )

preprocess_batch = partial(preprocess_celebahq_caption, transform=transform_fn)
collate_fn = collate_celebahq_caption

In [65]:
val_ds = BatchHuggingFaceDataset(
    val_dataset, preprocess_batch, batch_size
)

dl = DataLoader(
            dataset=val_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            persistent_workers=(num_workers > 0),
            collate_fn=collate_fn,
        )

In [66]:
for batch_data in dl:
    break

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map: 100%|██████████| 5/5 [00:03<00:00,  1.30 examples/s]
Map: 100%|██████████| 5/5 [00:03<00:00,  1.30 examples/s]
Map: 100%|██████████| 5/5 [00:03<00:00,  1.29 examples/s]
Map: 100%|██████████| 5/5 [00:03<00:00,  1.32 examples/s]
Map: 100%|██████████| 5/5 [00:03<00:00,  1.31 examples/s]


samples type: <class 'list'>


TypeError: expected Tensor as element 0 in argument 0, but got list

In [None]:
batch_data

NameError: name 'batch_data' is not defined