In [1]:
import torch
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt


In [2]:
import torch.nn as nn
class MLPModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(MLPModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 10)
        )
    
    def forward(self, input):
        input = input.view(input.size(0), -1)
        return self.layers(input)

In [3]:
import numpy as np
import torch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def train(model, train_loader, optimizer, loss_fn, print_every=100):
    '''
    Trains the model for one epoch
    '''
    model.train()
    losses = []
    n_correct = 0
    for iteration, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        output = model(images)
        optimizer.zero_grad()
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
#         if iteration % print_every == 0:
#             print('Training iteration {}: loss {:.4f}'.format(iteration, loss.item()))
        losses.append(loss.item())
        n_correct += torch.sum(output.argmax(1) == labels).item()
    accuracy = 100.0 * n_correct / len(train_loader.dataset)
    return np.mean(np.array(losses)), accuracy
            
def test(model, test_loader, loss_fn):
    '''
    Tests the model on data from test_loader
    '''
    model.eval()
    test_loss = 0
    n_correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            loss = loss_fn(output, labels)
            test_loss += loss.item()
            n_correct += torch.sum(output.argmax(1) == labels).item()

    average_loss = test_loss / len(test_loader)
    accuracy = 100.0 * n_correct / len(test_loader.dataset)
#     print('Test average loss: {:.4f}, accuracy: {:.3f}'.format(average_loss, accuracy))
    return average_loss, accuracy


def fit(train_dataloader, val_dataloader, model, optimizer, loss_fn, n_epochs, scheduler=None):
    train_losses, train_accuracies = [], []
    val_losses, val_accuracies = [], []

    for epoch in range(n_epochs):
        train_loss, train_accuracy = train(model, train_dataloader, optimizer, loss_fn)
        val_loss, val_accuracy = test(model, val_dataloader, loss_fn)
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        if scheduler:
            scheduler.step() # argument only needed for ReduceLROnPlateau
        print('Epoch {}/{}: train_loss: {:.4f}, train_accuracy: {:.4f}, val_loss: {:.4f}, val_accuracy: {:.4f}'.format(epoch+1, n_epochs,
                                                                                                          train_losses[-1],
                                                                                                          train_accuracies[-1],
                                                                                                          val_losses[-1],
                                                                                                          val_accuracies[-1]))
    
    return val_accuracies

In [4]:
transform_dataset = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])

In [5]:
train_dataset = datasets.ImageFolder('mnist/train',transform=transform_dataset)

In [6]:
val_dataset = datasets.ImageFolder('mnist/val',transform=transform_dataset)

In [7]:
test_dataset = datasets.ImageFolder('mnist/test',transform=transform_dataset)

In [15]:
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=100, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=100, shuffle=True)

In [16]:
for tensor in train_loader:
    print(tensor[0].shape)
    break
for tensor in val_loader:
    print(tensor[0].shape)
    break

torch.Size([100, 3, 28, 28])
torch.Size([100, 3, 28, 28])


In [18]:
nb_epochs = 25
i = 0
for hidden_layers in [256,512]:
    for learning_rate in [0.001,0.002,0.01]:
        model_mlp = MLPModel(3*28*28, hidden_layers)
        model_mlp = model_mlp.to(device)
        optimizer = torch.optim.Adam(model_mlp.parameters(), lr=learning_rate)
        loss_fn = nn.CrossEntropyLoss()
        accuracies = fit(train_loader, val_loader, model_mlp, optimizer, loss_fn, nb_epochs)
        plt.title('Learning rate:' + str(learning_rate) + ', number of neurons in hidden layers: ' + str(hidden_layers))
        plt.plot(accuracies)
        plt.savefig('mlp_plot/'+str(i) + '.png')
        plt.show() 
        i = i + 1
        

Epoch 1/25: train_loss: 0.5758, train_accuracy: 81.1219, val_loss: 0.3028, val_accuracy: 91.1515
Epoch 2/25: train_loss: 0.2843, train_accuracy: 91.1692, val_loss: 0.2744, val_accuracy: 91.2323
Epoch 3/25: train_loss: 0.2263, train_accuracy: 92.8532, val_loss: 0.2139, val_accuracy: 93.4293
Epoch 4/25: train_loss: 0.1878, train_accuracy: 94.1393, val_loss: 0.1842, val_accuracy: 94.3131
Epoch 5/25: train_loss: 0.1676, train_accuracy: 94.6965, val_loss: 0.1641, val_accuracy: 94.9747
Epoch 6/25: train_loss: 0.1522, train_accuracy: 95.2139, val_loss: 0.1768, val_accuracy: 94.6414
Epoch 7/25: train_loss: 0.1386, train_accuracy: 95.5323, val_loss: 0.1391, val_accuracy: 95.5909
Epoch 8/25: train_loss: 0.1254, train_accuracy: 95.9900, val_loss: 0.1659, val_accuracy: 94.9444
Epoch 9/25: train_loss: 0.1170, train_accuracy: 96.1866, val_loss: 0.1373, val_accuracy: 95.7424
Epoch 10/25: train_loss: 0.1054, train_accuracy: 96.5448, val_loss: 0.1294, val_accuracy: 96.0101
Epoch 11/25: train_loss: 0.09

KeyboardInterrupt: 