In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
device

BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 10
NUM_CLASSES = 10

In [2]:
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
        ]),
    download=True,
)
test_dataset = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1325, ), std=(0.3105, ))
    ]),
    download=True,
)

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 119931204.78it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 110424515.79it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 32308342.08it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7003870.87it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [43]:
# Defining the LeNet Architecture

class LeNet(nn.Module):
  def __init__(self, num_classes=NUM_CLASSES):
    super().__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, padding=0, stride=1),
        nn.BatchNorm2d(6),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.layer2 = nn.Sequential(
        nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0, stride=1),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.flatten = nn.Flatten()
    self.fc = nn.Linear(400, 120)
    self.relu = nn.ReLU()
    self.fc1 = nn.Linear(120, 84)
    self.relu1 = nn.ReLU()
    self.fc2 = nn.Linear(84, num_classes)

  def forward(self, x):
    out = self.layer1(x)
    out = self.layer2(out)
    out = self.flatten(out)
    out = self.fc(out)
    out = self.relu(out)
    out = self.fc1(out)
    out = self.relu1(out)
    out = self.fc2(out)
    return out

In [44]:
model = LeNet(NUM_CLASSES).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

total_step = len(train_dataloader)

In [46]:
for epoch in range(EPOCHS):
  for i, (images, labels) in enumerate(train_dataloader):
    images = images.to(device)
    labels = labels.to(device)

    outputs = model(images)
    loss = loss_fn(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


    if (i+1) % 400 == 0:
      print(f"Epoch {epoch+1}/{EPOCHS}, Step {i}/{total_step}, Loss: {loss.item()}")

Epoch 1/10, Step 399/938, Loss: 0.047678954899311066
Epoch 1/10, Step 799/938, Loss: 0.11595113575458527
Epoch 2/10, Step 399/938, Loss: 0.017915252596139908
Epoch 2/10, Step 799/938, Loss: 0.012542378157377243
Epoch 3/10, Step 399/938, Loss: 0.04159318655729294
Epoch 3/10, Step 799/938, Loss: 0.03949810564517975
Epoch 4/10, Step 399/938, Loss: 0.028746791183948517
Epoch 4/10, Step 799/938, Loss: 0.005584072787314653
Epoch 5/10, Step 399/938, Loss: 0.045934081077575684
Epoch 5/10, Step 799/938, Loss: 0.001006414880976081
Epoch 6/10, Step 399/938, Loss: 0.006689663510769606
Epoch 6/10, Step 799/938, Loss: 0.0015493565006181598
Epoch 7/10, Step 399/938, Loss: 0.0258770938962698
Epoch 7/10, Step 799/938, Loss: 0.006097725592553616
Epoch 8/10, Step 399/938, Loss: 0.013535972684621811
Epoch 8/10, Step 799/938, Loss: 0.02034027688205242
Epoch 9/10, Step 399/938, Loss: 0.000592561555095017
Epoch 9/10, Step 799/938, Loss: 0.0040889158844947815
Epoch 10/10, Step 399/938, Loss: 0.000498464389238

In [47]:
# Testing

with torch.no_grad():
  correct = 0
  total =  0

  for images, labels in test_dataloader:
    images = images.to(device)
    labels = labels.to(device)

    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  print(f"Accuracy of the Network on the test images of size {total} is {correct/total:.2f}")

Accuracy of the Network on the test images of size 10000 is 0.99


In [49]:
from torchsummary import summary
summary(model, (1, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 30, 30]              60
       BatchNorm2d-2            [-1, 6, 30, 30]              12
              ReLU-3            [-1, 6, 30, 30]               0
         MaxPool2d-4            [-1, 6, 15, 15]               0
            Conv2d-5           [-1, 16, 11, 11]           2,416
       BatchNorm2d-6           [-1, 16, 11, 11]              32
              ReLU-7           [-1, 16, 11, 11]               0
         MaxPool2d-8             [-1, 16, 5, 5]               0
           Flatten-9                  [-1, 400]               0
           Linear-10                  [-1, 120]          48,120
             ReLU-11                  [-1, 120]               0
           Linear-12                   [-1, 84]          10,164
             ReLU-13                   [-1, 84]               0
           Linear-14                   