In [None]:
# Install required libraries (uncomment if running on Colab)
!pip install torch torchvision matplotlib numpy

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# 2. Device configuration and Hyperparameters
# This section sets up the device (CPU or GPU) and defines the hyperparameters for training.

# This line will automatically detect and use a CUDA-enabled GPU (like T4) if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters for training
num_epochs = 5
batch_size = 64
lr = 0.001

In [None]:
# 3. Data Transformation and Loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Download and load the training data
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

# Download and load the test data
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False)

In [None]:
# 4. Neural Network Architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net().to(device)

In [None]:
# 5. Initialize models and weights
# Random weights
def random_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight, a=-0.5, b=0.5)  # Example: uniform random weights
        nn.init.zeros_(m.bias)  # Optional: zero bias

def xavier_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)

# Initialize the model weights
# model.load_state_dict(torch.load('models/checkpoint.pth', map_location=device))
# model.apply(random_weights)
model.apply(xavier_weights)

In [None]:
# 6. Loss function and Optimizers
# Define the loss function and optimizer for training the neural network.
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
# 7. Training Loop
# Train the neural network on the MNIST training dataset. The training loop will display the loss for each epoch.

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader):.4f}")

In [None]:
# 7.1 Save checkpoint after each epoch
checkpoint_path = f"mnist_checkpoint_epoch_{epoch+1}.pth"
torch.save({
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': running_loss / len(trainloader),
}, checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}")

In [None]:
# 8. Evaluate the model
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy on the test set: {accuracy:.2f}%')

In [None]:
# 9. Save the final trained model
model_save_path = "mnist_classifier_final.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Training complete. Final model saved to {model_save_path}")