In [None]:
from pathlib import Path

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset

In [9]:
# train/REAL/*.jpg
# train/FAKE/*.jpg
# test/REAL/*.jpg
# test/FAKE/*.jpg


class CIFAKE_Dataset(Dataset):
    """CIFAKE dataset."""

    def __init__(self, root_path: str, real_path: str, fake_path: str) -> None:
        """Initialize the CIFAKE dataset."""
        self.root_path = root_path
        self.real_path = real_path
        self.fake_path = fake_path
        self.real_img_paths = list(Path(root_path).joinpath(real_path).glob("*.jpg"))
        self.fake_img_paths = list(Path(root_path).joinpath(fake_path).glob("*.jpg"))

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.real_img_paths) + len(self.fake_img_paths)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
        """Return the item at the given index."""
        if idx < len(self.real_img_paths):
            img_path = self.real_img_paths[idx]
            label = 0  # REAL
        else:
            img_path = self.fake_img_paths[idx - len(self.real_img_paths)]
            label = 1  # FAKE
        image = Image.open(img_path).convert("RGB")
        return image, label

In [12]:
cifake_train = CIFAKE_Dataset("data/train", "REAL", "FAKE")
cifake_test = CIFAKE_Dataset("data/test", "REAL", "FAKE")

In [15]:
import multiprocessing

max_num_workers = multiprocessing.cpu_count()
max_num_workers

16

In [16]:
train_loader = DataLoader(cifake_train, batch_size=32, shuffle=True, num_workers=max_num_workers)
test_loader = DataLoader(cifake_test, batch_size=32, shuffle=False, num_workers=max_num_workers)

In [None]:
# Flatten the DataLoader to list of images so I can use in ML algorithms that do not accept DataLoader
train_images = []
train_labels = []
for images, labels in train_loader:
    train_images.extend(images)
    train_labels.extend(labels)

test_images = []
test_labels = []
for images, labels in test_loader:
    test_images.extend(images)
    test_labels.extend(labels)

In [None]:
train_images, train_labels