In [10]:
pip install torch torchvision matplotlib


Note: you may need to restart the kernel to use updated packages.


In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

# Define transformations for MNIST.
# The MNIST dataset is normalized with mean=0.1307 and std=0.3081.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the training dataset.
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

# Load the test dataset.
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)


 15%|█▌        | 1.51M/9.91M [05:29<1:02:38, 2.24kB/s]

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class PlainNN(nn.Module):
    def __init__(self):
        super(PlainNN, self).__init__()
        # Input size is 28*28 = 784
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # Flatten the image: x shape = [batch_size, 1, 28, 28] -> [batch_size, 784]
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Instantiate the plain neural network.
plain_net = PlainNN()

# Print the number of parameters in PlainNN.
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Number of parameters in PlainNN:", count_parameters(plain_net))


In [None]:
class MNIST_CNN(nn.Module):
    def __init__(self):
        super(MNIST_CNN, self).__init__()
        # Convolutional layer 1: input channels=1, output channels=32, kernel size=3.
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        # Convolutional layer 2: input channels=32, output channels=64, kernel size=3.
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # Max pooling layer (2x2) will reduce spatial dimensions.
        self.pool = nn.MaxPool2d(2, 2)
        # After two poolings, the 28x28 image becomes 7x7.
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)  # Reduces spatial dimensions.
        x = x.view(-1, 64 * 14 * 14)  # Flatten the tensor.
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Instantiate the CNN.
cnn_net = MNIST_CNN()

print("Number of parameters in CNN:", count_parameters(cnn_net))


In [None]:
import torch.optim as optim

# Define the loss function.
criterion = nn.CrossEntropyLoss()

# Define optimizers for both models.
optimizer_plain = optim.Adam(plain_net.parameters(), lr=0.001)
optimizer_cnn = optim.Adam(cnn_net.parameters(), lr=0.001)

# Number of epochs for training.
num_epochs = 10


In [None]:
def train_model(model, optimizer, trainloader, num_epochs):
    model.train()  # Set model to training mode.
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:  # Print every 100 mini-batches.
                print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0
    print('Finished Training')
