In [11]:
import gzip
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim

In [12]:
train = gzip.open('train-images-idx3-ubyte.gz', 'rb')
train_label = gzip.open('train-labels-idx1-ubyte.gz', 'rb')
test = gzip.open('t10k-images-idx3-ubyte.gz', 'rb')
test_label = gzip.open('t10k-labels-idx1-ubyte.gz', 'rb')

train.read(16)
train_label.read(8)
test.read(16)
test_label.read(8)

train_image_data = np.frombuffer(train.read(), dtype=np.uint8).reshape(-1, 28, 28)
train_label_data = np.frombuffer(train_label.read(), dtype=np.uint8)

test_image_data = np.frombuffer(test.read(), dtype=np.uint8).reshape(-1, 28, 28)
test_label_data = np.frombuffer(test_label.read(), dtype=np.uint8)

# Ensure train_image_data and train_label_data have compatible shapes
train_image_data = train_image_data[:len(train_label_data)] # This will keep the first 12000 entries from the image data, or
# train_label_data = train_label_data[:len(train_image_data)] # will keep the first 48000 labels, based on your desired approach.

train_images, val_images, train_labels, val_labels = train_test_split(
    train_image_data, train_label_data, test_size=0.2, random_state=42 # Added random_state for reproducibility.
)

In [13]:
transform = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# Convert data to PyTorch tensors
train_images = torch.tensor(train_images, dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.long)  # Labels should be long tensors
val_images = torch.tensor(val_images, dtype=torch.float32)
val_labels = torch.tensor(val_labels, dtype=torch.long)
test_image_data = torch.tensor(test_image_data, dtype=torch.float32)
test_label_data = torch.tensor(test_label_data, dtype=torch.long)

# Create TensorDatasets
train_dataset = TensorDataset(train_images, train_labels)
val_dataset = TensorDataset(val_images, val_labels)
test_dataset = TensorDataset(test_image_data, test_label_data)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [14]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNN()

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [16]:
for epoch in range(10):
    model.train()
    for images, labels in train_loader:
        images = images.unsqueeze(1).float()
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.unsqueeze(1).float()
            output = model(images)
            val_loss += criterion(output, labels).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(labels.view_as(pred)).sum().item()
    val_loss /= len(val_loader.dataset)

    print('Epoch: {} \tValidation Loss: {:.6f} \tValidation Accuracy: {:.0f}%'.format(epoch, val_loss, 100. * correct / len(val_loader.dataset)))

Epoch: 0 	Validation Loss: 0.006219 	Validation Accuracy: 86%
Epoch: 1 	Validation Loss: 0.005054 	Validation Accuracy: 88%
Epoch: 2 	Validation Loss: 0.005176 	Validation Accuracy: 88%
Epoch: 3 	Validation Loss: 0.005487 	Validation Accuracy: 88%
Epoch: 4 	Validation Loss: 0.004785 	Validation Accuracy: 89%
Epoch: 5 	Validation Loss: 0.004794 	Validation Accuracy: 90%
Epoch: 6 	Validation Loss: 0.005137 	Validation Accuracy: 89%
Epoch: 7 	Validation Loss: 0.005217 	Validation Accuracy: 89%
Epoch: 8 	Validation Loss: 0.005519 	Validation Accuracy: 89%
Epoch: 9 	Validation Loss: 0.005380 	Validation Accuracy: 90%


Sources:
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html