In [6]:
import os
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

# Directories
train_dir = "dataset/mnist"
val_dir = "dataset/mnist_val"

# Create directories if they don't exist
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

# Download the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist_dataset = datasets.MNIST(root="./dataset", train=True, transform=transform, download=True)

# Calculate split sizes
total_size = len(mnist_dataset)
val_size = int(0.01 * total_size)  # 1% for validation
train_size = total_size - val_size  # Remaining 99% for training

# Split dataset
train_dataset, val_dataset = random_split(mnist_dataset, [train_size, val_size])

# Save train dataset
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
for i, (image, label) in enumerate(train_loader):
    image_path = os.path.join(train_dir, f"{str(i).zfill(5)}.png")  # Sequentially numbered
    transforms.ToPILImage()(image.squeeze()).save(image_path)

# Save validation dataset
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
for i, (image, label) in enumerate(val_loader):
    image_path = os.path.join(val_dir, f"{str(i).zfill(5)}.png")  # Sequentially numbered
    transforms.ToPILImage()(image.squeeze()).save(image_path)

print(f"Training data saved in {train_dir}")
print(f"Validation data saved in {val_dir}")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 12.2MB/s]


Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 359kB/s]


Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.17MB/s]


Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 13.0MB/s]


Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw

Training data saved in dataset/mnist
Validation data saved in dataset/mnist_val
