In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [3]:
# set random seed for reproducibility
# ensures the same random numbers are generated every time the script runs
torch.manual_seed(42)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [4]:
# load MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.7MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 496kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.93MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [5]:
def to_device(batch, device):
  # move tensor(s) to chose device
  if isinstance(batch, (list, tuple)):
    return [to_device(x, device) for x in batch]
  return batch.to(device)

In [6]:
class DeviceDataLoader:
  # wrap a dataloader to move data to a device
  def __init__(self, dl, device):
    self.dl = dl
    self.device = device

  def __iter__(self):
    # yield a batch of data after moving it to device
    for b in self.dl:
      yield to_device(b, self.device)

  def __len__(self):
    # number of batches
    return len(self.dl)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
# create data loaders

# batch_size returns 64 images at a time
# shuffle randomizes the order of images in each epoch.
# prevents the model from learning the order of the data
# reduces the risk of getting stuck in local minima
# helps the model generalize better
train_loader = DeviceDataLoader(DataLoader(train_dataset, batch_size=64, shuffle=True), device)

# use larger batch_size for testing
# during testing, you don't perform backpropogation
# larger batches are more efficient for evaluation
# you want consisted results across the entire test set

# no shuffling in for test_data
# you're not training on this data, just evaluating
# it makes it more reproducible
# the order doesn't affect evaluation metrics
test_loader = DeviceDataLoader(DataLoader(test_dataset, batch_size=1000, shuffle=False), device)

In [9]:
# define the CNN architecture
class SimpleCNN(nn.Module):
  def __init__(self):
    # calls the initialization method of the parent class (nn.Module).
    # it is necessary because the CNN class inherits from nn.Module base class
    super(SimpleCNN, self).__init__()

    # convolutional layers

    # input color channel = 1 (grayscale)

    # OUTPUT CHANNELS = 32. this is a common starting points for
    # number of filters in CNNs.
    # the number doubles in the next layer - a common practice to
    # increase filter count as you go deeper

    # KERNEL_SIZE = 3
    # 3x3 kernels are the most commonly used in modern CNNs
    # generally better than 5x5 or 7x7 because stacking
    # multiple 3x3 kernels gives the same receptive field with fewer params

    # PADDING=1
    # with a 3x3 kernel and padding=1, output size = input size (28x28)
    self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)

    # INPUT CHANNELS = 32
    # matches the output channels from the first layer

    # OUTPUT CHANNELS = 64
    # doubling # of channels is a common pattern in CNNs
    # as you go deeper, you want more feature maps to capture more
    # complex patterns.

    # same as above
    self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)

    # pooling layer
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # fully connected layers
    self.fc1 = nn.Linear(64 * 7 * 7, 128)
    self.fc2 = nn.Linear(128, 10)

    # activation functions
    self.relu = nn.ReLU()

  def forward(self, x):
    # first convolutional block
    x = self.conv1(x)
    x = self.relu(x)
    x = self.pool(x)

    # second convolutional block
    x = self.conv2(x)
    x = self.relu(x)
    x = self.pool(x)

    # flatten the output for the fully connected layer
    # flattens the 3D representation (channels x height x width) to 1D vector

    # -1 automatically calculates this dimension.
    # basically means 'however many samples are in the batch

    # 64: # of output channels from self.conv2

    # 7*7: spatial dimension after two rounds of max pooling:
    # MNIST images = 28x28
    # first max pooling: 14x14
    # second max pooling: 7x7
    # reshapes from [32, 64, 7, 7] (batch_size, channels, height, width)
    # to [32, 3136] (batch_size, flattened_features). 3136 = 64x7x7
    x = x.view(-1, 64 * 7 * 7)

    # fully connected layers
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)

    return x

In [10]:
# initialize the model, loss function, and optimizer
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
# training loop
def train(epochs):
  model.train()
  for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(train_loader):
      # forward pass
      # data flows through ALL the layers defined in the forward() method
      # OUTPUTS: raw model predictions (logits) with shape [batch_size, 10]
      outputs = model(images)

      # criterion is for classification tasks
      # compares predictions to truth
      # cross-entropy loss measures the difference between the predicted
      # probability distribution and the actual distribution (one-hot encoded labels)
      loss = criterion(outputs, labels)

      # backward pass and optimize
      # CRITICAL: resets the gradients of all parameters (weights and biases) to zero
      optimizer.zero_grad()

      # triggers backpropogation algorithm
      # after this call, every parameter in the model has a .grad attribute
      loss.backward()

      # updates all model parameters using the calculated gradients
      optimizer.step()

      # calculate statistics
      # accumulates the los value for the current batch into a running total
      # useful for calculating average loss over an entire epoch
      running_loss += loss.item()

      # gets the raw tensor data without the computational graph
      # torch.max returns 2 values: max value and its index along dim 1 (the class dimension)
      # for classification, we care about which class has the highest score, not the actual score value itself
      _, predicted = torch.max(outputs.data, 1)

      # keeps track of total number of images processed so far
      # labels.size(0) gets size of 1st dim of labels tensor, which equals the batch size
      total += labels.size(0)

      # how many predictions were correct in this batch
      # predicted == labels creates boolean tensor that corresponds to the prediction matching the label
      # .sum() counts how many True values there are
      # .item() converts the tensor to a python scalar
      correct += (predicted == labels).sum().item()

      if (i + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item(): .4f}')

    accuracy = 100 * correct / total
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader): .4f}, Accuracy: {accuracy:.2f}%')

In [12]:
# testing function
def test():
  # switch the model to evaluation mode
  # dropout layers are disabled (all neurons are active)
  # uses running statistics instead of batch statistics
  model.eval()
  # temporarily disables gradient calculation
  with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

In [13]:
train(epochs=5)

Epoch [1/5], Step [100/938], Loss:  0.2348
Epoch [1/5], Step [200/938], Loss:  0.2046
Epoch [1/5], Step [300/938], Loss:  0.0493
Epoch [1/5], Step [400/938], Loss:  0.0282
Epoch [1/5], Step [500/938], Loss:  0.0337
Epoch [1/5], Step [600/938], Loss:  0.0135
Epoch [1/5], Step [700/938], Loss:  0.0356
Epoch [1/5], Step [800/938], Loss:  0.0743
Epoch [1/5], Step [900/938], Loss:  0.0390
Epoch [1/5], Loss:  0.1353, Accuracy: 95.82%
Epoch [2/5], Step [100/938], Loss:  0.0220
Epoch [2/5], Step [200/938], Loss:  0.0160
Epoch [2/5], Step [300/938], Loss:  0.0770
Epoch [2/5], Step [400/938], Loss:  0.0212
Epoch [2/5], Step [500/938], Loss:  0.0726
Epoch [2/5], Step [600/938], Loss:  0.0141
Epoch [2/5], Step [700/938], Loss:  0.0702
Epoch [2/5], Step [800/938], Loss:  0.0052
Epoch [2/5], Step [900/938], Loss:  0.0040
Epoch [2/5], Loss:  0.0414, Accuracy: 98.73%
Epoch [3/5], Step [100/938], Loss:  0.0181
Epoch [3/5], Step [200/938], Loss:  0.0981
Epoch [3/5], Step [300/938], Loss:  0.0012
Epoch [

In [14]:
test()

Test Accuracy: 99.13%
