<a href="https://colab.research.google.com/github/ProfessorDong/Deep-Learning-Course-Examples/blob/master/CNN_Examples/LeNet_MNIST_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


Implement the LeNet architecture for classification of MNIST data using PyTorch. 

1. Define the device to use for training. If a GPU is available, the code will use it for faster computation.

2. Define the data transformations to use for the MNIST dataset, including resizing the images to 32x32, converting them to tensors, and normalizing the pixel values to the range [-1, 1].

3. Load the MNIST dataset using the `datasets.MNIST` class and the data loaders using the `DataLoader` class.

4. Define the LeNet model architecture using the `nn.Module` class and the convolutional, max pooling, and fully connected layers using the `nn.Conv2d`, `nn.MaxPool2d`, and `nn.Linear` classes.

5. Create an instance of the model and move it to the device using the `to()` method.

6. Define the loss function and optimizer to use during training using the `nn.CrossEntropyLoss` and `optim.SGD` classes.

7. Train the model using a loop over the data loader and the `forward()` and `backward()` methods of the model, followed by an optimization step using the `step()` method of the optimizer.

8. Evaluate the model on the test set using a loop over the data loader and the `forward()` method of the model, followed by comparing the predicted labels to the true labels and computing the test accuracy.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
# Define the device to use for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Define the data transformations to use for the MNIST dataset
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [4]:
# Load the MNIST dataset
trainset = datasets.MNIST(root='./data', train=True,
                          download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False,
                         download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

In [5]:
# Define the LeNet model architecture
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool1(self.relu(self.conv1(x)))
        x = self.pool2(self.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
# Create an instance of the model and move it to the device
net = LeNet().to(device)

In [7]:
# Define the loss function and optimizer to use during training
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

In [8]:
# Train the model
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0

[1,   200] loss: 1.261
[1,   400] loss: 0.180
[2,   200] loss: 0.107
[2,   400] loss: 0.094
[3,   200] loss: 0.064
[3,   400] loss: 0.068
[4,   200] loss: 0.051
[4,   400] loss: 0.048
[5,   200] loss: 0.042
[5,   400] loss: 0.040
[6,   200] loss: 0.033
[6,   400] loss: 0.037
[7,   200] loss: 0.031
[7,   400] loss: 0.030
[8,   200] loss: 0.025
[8,   400] loss: 0.027
[9,   200] loss: 0.020
[9,   400] loss: 0.025
[10,   200] loss: 0.019
[10,   400] loss: 0.021


In [11]:
# Evaluate the model on the test set
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # display(images.shape)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test accuracy: %.2f%%' % (100 * correct/total))

Test accuracy: 98.97%


In [12]:
# Show the model summary
from torchsummary import summary

summary(net, input_size=(1, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 28, 28]             156
              ReLU-2            [-1, 6, 28, 28]               0
         MaxPool2d-3            [-1, 6, 14, 14]               0
            Conv2d-4           [-1, 16, 10, 10]           2,416
              ReLU-5           [-1, 16, 10, 10]               0
         MaxPool2d-6             [-1, 16, 5, 5]               0
            Linear-7                  [-1, 120]          48,120
              ReLU-8                  [-1, 120]               0
            Linear-9                   [-1, 84]          10,164
             ReLU-10                   [-1, 84]               0
           Linear-11                   [-1, 10]             850
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/ba