### An example of feed forward neural network on MNIST dataset

In [2]:
import torch 
import torchvision
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms

'''move the computations to the GPU if cuda is available, otherwise the computations will be run on CPU'''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

'''defining model parameters'''
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

'''download the training and test set'''
train_dataset = torchvision.datasets.MNIST(root='data/', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='data/', 
                                          train=False, 
                                          transform=transforms.ToTensor())

'''use dataloader to shuffle and batch the data'''
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

'''define the model: 2 fully connected layers'''
class NN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NN, self).__init__()
        self.layer_1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.layer_2 = nn.Linear(hidden_size, num_classes)
    def forward(self, input):
        output = self.layer_1(input)
        output = self.relu(output)
        output = self.layer_2(output)
        return output

'''instantiate the model'''
model = NN(input_size, hidden_size, num_classes).to(device)

'''cross entropy is used as loss function'''
criterion = nn.CrossEntropyLoss()

'''Adam optimizer is used as the optimization function. We optimized all the model parameters, with a given learning rate.'''
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

'''training...'''
for epoch in range(num_epochs):
    for i, (img, label) in enumerate(train_loader):
        img = img.reshape(-1,28*28)
        img = img.to(device)
        label = label.to(device)
        
        pred = model(img)
        loss = criterion(pred, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i%100==0:
            print('epoch [{}/{}], step [{}/{}], loss {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader),loss))

'''calculate the performance of the trained model on unseen test set'''
true = 0
total = 0
with torch.no_grad():
    for img,label in test_loader:
        img = img.reshape(-1, 28*28)
        img = img.to(device)
        label = label.to(device)
        pred = model(img)
        prob, pred_label = torch.max(pred.data, 1)
        total += label.size()[0]
        true += (pred_label==label).sum().item()
print('accuracy on the test set is {} percent'.format(100 * float(true)/total))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
epoch [1/5], step [1/600], loss 2.2920
epoch [1/5], step [101/600], loss 0.4885
epoch [1/5], step [201/600], loss 0.1912
epoch [1/5], step [301/600], loss 0.3238
epoch [1/5], step [401/600], loss 0.1246
epoch [1/5], step [501/600], loss 0.1594
epoch [2/5], step [1/600], loss 0.1386
epoch [2/5], step [101/600], loss 0.1909
epoch [2/5], step [201/600], loss 0.1396
epoch [2/5], step [301/600], loss 0.1575
epoch [2/5], step [401/600], loss 0.0923
epoch [2/5], step [501/600], loss 0.0293
epoch [3/5], step [1/600], loss 0.0396
epoch [3/5], step [101/600], loss 0.0803
epoch [3/5], step [201/600], loss 0.0582
epoch [3/5], step [301/600], loss 0.0767
epoch [3/5], step [401/600], loss 0.0574
epo