In [210]:
import torch
from torch import Tensor
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transform
from torch.utils.data import DataLoader

In [169]:
#MNIST datasets
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform.ToTensor())

len(mnist_testset)

10000

In [233]:
def train_perceptron(X: Tensor, W: Tensor, b: Tensor, y_true: Tensor, mu: float):
    for input, labels in zip(X, y_true):
        x = input.view(1, -1)
        #view - reshaping a tensor and preserving no of elems
        # x = images.view(images.shape[0], -1)
        encoded_label = torch.zeros(10)
        encoded_label[labels] = 1
        Z = torch.matmul(x, W) + b
        Y = 1 / (1 + torch.exp(-Z))
        Error = Y - encoded_label
        deltaWL : Tensor = Error.t() @ x
        deltabL : Tensor = Error
        W = W - mu * deltaWL.t()
        b = b - mu * deltabL
    return W, b

def train(trainset):
    mnist_trainloader : DataLoader = DataLoader(trainset, batch_size=128, shuffle=True)
    W = torch.randn(784, 10)
    b = torch.randn(10)
    mu = 0.01

    X, y_true = None, None
    for inputs, labels in mnist_trainloader:
        if X is None:
            X = inputs.view(inputs.shape[0], -1)
            y_true = labels
        else:
            X = torch.cat((X, inputs.view(inputs.shape[0], -1)), dim=0)
            y_true = torch.cat((y_true, labels), dim=0)

    return train_perceptron(X, W, b, y_true, mu)
       

W, b = train(mnist_trainset)
    

In [234]:
def predict(testset, W, b):
    mnist_testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)
    ok = 0
    for images, labels in mnist_testloader:
        #view - reshaping a tensor and preserving no of elems
        X = images.view(images.shape[0], -1)
        y_true = torch.zeros(10)
        y_true[labels[0]] = 1
        Z = torch.matmul(X, W) + b
        Y = 1 / (1 + torch.exp(-Z))
        if torch.argmax(y_true) == torch.argmax(Y):
            ok += 1
        #return Y, y_true, labels
    return ok / len(testset) * 100
    
predict(mnist_testset, W, b)

88.75999999999999