In [2]:
import os
os.environ['PATH'] = f"{os.environ['PATH']}:/root/.local/bin"


In [4]:
## This cell contains the essential imports you will need – DO NOT CHANGE THE CONTENTS! ##
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets
from torch.utils.data.dataloader import DataLoader

### About preprocessing

First we need to convert our black and white images to tensor values. Then we need to normalize our pixels values, from [0,1] to [-1,1]

In [27]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

In [6]:
train_dataset = datasets.MNIST(root='data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='data/', train=False, transform=transform, download=True)

In [7]:
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [8]:
def show5(img_loader):
    dataiter = iter(img_loader)
    
    batch = next(dataiter)
    labels = batch[1][0:5]
    images = batch[0][0:5]
    for i in range(5):
        print(int(labels[i].detach()))
    
        image = images[i].numpy()
        plt.imshow(image.T.squeeze().T)
        plt.show()

In [9]:
for images, labels in train_loader:
    print(images.shape)  
    print(labels.shape)  
    break

torch.Size([64, 1, 28, 28])
torch.Size([64])


In [14]:
class Model(nn.Module):
    def __init__(self):
        super(ComplexMNISTModel, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)  
        
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.dropout = nn.Dropout(0.25)
        
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x)))) 
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        
        x = x.view(-1, 128 * 3 * 3)
        
        x = F.relu(self.fc1(x))
        x = self.dropout(x) 
        x = F.relu(self.fc2(x))
        x = self.fc3(x) 
        
        return x

    def train_model(self, train_loader, num_epochs=5, learning_rate=0.001):
        optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(num_epochs):
            running_loss = 0.0
            for images, labels in train_loader:
                optimizer.zero_grad()
                
                outputs = self(images)
                loss = criterion(outputs, labels) 
                
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()

            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss/len(train_loader)}')

    def test_model(self, test_loader):
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                outputs = self(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        print(f'Accuracy: {accuracy:.2f}%')

In [15]:
model = ComplexMNISTModel()

model.train_model(train_loader, num_epochs=5)

model.test_model(test_loader)

Epoch [1/5], Loss: 0.13579946976087526
Epoch [2/5], Loss: 0.04601224882752577
Epoch [3/5], Loss: 0.0362548575171757
Epoch [4/5], Loss: 0.028073561914010232
Epoch [5/5], Loss: 0.023774999313264416
Accuracy: 99.00%


In [28]:
torch.save(model, 'model_full.pth')