In [25]:
import numpy as np
import torch
from datasets import load_dataset, Dataset as HfDataset
from torchvision import transforms
from functools import partial


url = "Ryan-sjtu/celebahq-caption"
image_size = 512
batch_size = 12
num_workers = 1

dataset = load_dataset(url, split="train")

In [27]:
train_test_split = dataset.train_test_split(test_size=0.2, shuffle=True)

train_dataset = train_test_split['train']
val_dataset = train_test_split['test']

In [28]:
len(train_dataset), len(val_dataset)

(24000, 6000)

In [29]:
class CustomResizeAndCrop:
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, image):
        width, height = image.size
        if width > height:
            scale = self.target_size / width
        else:
            scale = self.target_size / height

        new_width = int(width * scale)
        new_height = int(height * scale)

        resized_image = transforms.Resize((new_height, new_width))(image)
        final_image = transforms.CenterCrop(self.target_size)(resized_image)

        return final_image
    
def preprocess_celebahq_caption(sample, transform):
    prefix = "a photography of"
    image = transform(sample["image"])
    text = sample["text"].lower().removeprefix(prefix).strip()

    return image, text


def collate_celebahq_caption(samples):
    images, texts = zip(*samples)
    
    images = torch.stack(images)
    texts = np.stack(texts)

    return images, texts

In [30]:
from torch.utils.data import Dataset, DataLoader

class PytorchHuggingFaceDataset(Dataset):
    def __init__(self, hf_dataset, preprocess_fn):
        """
        Custom Dataset to apply transformations in batch.
        
        Args:
            hf_dataset: Hugging Face dataset.
            preprocess_fn: Function to preprocess a batch of data.
        """
        self.hf_dataset = hf_dataset
        self.preprocess_fn = preprocess_fn

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        data_point = self.hf_dataset[idx]
        image, text = self.preprocess_fn(data_point)

        return image, text

In [31]:
transform_fn = transforms.Compose(
            [
                CustomResizeAndCrop(target_size=image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5]),
            ]
        )

preprocess_batch = partial(preprocess_celebahq_caption, transform=transform_fn)
collate_fn = collate_celebahq_caption

In [32]:
val_ds = PytorchHuggingFaceDataset(val_dataset, preprocess_batch)

dl = DataLoader(
            dataset=val_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            persistent_workers=(num_workers > 0),
            collate_fn=collate_fn,
        )

In [34]:
for ds_imgs, ds_txts in dl:
    break

RuntimeError: DataLoader worker (pid(s) 26987) exited unexpectedly

In [None]:
ds_imgs.shape, 

(torch.Size([5, 3, 512, 512]), (5,))

In [None]:
ds_imgs.shape

torch.Size([5, 3, 512, 512])

In [None]:
ds_txts

array(['a woman with a very long blond hair',
       'a woman with blonde hair and blue eyes smiling',
       'a woman with a hat and a necklace',
       'a woman with wet hair in a pool',
       'a man with a suit and tie smiling'], dtype='<U46')