In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch import optim
import time

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)


BATCH_SIZE = 32
EPOCHS = 1
LR = 1e-2
MOMENTUM = 0.5

train_set = datasets.MNIST("data/", download=True, train=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True
)

test_set = datasets.MNIST("data/", download=True, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=BATCH_SIZE, shuffle=False
)


def count_digit_samples(dataset):
    counts = [0] * 10
    for _, label in dataset:
        counts[label] += 1
    return counts


def print_digit_counts(title, digit_counts):
    print(title)
    for i, count in enumerate(digit_counts):
        print(f"Digit {i}: {count}")


train_digit_counts = count_digit_samples(train_set)
test_digit_counts = count_digit_samples(test_set)

print(f"Trainset length: {len(train_set)}, Testset length: {len(test_set)}")
print_digit_counts("Trainset:", train_digit_counts)
print_digit_counts("Testset:", test_digit_counts)

Trainset length: 60000, Testset length: 10000
Trainset:
Digit 0: 5923
Digit 1: 6742
Digit 2: 5958
Digit 3: 6131
Digit 4: 5842
Digit 5: 5421
Digit 6: 5918
Digit 7: 6265
Digit 8: 5851
Digit 9: 5949
Testset:
Digit 0: 980
Digit 1: 1135
Digit 2: 1032
Digit 3: 1010
Digit 4: 982
Digit 5: 892
Digit 6: 958
Digit 7: 1028
Digit 8: 974
Digit 9: 1009


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.max_pool2d(x, kernel_size=2)

        x = self.dropout1(x)
        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

In [4]:
model = Net()
total_params = sum(p.numel() for p in model.parameters())

# Print the total number of parameters
print(f"Total number of parameters: {total_params}")

# Alternatively, print the number of parameters per layer
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} parameters")

Total number of parameters: 1199882
conv1.weight: 288 parameters
conv1.bias: 32 parameters
conv2.weight: 18432 parameters
conv2.bias: 64 parameters
fc1.weight: 1179648 parameters
fc1.bias: 128 parameters
fc2.weight: 1280 parameters
fc2.bias: 10 parameters


In [7]:
model = Net()

# Set the optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
criterion = nn.NLLLoss()

# Set the number of epochs to train for
n_epochs = EPOCHS
start_time = time.time()
for epoch in range(n_epochs):
    # Set the model to training mode
    model.train()

    running_loss = 0
    correct_predictions = 0

    # Loop over the training data
    for batch_idx, (data, target) in enumerate(train_loader):
        # Move the data and target to the device
        data, target = data, target
        print(data.shape)
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(data)

        # Compute the loss
        loss = criterion(output, target)

        # Backward pass
        loss.backward()

        # Update the weights
        optimizer.step()

        # Update the running loss and correct predictions
        running_loss += loss.item()
        _, predictions = torch.max(output.data, 1)
        correct_predictions += (predictions == target).sum().item()

    # Compute the average loss and accuracy for this epoch
    avg_loss = running_loss / len(train_loader.dataset)
    accuracy = correct_predictions / len(train_loader.dataset)

    # Print the loss for this epoch
    print(
        "Epoch: {} Loss: {:.6f} Accuracy: {:.2f}%".format(
            epoch + 1, avg_loss, accuracy * 100
        )
    )

training_time = time.time() - start_time
print(f"Training time: {training_time} seconds")

torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 2

KeyboardInterrupt: 

In [6]:
# Set the model to evaluation mode
model.eval()

# Initialize the running loss and correct predictions
running_loss = 0
correct_predictions = 0
digit_correct_predictions = [0] * 10
digit_total_samples = [0] * 10

# Loop over the test data
with torch.no_grad():
    for data, target in test_loader:
        # Move the data and target to the device
        data, target = data, target

        # Forward pass
        output = model(data)

        # Compute the loss
        loss = criterion(output, target)

        # Update the running loss and correct predictions
        running_loss += loss.item()
        _, predictions = torch.max(output.data, 1)
        correct_predictions += (predictions == target).sum().item()

        # Update per-digit counters
        for i in range(len(target)):
            digit_total_samples[target[i]] += 1
            if predictions[i] == target[i]:
                digit_correct_predictions[target[i]] += 1

# Compute the average loss and accuracy for this epoch
avg_loss = running_loss / len(test_loader.dataset)
accuracy = correct_predictions / len(test_loader.dataset)

# Print the loss and accuracy for this epoch
print("Test Loss: {:.6f} Test Accuracy: {:.2f}%".format(avg_loss, accuracy * 100))

# Print the test accuracy for each digit
print("Test accuracy per digit:")
for i in range(10):
    digit_accuracy = digit_correct_predictions[i] / digit_total_samples[i]
    print(f"Digit {i}: {digit_accuracy * 100:.2f}%")

Test Loss: 0.003797 Test Accuracy: 96.48%
Test accuracy per digit:
Digit 0: 98.67%
Digit 1: 98.85%
Digit 2: 95.16%
Digit 3: 96.24%
Digit 4: 96.03%
Digit 5: 96.19%
Digit 6: 97.60%
Digit 7: 96.21%
Digit 8: 95.28%
Digit 9: 94.35%
