In [1]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import os
from torchvision.datasets import ImageFolder

class DatasetLoader:
    def __init__(self, dataset_name="cifar10", batch_size=64):
        self.dataset_name = dataset_name.lower()
        self.batch_size = batch_size
        self.data_path = "../data"
        # self.data_root = os.path

    def get_dataset(self):
        if self.dataset_name == "cifar10":
            return self._load_cifar10()
        elif self.dataset_name == "ffhq":
            return self._load_ffhq()
        else:
            raise ValueError("only cifar10 or ffhq")
        
    def _load_cifar10(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        dataset = datasets.CIFAR10(root=self.data_path, train=True, download=True, transform=transform)
        return DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
    
    def _load_ffhq(self):
        transform = transforms.Compose([
            transforms.Resize((256, 256)),  # Resize FFHQ images to 256x256
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        ffhq_path = self.data_path + "/ffhq"

        if not os.path.exists(ffhq_path):
            raise FileNotFoundError(f"FFHQ dataset not found in {ffhq_path}")

        dataset = ImageFolder(root=ffhq_path, transform=transform)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

#test
torch_loader = DatasetLoader(dataset_name="cifar10", batch_size=64)
dataloader = torch_loader.get_dataset()

for images, labels in dataloader:
    print(images.shape, labels.shape)  # torch.Size([64, 3, 32, 32]) torch.Size([64])
    break


100%|██████████| 170M/170M [00:44<00:00, 3.87MB/s] 


torch.Size([64, 3, 32, 32]) torch.Size([64])
