In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from PIL import Image
import matplotlib.pyplot as plt

In [6]:
class MnistNet(nn.Module):
    """
    Lightweight network architecture for the Mnist dataset (digit) classification
    """
    def __init__(self):
        super(MnistNet, self).__init__()
        self.num_classes = 10
        
        # fully convolutional part
        self.features = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(4, 4, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(inplace=True)            
        )
        
        # classifier, FC layers
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(16*4,16),
            nn.ReLU(inplace=True),
            nn.Linear(16,self.num_classes),
            nn.BatchNorm1d(self.num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x.view(-1,x.size(-3)*x.size(-2)*x.size(-1)))
        return x


def train(model, train_loader, optimizer):
    """
    Training of an epoch
    model: network
    train_loader: train_loader loading images and labels in batches
    optimizer: optimizer to use in the training
    """
    model.train()
    total_loss = 0
    train_total = 0
    train_correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad() # zero the accumulated gradients
        output = model(data) # computer network's output
        loss = F.cross_entropy(output, target) # computer the loss
        loss.backward() # backward pass
        optimizer.step() # update weights
        
        total_loss += loss.item()

        scores, predictions = torch.max(output.data, 1)
        train_total += target.size(0)
        train_correct += int(sum((1 for pr, val in zip(predictions, target) if pr == val)))
        acc = (train_correct / train_total) * 100
        
        if batch_idx % 100 == 0:
            print('[{}/{} ({:.0f}%)] Training\tBatch loss: {:.6f}\tAccuracy: {:.6f}%'.format(
                 batch_idx * len(data), len(train_loader.dataset),
                 100. * batch_idx / len(train_loader), loss.item()/len(data),
                 acc))
    
    print('Training: Epoch average loss {:.6f}'.format(total_loss/len(train_loader.dataset)),
          'Epoch accuracy {:.6f}%'.format((train_correct / train_total) * 100))
         
        
def test(model, val_loader):
    """
    Compute accuracy on the validation set
    model: network
    val_loader: test_loader loading images and labels in batches
    """
    model.eval()
    
    # implement validation procedure, report accuracy on the validation set

    total_loss = 0
    val_total = 0
    val_correct = 0
    for batch_idx, (data, target) in enumerate(val_loader):
        output = model(data)
        loss = F.cross_entropy(output, target)
        total_loss += loss.item()*data.size(0)
        scores, predictions = torch.max(output.data, 1)
        val_total += target.size(0)
        val_correct += int(sum((1 for pr, val in zip(predictions, target) if pr == val)))
        val_acc = (val_correct / val_total) * 100

        print('[{}/{} ({:.0f}%)] Validation\tBatch loss: {:.6f}\tAccuracy: {:.6f}%'.format(
             batch_idx * len(data), len(val_loader.dataset),
             100. * batch_idx / len(val_loader), loss.item()/len(data),
             val_acc))

    print('Validation: Epoch average loss {:.6f}\tAccuracy: {:.6f}%'
          .format(total_loss/len(val_loader.dataset), (val_correct / val_total) * 100))


In [7]:
# mnist dataset structure - train part
mnist_dataset_train = datasets.MNIST('vs3ex1data/mnist_data', train=True, transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ]), download=True)
# mnist dataset structure - test part
mnist_dataset_val = datasets.MNIST('vs3ex1data/mnist_data', train=False, transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ]), download=True)

# show sample images
print('Sample images')
for i in range(0,100,10):
    # plt.imshow(Image.fromarray(mnist_dataset_train.train_data[i].numpy(), mode='L'))
    # plt.show()
    pass

Sample images


In [8]:
# loader of the training set
train_loader = torch.utils.data.DataLoader(mnist_dataset_train,batch_size=16, shuffle=True)
# loader of the validation set
val_loader = torch.utils.data.DataLoader(mnist_dataset_val,batch_size=512, shuffle=False)

model = MnistNet() # initialize network structure
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(1, 20 + 1):
        print('Epoch {}'.format(epoch))
        train(model, train_loader, optimizer)
        test(model, val_loader)


Epoch 1
Training: Epoch average loss 0.045525 Epoch accuracy 77.326667%
Validation: Epoch average loss 0.182210	Accuracy: 94.940000%
Epoch 2
Training: Epoch average loss 0.031092 Epoch accuracy 85.051667%
Validation: Epoch average loss 0.178146	Accuracy: 95.310000%
Epoch 3
Training: Epoch average loss 0.028732 Epoch accuracy 86.346667%
Validation: Epoch average loss 0.160843	Accuracy: 95.900000%
Epoch 4
Training: Epoch average loss 0.027624 Epoch accuracy 86.861667%
Validation: Epoch average loss 0.141627	Accuracy: 96.020000%
Epoch 5


KeyboardInterrupt: 