In [4]:
import os
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pickle


def load_cifar10_batch(file):
    """
    Load a single CIFAR-10 batch file.
    """
    with open(file, 'rb') as f:
        batch = pickle.load(f, encoding='bytes')
        data = batch[b'data']
        labels = batch[b'labels']
        data = data.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)  # Convert to (N, H, W, C)
        return data, labels


class CIFAR10BinaryDataset(Dataset):
    def __init__(self, root, is_train=True, transform=None):
        """
        Dataset class for CIFAR-10 binary files.
        Args:
            root (str): Path to CIFAR-10 binary files.
            is_train (bool): If True, load training batches. Else, load test batch.
            transform (callable, optional): Transformations to apply to images.
        """
        self.root = root
        self.transform = transform
        self.data = []
        self.labels = []

        if is_train:
            for i in range(1, 6):
                batch_file = os.path.join(root, f'data_batch_{i}')
                data, labels = load_cifar10_batch(batch_file)
                self.data.append(data)
                self.labels.extend(labels)
            self.data = np.concatenate(self.data, axis=0)
        else:
            batch_file = os.path.join(root, 'test_batch')
            self.data, self.labels = load_cifar10_batch(batch_file)

        print(f'Loaded {"training" if is_train else "test"} data: {len(self.data)} samples')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index]
        label = self.labels[index]

        image = Image.fromarray(image)  # Convert numpy array to PIL Image
        if self.transform:
            image = self.transform(image)
        return image, label


def build_transform(is_train, args):
    """
    Build transformation pipeline.
    """
    mean = [0.4914, 0.4822, 0.4465]  # CIFAR-10 mean
    std = [0.247, 0.243, 0.261]  # CIFAR-10 std

    if is_train:
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    return transform


def build_dataset(is_train, args):
    """
    Build the dataset using CIFAR-10 binary files.
    """
    transform = build_transform(is_train, args)
    dataset = CIFAR10BinaryDataset(root=args.data_path, is_train=is_train, transform=transform)
    return dataset


# Example Arguments
class Args:
    data_path = 'C:/Users/ensin/OneDrive/Documenten/Universiteit/Thesis/cifar_alt'  # Path to CIFAR-10 binary files
    input_size = 32  # CIFAR-10 images are 32x32
    color_jitter = 0.4
    aa = None
    reprob = 0.0
    remode = 'pixel'
    recount = 1


# Build datasets
args = Args()
train_dataset = build_dataset(is_train=True, args=args)
val_dataset = build_dataset(is_train=False, args=args)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")


Loaded training data: 50000 samples
Loaded test data: 10000 samples
Train dataset size: 50000
Validation dataset size: 10000
