In [1]:
import torch
import torchvision
from torch import nn, optim

from torchsummary import summary

In [2]:
batch_size = 32
epoch = 30
learning_rate = 0.01

In [3]:
trans = torchvision.transforms.ToTensor()

train_data = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
            'mnist_data', train=True, download=True, transform=trans
            ), batch_size=batch_size
            )
val_data = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
            'mnist_data', train=False, download=True, transform=trans
            ), batch_size=batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_data\MNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting mnist_data\MNIST\raw\train-images-idx3-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz to mnist_data\MNIST\raw



In [4]:
class Ann(nn.Module):
    
    def __init__(self):
        super(Ann, self).__init__()
        self.linear1 = nn.Linear(28*28, 150)
        self.linear2 = nn.Linear(150, 10)
                    
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = nn.Flatten()(x)
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [5]:
def validate(model, data):
    total = 0
    correct = 0
    for i, (images, labels) in enumerate(data):
        y_pred = model(images)
        value, pred = torch.max(y_pred, 1)
        total += y_pred.size(0)
        correct += torch.sum(pred == labels)
    return correct * 100 / total

In [6]:
ann = Ann()
summary(ann, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 150]         117,750
              ReLU-2                  [-1, 150]               0
            Linear-3                   [-1, 10]           1,510
Total params: 119,260
Trainable params: 119,260
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.45
Estimated Total Size (MB): 0.46
----------------------------------------------------------------


In [7]:
%%time

optimizer = optim.Adam(ann.parameters(), lr=learning_rate)
cross_entropy = nn.CrossEntropyLoss()

for n in range(epoch):
    for i, (images, labels) in enumerate(train_data):
        optimizer.zero_grad()
        prediction = ann(images)
        loss = cross_entropy(prediction, labels)
        loss.backward()
        optimizer.step()
    accuracy = float(validate(ann, val_data))
    print("Epoch:", n+1, "Loss: ", float(loss.data), "Val. Accuracy:", accuracy)

Epoch: 1 Loss:  0.12276075780391693 Val. Accuracy: 93.5
Epoch: 2 Loss:  0.0048448871821165085 Val. Accuracy: 95.19000244140625
Epoch: 3 Loss:  0.04165646806359291 Val. Accuracy: 93.12999725341797
Epoch: 4 Loss:  0.005520841106772423 Val. Accuracy: 95.0999984741211
Epoch: 5 Loss:  0.001014151843264699 Val. Accuracy: 95.38999938964844
Epoch: 6 Loss:  0.002359476638957858 Val. Accuracy: 95.18000030517578
Epoch: 7 Loss:  0.36308231949806213 Val. Accuracy: 94.44999694824219
Epoch: 8 Loss:  0.00022598567011300474 Val. Accuracy: 94.54000091552734
Epoch: 9 Loss:  0.005337177775800228 Val. Accuracy: 96.0199966430664
Epoch: 10 Loss:  0.0006358727696351707 Val. Accuracy: 95.9800033569336
Epoch: 11 Loss:  0.050514962524175644 Val. Accuracy: 96.18000030517578
Epoch: 12 Loss:  0.0005741254426538944 Val. Accuracy: 95.94999694824219
Epoch: 13 Loss:  4.383684790809639e-05 Val. Accuracy: 95.91000366210938
Epoch: 14 Loss:  3.762528137940535e-07 Val. Accuracy: 95.88999938964844
Epoch: 15 Loss:  2.37665631

We can see that there are less parameters and the training time is faster. We also see loss fluctuating after reaching 0.