In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import glob
import os
import random
import time
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from matplotlib import pyplot as plt

In [2]:
class SimpleNeuralNet(nn.Module):

    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNeuralNet, self).__init__()
        # fill in the declarations of the layers here
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes) 
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # fill the forward logic here
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.softmax(out)
        return out


def train_and_test_simple_net(input_size, hidden_size, num_classes):
    num_epochs = 5
    learning_rate = 0.001
    model = SimpleNeuralNet(input_size, hidden_size, num_classes).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            # Move tensors to the configured device
            #images = images.reshape(-1, 28*28).to(device)#
            images = images.view(-1, 28*28).to(device)
            labels = labels.to(device)

            # Forward pass
            # The forward process computes the loss of each iteration on each sample
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and using the optimizer to update the parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Below, an epoch corresponds to one pass through all of the samples.
            # Each training step corresponds to a parameter update using 
            # a gradient computed on a minibatch of 100 samples 
            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

    # Test the model
    # In the test phase, we don't need to compute gradients (for memory efficiency)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.reshape(-1, 28 * 28).to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

In [4]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# The MNIST dataset is a built-in dataset in torchvision
batch_size = 100
train_dataset = torchvision.datasets.MNIST(root='../../data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data',
                                          train=False,
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

train_and_test_simple_net(28 * 28, 200, 10)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data\MNIST\raw\train-images-idx3-ubyte.gz


9920512it [00:08, 1132809.95it/s]                                                                                      


Extracting ../../data\MNIST\raw\train-images-idx3-ubyte.gz to ../../data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data\MNIST\raw\train-labels-idx1-ubyte.gz


32768it [00:01, 17920.05it/s]                                                                                          


Extracting ../../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../../data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data\MNIST\raw\t10k-images-idx3-ubyte.gz


1654784it [00:01, 1350940.68it/s]                                                                                      


Extracting ../../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data\MNIST\raw\t10k-labels-idx1-ubyte.gz


8192it [00:00, 46455.75it/s]                                                                                           


Extracting ../../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../data\MNIST\raw
Processing...
Done!
Epoch [1/5], Step [100/600], Loss: 1.7160
Epoch [1/5], Step [200/600], Loss: 1.6128
Epoch [1/5], Step [300/600], Loss: 1.6083
Epoch [1/5], Step [400/600], Loss: 1.5227
Epoch [1/5], Step [500/600], Loss: 1.5497
Epoch [1/5], Step [600/600], Loss: 1.5273
Epoch [2/5], Step [100/600], Loss: 1.5569
Epoch [2/5], Step [200/600], Loss: 1.5187
Epoch [2/5], Step [300/600], Loss: 1.5136
Epoch [2/5], Step [400/600], Loss: 1.5735
Epoch [2/5], Step [500/600], Loss: 1.5239
Epoch [2/5], Step [600/600], Loss: 1.5438
Epoch [3/5], Step [100/600], Loss: 1.5150
Epoch [3/5], Step [200/600], Loss: 1.4965
Epoch [3/5], Step [300/600], Loss: 1.5182
Epoch [3/5], Step [400/600], Loss: 1.4889
Epoch [3/5], Step [500/600], Loss: 1.5222
Epoch [3/5], Step [600/600], Loss: 1.4765
Epoch [4/5], Step [100/600], Loss: 1.4867
Epoch [4/5], Step [200/600], Loss: 1.5352
Epoch [4/5], Step [300/600], Loss: 1.5124
Epoch [4/5], Ste