In [262]:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import random

In [255]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('../data', train=False, download=True, transform=transform)

In [256]:
weights = torch.randn(28*28, 10, requires_grad=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size = 32, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = 32, shuffle = True)


In [257]:
def accuracy(weights, test_loader):
	correct = 0
	test_size = len(test_loader.dataset)
	for id, (data, target) in enumerate(test_loader):
		data = data.view(-1, 28*28)
		outputs = torch.matmul(data, weights)
		softmax = F.softmax(outputs, dim=1)
		pred = softmax.argmax(dim=1, keepdim=True)
		n_correct = pred.eq(target.view_as(pred)).sum().item()
		correct += n_correct
	return correct / test_size

In [258]:
def	train(weights, train_loader, learning_rate, n_test):
    it = 0
    for id, (data, targets) in enumerate(train_loader):
        if weights.grad is not None:
            weights.grad.zero_()

        data = data.view(-1, 28*28)
        outputs = torch.matmul(data, weights)
        log_softmax = F.log_softmax(outputs, dim=1)
        loss = F.nll_loss(log_softmax, targets)
        loss.backward()
        with torch.no_grad():
            weights -= learning_rate * weights.grad

        it += 1
        if it > n_test:
            break

In [270]:
train(weights, train_loader, 0.1, 6000)
print("Accuracy: " + str(accuracy(weights, test_loader)))

Accuracy: 0.8771


In [260]:
def prediction(input):
	data = input[0].view(-1, 28*28)
	output = torch.matmul(data, weights)
	softmax = F.softmax(output, dim=1)
	pred = softmax.argmax(dim=1, keepdim=True)
	print("Val: {} - Pred: {}".format(pred.item(), input[1]))

In [271]:
for i in range(25):
	prediction(test_data[random.randint(0, 200)])

Val: 7 - Pred: 7
Val: 8 - Pred: 8
Val: 5 - Pred: 5
Val: 7 - Pred: 7
Val: 2 - Pred: 2
Val: 1 - Pred: 1
Val: 4 - Pred: 4
Val: 6 - Pred: 3
Val: 5 - Pred: 5
Val: 0 - Pred: 0
Val: 9 - Pred: 9
Val: 9 - Pred: 9
Val: 0 - Pred: 0
Val: 7 - Pred: 7
Val: 9 - Pred: 9
Val: 4 - Pred: 4
Val: 9 - Pred: 9
Val: 8 - Pred: 5
Val: 9 - Pred: 9
Val: 1 - Pred: 1
Val: 3 - Pred: 3
Val: 9 - Pred: 9
Val: 9 - Pred: 9
Val: 2 - Pred: 2
Val: 1 - Pred: 1
