In [1]:
import torch
import torchvision
from torch import nn, optim

from torchsummary import summary

In [2]:
batch_size = 32
epoch = 30
learning_rate = 0.01

In [3]:
trans = torchvision.transforms.ToTensor()

train_data = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
            'mnist_data', train=True, download=True, transform=trans
            ), batch_size=batch_size
            )
val_data = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
            'mnist_data', train=False, download=True, transform=trans
            ), batch_size=batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_data\MNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting mnist_data\MNIST\raw\train-images-idx3-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz to mnist_data\MNIST\raw



In [4]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, stride=1)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1)
        
        self.tanh = nn.Tanh()
        self.linear1 = nn.Linear(3456, 10)

    def forward(self, x):
        x = self.tanh(self.conv1(x))
        x = self.tanh(self.conv2(x))
        x = x.view(x.shape[0], -1)
        x = self.linear1(x)
        return x

In [5]:
def validate(model, data):
    total = 0
    correct = 0
    for i, (images, labels) in enumerate(data):
        y_pred = model(images)
        value, pred = torch.max(y_pred, 1)
        total += y_pred.size(0)
        correct += torch.sum(pred == labels)
    return correct * 100 / total

In [6]:
convnet = ConvNet()
summary(convnet, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 3, 26, 26]              30
              Tanh-2            [-1, 3, 26, 26]               0
            Conv2d-3            [-1, 6, 24, 24]             168
              Tanh-4            [-1, 6, 24, 24]               0
            Linear-5                   [-1, 10]          34,570
Total params: 34,768
Trainable params: 34,768
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.13
Estimated Total Size (MB): 0.22
----------------------------------------------------------------


In [7]:
%%time

optimizer = optim.Adam(convnet.parameters(), lr=learning_rate)
cross_entropy = nn.CrossEntropyLoss()

for n in range(epoch):
    for i, (images, labels) in enumerate(train_data):
        optimizer.zero_grad()
        prediction = convnet(images)
        loss = cross_entropy(prediction, labels)
        loss.backward()
        optimizer.step()
    accuracy = float(validate(convnet, val_data))
    print("Epoch:", n+1, "Loss: ", float(loss.data), "Val. Accuracy:", accuracy)

Epoch: 1 Loss:  0.06866645067930222 Val. Accuracy: 86.27999877929688
Epoch: 2 Loss:  0.13885276019573212 Val. Accuracy: 85.95999908447266
Epoch: 3 Loss:  0.1134258359670639 Val. Accuracy: 89.44999694824219
Epoch: 4 Loss:  0.23833653330802917 Val. Accuracy: 82.19999694824219
Epoch: 5 Loss:  0.1914588063955307 Val. Accuracy: 89.08000183105469
Epoch: 6 Loss:  0.12360385060310364 Val. Accuracy: 89.56999969482422
Epoch: 7 Loss:  0.13891775906085968 Val. Accuracy: 89.7699966430664
Epoch: 8 Loss:  0.1110406443476677 Val. Accuracy: 89.33000183105469
Epoch: 9 Loss:  0.08644136786460876 Val. Accuracy: 87.41999816894531
Epoch: 10 Loss:  0.17617720365524292 Val. Accuracy: 88.69999694824219
Epoch: 11 Loss:  0.11897692084312439 Val. Accuracy: 87.08999633789062
Epoch: 12 Loss:  0.07859304547309875 Val. Accuracy: 87.76000213623047
Epoch: 13 Loss:  0.025874320417642593 Val. Accuracy: 88.2699966430664
Epoch: 14 Loss:  0.06950537860393524 Val. Accuracy: 88.75
Epoch: 15 Loss:  0.0964871272444725 Val. Accu

We can see that the model is trying to reduce the loss with varying validation accuracy.