In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

In [2]:
# One-hot encode a batch of labels
def one_hot_encode(labels, num_classes=10):
    one_hot = torch.zeros((labels.size(0), num_classes))
    one_hot.scatter_(1, labels.unsqueeze(1).long(), 1)
    return one_hot

In [3]:
# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))

In [4]:
# Load dataset in batches
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [33]:
for data, labels in train_loader:
    print(data.shape)
    print(labels.shape)
    break

torch.Size([32, 1, 28, 28])
torch.Size([32])


In [5]:
def sigmoid(x):
    return (1 / (1 + torch.exp(-x)))

In [6]:
def softmax(x, dim=1):
  exp_x = torch.exp(x)
  sum_exp = torch.sum(exp_x, dim=dim, keepdim=True)
  return exp_x / sum_exp

In [7]:
# Cross-entropy on batches
def cross_entropy_loss(output, target):
  probs = softmax(output, dim=1)

  target = target.unsqueeze(1)

  loss = -torch.mean(torch.sum(torch.log(probs) * target, dim=1))

  return loss

In [8]:
class Perceptron:
    def __init__(self, input_dim, num_classes, learning_rate, loss_fn, batch_size):
        self.weights = torch.randn(input_dim, num_classes)
        self.bias = torch.randn(num_classes)
        self.loss_fn = loss_fn
        self.learning_rate = learning_rate
        self.batch_size = batch_size

    def forward(self, inputs):
        linear = torch.add(torch.mm(inputs, self.weights), self.bias)
        return torch.sigmoid(linear)
    
    def backward(self, inputs, outputs, targets):
        grad = outputs - targets

        grad_w = torch.mm(inputs.t(), grad)
        grad_b = torch.sum(grad, dim=0)

        grad_w /= self.batch_size
        grad_b /= self.batch_size

        self.weights -= self.learning_rate * grad_w
        self.bias -= self.learning_rate * grad_b
    
    def train(self, train_loader, epochs):
        for i in range(epochs):
            total_correct = 0
            total_samples = 0
            for inputs, targets in train_loader:
                # Preprocess inputs and targets
                inputs = torch.flatten(inputs, start_dim=1)
                targets = one_hot_encode(targets)
                
                outputs = self.forward(inputs)
                loss = self.loss_fn(outputs, targets)
                self.backward(inputs, outputs, targets)

                # Compute accuracy for batch.
                predictions = torch.argmax(outputs, dim=1)
                targets = torch.argmax(targets, dim=1)
                correct = torch.sum((predictions == targets).float())
                total_correct += correct
                total_samples += len(targets)
            print(f'Epoch {i+1} --- Loss: {loss} --- Accuracy: {total_correct/total_samples}')
    
    def evaluate(self, test_loader):
        correct = 0
        total = len(test_loader.dataset)
        for inputs, targets in test_loader:
            inputs = torch.flatten(inputs, start_dim=1)

            outputs = self.forward(inputs)
            predictions = torch.argmax(outputs, dim=1)
            correct += torch.sum((predictions == targets).float())
        print(f'Accuracy: {correct / total}')

In [9]:
# Parameters initialization
epochs = 10
learning_rate = 0.15
input_dim = 784
num_classes = 10
batch_size = 32

In [10]:
model = Perceptron(input_dim, num_classes, learning_rate, cross_entropy_loss, batch_size)

In [11]:
model.train(train_loader, epochs)
model.evaluate(test_loader)

Epoch 1 --- Loss: 7.436459541320801 --- Accuracy: 0.7735000252723694
Epoch 2 --- Loss: 7.420828819274902 --- Accuracy: 0.8689166903495789
Epoch 3 --- Loss: 7.460615634918213 --- Accuracy: 0.883650004863739
Epoch 4 --- Loss: 7.406758785247803 --- Accuracy: 0.8915666937828064
Epoch 5 --- Loss: 7.479333400726318 --- Accuracy: 0.896233320236206
Epoch 6 --- Loss: 7.500998020172119 --- Accuracy: 0.8999666571617126
Epoch 7 --- Loss: 7.475472927093506 --- Accuracy: 0.9023000001907349
Epoch 8 --- Loss: 7.433465003967285 --- Accuracy: 0.9050499796867371
Epoch 9 --- Loss: 7.430211544036865 --- Accuracy: 0.9074166417121887
Epoch 10 --- Loss: 7.486754417419434 --- Accuracy: 0.908133327960968
Accuracy: 0.9049999713897705
