### An example of convolutional neural network on MNIST dataset

In [1]:
import torch 
import torchvision
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms

'''move the computations to the GPU if cuda is available, otherwise the computations will be run on CPU'''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

'''defining model parameters'''
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001

'''download the training and test set'''
train_dataset = torchvision.datasets.MNIST(root='data/', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='data/', 
                                          train=False, 
                                          transform=transforms.ToTensor())

'''use dataloader to shuffle and batch the data'''
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

'''define the model'''
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.layer_1 = nn.Sequential(
        #nn.Conv2D(in_channel, out_channel, kernel, stride)
        nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer_2 = nn.Sequential(
        nn.Conv2d(16, 32, kernel_size=5, stride = 1, padding=2),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Linear(7*7*32, num_classes)
    def forward(self, input):
        out = self.layer_1(input)
        out = self.layer_2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
    
'''instantiate the model'''
model = CNN(10).to(device)

'''cross entropy is used as loss function'''
criterion = nn.CrossEntropyLoss()

'''Adam optimizer is used as the optimization function. We optimized all the model parameters, with a given learning rate.'''
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

'''training'''
for epoch in range(num_epochs):
    for i, (img, label) in enumerate(train_loader):
        img = img.to(device)
        label = label.to(device)
        
        pred = model(img)
        loss = criterion(pred, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i%100 ==0:
            print('epoch [{}/{}], step [{}/{}], loss {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss))
        
'''calculate the performance of the trained model on unseen test set'''
true = 0
total = 0
with torch.no_grad():
    for img, label in test_loader:
        img = img.to(device)
        label = label.to(device)
        
        pred = model(img)
        _,pred = torch.max(pred, 1)
        true += (pred==label).sum()
        total += label.size(0)
print('accuracy on the test set is {} percent'.format(100*float(true)/total))

epoch [1/5], step [1/600], loss 2.3503
epoch [1/5], step [101/600], loss 0.1524
epoch [1/5], step [201/600], loss 0.0846
epoch [1/5], step [301/600], loss 0.0622
epoch [1/5], step [401/600], loss 0.0656
epoch [1/5], step [501/600], loss 0.1174
epoch [2/5], step [1/600], loss 0.1114
epoch [2/5], step [101/600], loss 0.0411
epoch [2/5], step [201/600], loss 0.0353
epoch [2/5], step [301/600], loss 0.0839
epoch [2/5], step [401/600], loss 0.0779
epoch [2/5], step [501/600], loss 0.0280
epoch [3/5], step [1/600], loss 0.0310
epoch [3/5], step [101/600], loss 0.0171
epoch [3/5], step [201/600], loss 0.0493
epoch [3/5], step [301/600], loss 0.0182
epoch [3/5], step [401/600], loss 0.0244
epoch [3/5], step [501/600], loss 0.1172
epoch [4/5], step [1/600], loss 0.0242
epoch [4/5], step [101/600], loss 0.0349
epoch [4/5], step [201/600], loss 0.0420
epoch [4/5], step [301/600], loss 0.0074
epoch [4/5], step [401/600], loss 0.0114
epoch [4/5], step [501/600], loss 0.0187
epoch [5/5], step [1/600