# MNIST Custom Implementation

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

In [2]:
train_dataset = MNIST(root="../datasets", train=True, download=True, transform=transforms.ToTensor())
test_dataset = MNIST(root="../datasets", train=False, download=False, transform=transforms.ToTensor())

In [3]:
# figure out if you have a GPU
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU (UUID: GPU-94b0898d-baf3-ee3a-f30c-ffdfe67aced2)


In [4]:
# parameters
DEVICE = ("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS=10
BATCH_SIZE=32

HIDDEN_SIZE = 100
NUM_LABELS = 10
NUM_FEATURES = 28*28
ALPHA = 0.1

In [5]:
train_dataloader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              drop_last=True,
                              num_workers=4)

test_dataloader = DataLoader(dataset=test_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              drop_last=True,
                              num_workers=4)

In [6]:
# we create a set of weights and biases
W_1 = torch.randn(HIDDEN_SIZE, NUM_FEATURES, requires_grad=True, device=DEVICE)
b_1 = torch.zeros(1, HIDDEN_SIZE, requires_grad=True, device=DEVICE)

W_2 = torch.randn(NUM_LABELS, HIDDEN_SIZE, requires_grad=True, device=DEVICE)
b_2 = torch.zeros(1, NUM_LABELS, requires_grad=True, device=DEVICE)

In [9]:
for epoch in range(NUM_EPOCHS):
    loss_sum = 0
    batch_nums = 0
    for batch_idx, (features, labels) in enumerate(train_dataloader):
        # reshape features and move to gpu
        features = features.view(-1, NUM_FEATURES).to(DEVICE)
        
        # create one hot labels and move to GPU
        one_hot_labels = torch.zeros(BATCH_SIZE, NUM_LABELS).to(DEVICE)
        for sample_idx, label in enumerate(labels):
            one_hot_labels[sample_idx][label] = 1
                
        # ------ FORWARD PASS --------
        # first linear transformation
        out = features @ W_1.T + b_1

        # sigmoid activation
        out = 1 / (1 + torch.exp(-out))

        # second linear transformation
        logits = out @ W_2.T + b_2

        # softmax
        numerator = torch.exp(logits)
        denominator = numerator.sum(dim=1, keepdim=True)
        softmax = numerator / denominator

        # ------CALCULATE LOSS --------
        #cross-entropy loss
        loss = -(one_hot_labels * torch.log(softmax)).mean()

        # ------BACKPROPAGATION --------
        loss.backward()

        # ------GRADIENT DESCENT --------
        with torch.inference_mode():
            W_1.sub_(ALPHA * W_1.grad)
            b_1.sub_(ALPHA * b_1.grad)
            W_2.sub_(ALPHA * W_2.grad)
            b_2.sub_(ALPHA * b_2.grad)

        # ------CLEAR GRADIENTS --------
        W_1.grad.zero_()
        W_2.grad.zero_()
        b_1.grad.zero_()
        b_2.grad.zero_()
        
        # ------TRACK LOSS --------
        batch_nums += 1
        loss_sum += loss.detach().cpu()
    
    print(f'Epoch: {epoch+1} Loss: {loss_sum / batch_nums}')

Epoch: 1 Loss: 0.05242159962654114
Epoch: 2 Loss: 0.05055994167923927
Epoch: 3 Loss: 0.04891354963183403
Epoch: 4 Loss: 0.04746674373745918
Epoch: 5 Loss: 0.04614807292819023
Epoch: 6 Loss: 0.044992394745349884
Epoch: 7 Loss: 0.04390263929963112
Epoch: 8 Loss: 0.042916443198919296
Epoch: 9 Loss: 0.04200851172208786
Epoch: 10 Loss: 0.04114517942070961


In [10]:
# test acccuracy
num_samples = 0
num_correct = 0
for batch_idx, (features, labels) in enumerate(test_dataloader):
    with torch.inference_mode():
        features = features.view(-1, NUM_FEATURES).to(DEVICE)
        labels = labels.to(DEVICE) 
        # ------ FORWARD PASS --------
        # first linear transformation
        out = features @ W_1.T + b_1

        # sigmoid activation
        out = 1 / (1 + torch.exp(-out))

        # second linear transformation
        logits = out @ W_2.T + b_2
        
        predictions = logits.argmax(dim=1)
        num_samples+=len(features)
        num_correct+=(labels == predictions).sum().detach().cpu().item()
        
accuracy = num_correct / num_samples
print(accuracy)

0.8804086538461539
