In [None]:
# !pip install torch torchvision

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
transform = transforms.ToTensor()

In [None]:
train_data = MNIST(root='data', train=True, transform=transform, download=True)
test_data = MNIST(root='data', train=False, transform=transform, download=True)

In [None]:
# train_data.data[0]
# test_data.data[0]

In [None]:
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False)

In [None]:
class DigitClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            
            nn.Linear(128, 64),
            nn.ReLU(),
            
            nn.Linear(64, 10),
            # nn.ReLU()   

        )

    def forward(self, x):
        return self.net(x)

In [None]:
model = DigitClassifier().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        images = images.view(images.size(0), -1).to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        loss.backward()
        optimizer.step()


        total_loss += loss.item() 

    print(f"Epoch [{epoch+1}/{epochs}], Step [{labels.size(0)}], Loss: {loss.item():.4f}")
    

In [None]:
model.eval()
correct = 0
total = 0

In [None]:

with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(images.size(0), -1).to(device)
        labels = labels.to(device)

        outputs = model(images)
        predictions = outputs.argmax(dim=1)

        total += labels.size(0)
        correct += (predictions == labels).sum().item()

    print(f'Accuracy of the model on the test images: {100 * correct / total} %')



In [None]:
torch.save(model.state_dict(), 'digit_classifier.pth')
print('Model saved to digit_classifier.pth')

In [None]:
index = 0
image, true_label = test_data[index]

plt.imshow(image.squeeze(), cmap='gray')
plt.title(f'True Label: {true_label}')
plt.axis('off')
plt.show()

image_flat = image.view(1, -1).to(device)

with torch.no_grad():
    output = model(image_flat)
    predicted_label = output.argmax(dim=1).item()

print(f'Index: {index}')
print(f'True Label: {true_label}')
print(f'Predicted Label: {predicted_label}')

In [None]:
import random

for _ in range(5):
    index = random.randint(0, len(test_data)-1)
    image, true_label = test_data[index]

    plt.imshow(image.squeeze(), cmap='gray')
    plt.title(f'True Label: {true_label}')
    plt.axis('off')
    plt.show()

    image_flat = image.view(1, -1).to(device)

    with torch.no_grad():
        output = model(image_flat)
        predicted_label = output.argmax(dim=1).item()

    print(f'Index: {index}')
    print(f'True Label: {true_label}')
    print(f'Predicted Label: {predicted_label}')

In [73]:
import random
n = 5000
cnt = 0
for _ in range(n):
    index = random.randint(0, len(test_data)-1)
    image, true_label = test_data[index]

    # plt.imshow(image.squeeze(), cmap='gray')
    # plt.title(f'True Label: {true_label}')
    # plt.axis('off')
    # plt.show()

    image_flat = image.view(1, -1).to(device)

    with torch.no_grad():
        output = model(image_flat)
        predicted_label = output.argmax(dim=1).item()

    if true_label == predicted_label:
        cnt += 1

    # print('---')
    # print(f'Index: {index}')
    # print(f'True Label: {true_label}')
    # print(f'Predicted Label: {predicted_label}')
print(f'Correct Predictions: {cnt} out of {n}')
print(f'Incorrect Predictions: {n - cnt} out of {n}')
print(f'Accuracy over {n} random samples: {100 * cnt / n} %')

Correct Predictions: 4894 out of 5000
Incorrect Predictions: 106 out of 5000
Accuracy over 5000 random samples: 97.88 %
