# Implementation of LeNet5

In [40]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from functools import reduce
from tqdm import tqdm_notebook as tqdm

## Hyperparameter

In [72]:
batch_size = 64
lr = 0.001
epoch = 2

## Data

In [29]:
transform = transforms.ToTensor()
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)

## Define the structure of LeNet5

In [30]:
class LeNet5(nn.Module):
    
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 6, 5, padding=2),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5),
            nn.MaxPool2d(2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.fc3 = nn.Sequential(
            nn.Linear(84, 10)
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.fc1(x.view(-1, 16*5*5))
        x = self.fc2(x)
        x = self.fc3(x)
        return x

## Train

In [73]:
net = LeNet5()

In [74]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

In [75]:
epoch_bar = tqdm(range(epoch), desc='epoch')
for e in epoch_bar:
    batch_bar = tqdm(trainloader, desc='batch')
    for data in batch_bar:
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        batch_bar.set_description('Loss %.4f' % loss)

HBox(children=(IntProgress(value=0, description='epoch: ', max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, description='batch: ', max=938), HTML(value='')))

HBox(children=(IntProgress(value=0, description='batch: ', max=938), HTML(value='')))

## Test

In [76]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        pre = torch.max(net(images), 1)[1].numpy()
        correct += (pre == labels).sum().item()
        total += labels.shape[0]

In [77]:
print("Accuracy on %d test images: %.2f%%" % (total, 100*correct/total))

Accuracy on 60000 test images: 97.85%
