In [None]:
import torch
import random
import numpy as np
import torchvision.datasets
import matplotlib.pyplot as plt

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

In [None]:
MNIST_train = torchvision.datasets.MNIST('./', download = True, train = True)
MNIST_test = torchvision.datasets.MNIST('./', download = True, train = False)

In [None]:
X_train = MNIST_train.train_data
y_train = MNIST_train.train_labels
X_test = MNIST_test.test_data
y_test = MNIST_test.test_labels

In [None]:
X_train = X_train.unsqueeze(1).float()
X_test = X_test.unsqueeze(1).float()
X_test.shape

In [None]:
class LeNet5(torch.nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=6, kernel_size=5, padding=2)
        self.act1 = torch.nn.Tanh()
        self.pool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.conv2 = torch.nn.Conv2d(
            in_channels=6, out_channels=16, kernel_size=5, padding=0)
        self.act2 = torch.nn.Tanh()
        self.pool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.fc1 = torch.nn.Linear(5 * 5 * 16, 120)
        self.act3 = torch.nn.Tanh()
        
        self.fc2 = torch.nn.Linear(120, 84)
        self.act4 = torch.nn.Tanh()
        
        self.fc3 = torch.nn.Linear(84, 10)
        
    def forward(self, x):
        x = self.conv1(x) 
        x = self.act1(x) 
        x = self.pool1(x) 
        
        x = self.conv2(x)  
        
        x = self.act2(x)  
        x = self.pool2(x) 
       
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))  
        
        x = self.fc1(x)  
        x = self.act3(x) 
        
        x = self.fc2(x)  
        x = self.act4(x)  
        
        x = self.fc3(x) 
        
        return x
    
lenet5 = LeNet5()

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
lenet5 = lenet5.to(device)

In [None]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lenet5.parameters(), lr = 1.0e-3)

In [None]:
batch_size = 100

test_accuracy_history = []
test_loss_history = []

X_test = X_test.to(device)
y_test = y_test.to(device)

for epoch in range(10000):
    order = np.random.permutation(len(X_train))
    
    for start_index in range(0, len(X_train), batch_size):
        optimizer.zero_grad()
        
        batch_indexes = order[start_index:start_index+batch_size]
        
        X_batch = X_train[batch_indexes].to(device)
        y_batch = y_train[batch_indexes].to(device)
        
        preds = lenet5.forward(X_batch)
        
        loss_val = loss(preds, y_batch)
        loss_val.backward()
        
        optimizer.step()
        
    test_preds = lenet5.forward(X_test)
    test_loss_history.append(loss(test_preds, y_test).item())
    
    accuracy = (test_preds.argmax(dim=1) == y_test).float().mean()
    test_accuracy_history.append(accuracy.item())
    print(accuracy)

In [None]:
plt.plot(test_accuracy_history)

In [None]:
plt.plot(test_loss_history)

In [None]:
i = int(input("Enter a number from 0 to 9999 to check the model: "))

print("Image:")
X_test = X_test.reshape([-1, 28, 28])
plt.imshow(X_test[i, :, :])
plt.show()

X_test = X_test.unsqueeze(1)
print("Detected digit on the image:", lenet5.forward(X_test)[i].argmax(dim=0).item())