# Neural Networks in PyTorch - Example on MNIST Handwritten Digit Classification

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

# Set random seed for reproducibility
torch.manual_seed(302)
np.random.seed(302)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # this ensures that the mean pixel value is 0 and the standard deviation is 1
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

print(train_dataset); print(test_dataset)

# Creating data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)


## Explorative Data Analysis

In [None]:
from matplotlib import pyplot as plt

image_batch, label_batch = next(iter(train_loader))
image_batch.shape, label_batch.shape

print(label_batch[:9])

plt.figure(figsize=(10, 10))
plt.imshow(image_batch[:9, 0, :, :]
           .reshape(3, 3, 28, 28)
           .permute(0, 2, 1, 3)
           .reshape(28*3, 28*3), cmap='gray')
plt.show()

## Defining the Network

Feed-forward neural network with 3 hidden layers in PyTorch:

In [None]:
class NeuralNet(nn.Module):
    def __init__(self):
      super(NeuralNet, self).__init__()
      self.fc1 = nn.Linear(in_features=28*28, out_features=128)
      self.fc2 = nn.Linear(in_features=128, out_features=128)
      self.fc3 = nn.Linear(in_features=128, out_features=10)


    def forward(self, x):
      # Flatten the data (B, 1, 28, 28) => (B, 784), where B is the batch size
      x = torch.flatten(x, start_dim=1)

      # Pass data through 1st fully connected layer
      x = self.fc1(x)
      # Apply ReLU non-linearity
      x = F.relu(x)

      # Pass data through 2nd fully connected layer
      x = self.fc2(x)
      # Apply ReLU non-linearity
      x = F.relu(x)

      # Pass data through 3rd fully connected layer
      x = self.fc3(x)

      # Before passing x to the softmax function, the values in x are called *logits*.
      # Finally, apply softmax to x (logits)
      probs = F.softmax(x, dim=1)

      return probs

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = NeuralNet().to(device)
print(model)

# Parameters of the model
for p in model.parameters():
    print(p.shape)

In [None]:
# Just looking into how the inputs and (untrained) model outputs would look like
x = image_batch[:1].to(device)
print(f'Input image shape: {x.shape}')

probs = model(x)
print(f'Model output: {probs}')

probs = torch.exp(probs)
print(probs)
print(torch.sum(probs))

## Training the model

In [None]:
learning_rate = 0.01
num_epochs = 5

model = NeuralNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

losses = []
for epoch in range(num_epochs):
    print('-'*20, f'Epoch {epoch}', '-'*20)
    # Train one epoch
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        probs = model(data)
        loss = loss_fn(probs, target)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        if batch_idx % 100 == 0:
            print(f'Train Epoch {epoch} | Loss: {loss.item()}')
    print(f'\nAverage train loss in epoch {epoch}: {np.mean(losses[-len(train_loader):])}')

    # Evaluate on test set (for monitoring training at the end of each epoch)
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            probs = model(data)
            test_loss += loss_fn(probs, target).item()
            pred = torch.argmax(probs, dim=1)  # get the index of the max probability as the predicted output
            correct += (pred == target).sum().item()

    test_loss = test_loss / len(test_loader)
    avg_correct = correct / len(test_loader.dataset)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * avg_correct:.0f}%)\n')

## Plot Loss Curve

In [None]:
losses_smoothed = np.array(losses).reshape(-1, 10).mean(axis=1) # average every 10 batch losses
steps = np.arange(len(losses))

plt.figure(figsize=(10, 6))
plt.plot(steps, losses, 'b', alpha=0.5)
plt.plot(steps[::10], losses_smoothed, 'b')
plt.title('Training Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.show()