In [1]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import random
device = 'cuda'

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [2]:
learning_rate = 0.001
epochs=15
batch_size=100
drop_prob = 0.3

In [3]:
# MNIST dataset
mnist_train = dsets.MNIST(root='MNIST_data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True
                         )
mnist_test = dsets.MNIST(root='MNIST_data/',
                          train=False,
                          transform=transforms.ToTensor(),
                          download=True
                         )


In [4]:
data_loader = torch.utils.data.DataLoader(dataset=mnist_train, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          drop_last=True)

In [5]:
linear1 = torch.nn.Linear(784, 512, bias=True)
linear2 = torch.nn.Linear(512, 512, bias=True)
linear3 = torch.nn.Linear(512, 512, bias=True)
linear4 = torch.nn.Linear(512, 512, bias=True)
linear5 = torch.nn.Linear(512, 10, bias=True)
relu = torch.nn.ReLU()
dropout = torch.nn.Dropout(p=drop_prob)

In [6]:
torch.nn.init.xavier_uniform_(linear1.weight)
torch.nn.init.xavier_uniform_(linear2.weight)
torch.nn.init.xavier_uniform_(linear3.weight)
torch.nn.init.xavier_uniform_(linear4.weight)
torch.nn.init.xavier_uniform_(linear5.weight)

Parameter containing:
tensor([[ 0.0002, -0.0084,  0.0340,  ...,  0.0659,  0.0017,  0.0082],
        [ 0.0732, -0.0465, -0.0325,  ...,  0.1050,  0.0456, -0.0786],
        [ 0.0072, -0.0874,  0.0975,  ...,  0.0975,  0.0515,  0.0245],
        ...,
        [ 0.0794,  0.0860, -0.0077,  ...,  0.0270,  0.0412,  0.0247],
        [ 0.0821,  0.0229,  0.0508,  ..., -0.0335, -0.0429, -0.0587],
        [-0.0008,  0.0162,  0.0755,  ...,  0.1017, -0.0163,  0.0126]],
       requires_grad=True)

In [7]:
model = torch.nn.Sequential(linear1, relu, dropout,
                            linear2, relu, dropout,
                            linear3, relu, dropout,
                            linear4, relu, dropout,
                            linear5
                           ).to(device)


In [8]:
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [9]:
total_batch = len(data_loader)
model.train()
for epoch in range(epochs):
    avg_cost = 0
    for X,y in data_loader:
        X = X.view(-1, 28 * 28).to(device)
        y = y.to(device)
        
        hypothesis = model(X)
        cost = criterion(hypothesis, y)
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        
        avg_cost += cost/ total_batch
        
    print('Epoch : {:4d}/{} Cost : {:.6f}'.format(epoch+1,epochs, avg_cost))


Epoch :    1/15 Cost : 0.312583
Epoch :    2/15 Cost : 0.143761
Epoch :    3/15 Cost : 0.112883
Epoch :    4/15 Cost : 0.093941
Epoch :    5/15 Cost : 0.082029
Epoch :    6/15 Cost : 0.075798
Epoch :    7/15 Cost : 0.068466
Epoch :    8/15 Cost : 0.065301
Epoch :    9/15 Cost : 0.056528
Epoch :   10/15 Cost : 0.055926
Epoch :   11/15 Cost : 0.053659
Epoch :   12/15 Cost : 0.048328
Epoch :   13/15 Cost : 0.044862
Epoch :   14/15 Cost : 0.044445
Epoch :   15/15 Cost : 0.041840


In [41]:
with torch.no_grad():
    model.eval()
    
    X_test = mnist_test.test_data.view(-1, 28*28).float().to(device)
    y_test = mnist_test.test_labels.to(device)
    
    prediction = model(X_test)
    correct_prediction = torch.argmax(prediction, 1) == y_test
    accuracy = correct_prediction.float().mean()
    print('Acc {} '.format(accuracy.item()))
    
    r = random.randint(0, len(mnist_test) -1)
    X_single_data = mnist_test.test_data[r:r+1].view(-1, 28*28).float().to(device)
    y_single_data = mnist_test.test_labels[r:r+1].to(device)
    
    print('Label {} '.format(y_single_data.item()))
    single_prediction = model(X_single_data)
    print('Prediction : {}'.format(torch.argmax(single_prediction, 1).item()))

Acc 0.981499969959259 
Label 4 
Prediction : 4
