<a href="https://colab.research.google.com/github/DurgaPrasad-R/FML/blob/main/MNIST_Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import random
# Define the neural network architecture
model = nn.Sequential(
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(dim=1)
)
# Initialize weights using Kaiming initialization
for module in model.modules():
    if isinstance(module, nn.Linear):
        nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
# Define the loss function and optimizer
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum = 0.9);


# Define the learning rate scheduler
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0)

# Define the normalization transform
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])

# Load the MNIST dataset and create data loaders
torch.manual_seed(123)
np.random.seed(123)

train_data = torchvision.datasets.MNIST(root='./data', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=False)

test_data = torchvision.datasets.MNIST(root='./data', train=False, transform=ToTensor(), download=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# Train the model for several epochs

num_epoch = 15;
for epoch in range(num_epoch):
    # Train the model on the training data
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    model.train()  # switch to train mode
    for inputs, labels in train_loader:
        # Flatten the input images
        inputs = inputs.view(inputs.shape[0], -1)
        # Zero the gradients
        optimizer.zero_grad()
        # Forward pass
        outputs = model(inputs)
        # Compute the loss
        loss = criterion(outputs, labels)
        # Backward pass
        loss.backward()
        # Update the weights
        optimizer.step()
        # Keep track of the loss and accuracy
        train_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        
    train_loss /= len(train_loader.dataset)
    train_accuracy = 100.0 * train_correct / train_total
    
    # Evaluate the model on the test data
    test_loss = 0.0
    test_correct = 0
    test_total = 0
    model.eval()  # switch to eval mode
    with torch.no_grad():
        for inputs, labels in test_loader:
            # Flatten the input images
            inputs = inputs.view(inputs.shape[0], -1)
            # Forward pass
            outputs = model(inputs)
            # Compute the loss
            loss = criterion(outputs, labels)
            # Keep track of the loss and accuracy
            test_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            predicted_labels = [test_data.classes[i] for i in predicted]
            actual_labels = [test_data.classes[i] for i in labels]
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100.0 * test_correct / test_total
    
    # Print the loss and accuracy for this epoch
    print(f"Epoch {epoch+1}/{15} - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%")
    print("Actual Labels:", actual_labels)
    print("Predicted Labels:", predicted_labels)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 78769178.04it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 18004710.73it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 25122842.30it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 11602027.26it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Epoch 1/15 - Train Loss: 0.2936, Train Acc: 91.22%, Test Loss: 0.1519, Test Acc: 95.26%
Actual Labels: ['1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine', '0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six']
Predicted Labels: ['1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine', '0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six']
Epoch 2/15 - Train Loss: 0.1146, Train Acc: 96.60%, Test Loss: 0.1055, Test Acc: 96.53%
Actual Labels: ['1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine', '0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six']
Predicted Labels: ['1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine', '0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - 