# Create dataset

In [17]:
import torchvision
from torchvision import transforms
import os
import shutil
import pickle

In [18]:
ROOT_PATH = os.path.dirname(os.getcwd())
DATA_PATH = os.path.join(ROOT_PATH, "data")
TRAIN_DATA_PATH = os.path.join(DATA_PATH, "train")
TEST_DATA_PATH = os.path.join(DATA_PATH, "test")

In [19]:
transform_train = transforms.Compose(
    [
        transforms.CenterCrop(24),
        transforms.Resize(size=32),
        transforms.ToTensor(),
        transforms.GaussianBlur(kernel_size=(3, 7), sigma=(1.1, 2.2)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.CenterCrop(24),
        transforms.Resize(size=32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_rotate_train = transforms.Compose(
    [
        torchvision.transforms.RandomRotation((30, 30)),
        transforms.CenterCrop(24),
        transforms.Resize(size=32),
        transforms.ToTensor(),
        transforms.GaussianBlur(kernel_size=(3, 7), sigma=(1.1, 2.2)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_rotate_test = transforms.Compose(
    [
        torchvision.transforms.RandomRotation((30, 30)),
        transforms.CenterCrop(24),
        transforms.Resize(size=32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

In [20]:
# Load CIFAR-10 train and test datasets
train_set = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root=DATA_PATH, train=False, download=True)

# Class names
classes = train_set.classes


# Function to save dataset
def save_dataset(base_dir, train=True, transform_name=None, transform=None):
    dataset = torchvision.datasets.CIFAR10(
        root=base_dir, train=train, download=True, transform=transform
    )
    file_name = (
        f"{"train" if train else "test"}_{transform_name if transform else "original"}"
    )
    save_dir = os.path.join(DATA_PATH, f"{file_name}.pkl")

    with open(save_dir, "wb") as f:
        pickle.dump(dataset, f)


def remove_leftover(data_path=DATA_PATH):
    file_path = os.path.join(data_path, "cifar-10-python.tar.gz")
    dir_path = os.path.join(data_path, "cifar-10-batches-py")

    # Remove the file if it exists
    try:
        os.remove(file_path)
        print(f"Removed file: {file_path}")
    except FileNotFoundError:
        print(f"File not found (skipped): {file_path}")

    # Remove the directory and its contents if it exists
    try:
        shutil.rmtree(dir_path)
        print(f"Removed directory: {dir_path}")
    except FileNotFoundError:
        print(f"Directory not found (skipped): {dir_path}")


# Save train and test datasets
save_dataset(
    base_dir=DATA_PATH, train=True, transform_name="original", transform=transform_train
)

save_dataset(
    base_dir=DATA_PATH,
    train=True,
    transform_name="rotated",
    transform=transform_rotate_train,
)

save_dataset(
    base_dir=DATA_PATH,
    train=False,
    transform_name="original",
    transform=transform_test,
)

save_dataset(
    base_dir=DATA_PATH,
    train=False,
    transform_name="rotated",
    transform=transform_rotate_test,
)
remove_leftover()

Removed file: d:\DATA SCIENCE\GIT\replicate_moe\data\cifar-10-python.tar.gz
Removed directory: d:\DATA SCIENCE\GIT\replicate_moe\data\cifar-10-batches-py
