# MNIST CNN Training Notebook

This notebook implements a CNN model for MNIST digit classification.

In [None]:
# Install required packages
!pip install tqdm

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import os
from datetime import datetime

In [None]:
# Define the CNN model
class MnistCNN(nn.Module):
    def __init__(self):
        super(MnistCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=3, padding=1),
            nn.BatchNorm2d(10),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(10, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(7 * 7 * 16, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
# Utility functions
def save_model(model, accuracy, path="models"):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"mnist_model_{accuracy:.2f}acc_{timestamp}.pth"
    os.makedirs(path, exist_ok=True)
    torch.save(model.state_dict(), f"{path}/{filename}")
    return filename

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [None]:
# Training function
def train_model(epochs=5, batch_size=32):
    # Use GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data transformation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load and split dataset
    full_dataset = torchvision.datasets.MNIST(root='./data', 
                                             train=True,
                                             transform=transform,
                                             download=True)
    
    # Split into 50K training and 10K validation
    train_size = 50000
    val_size = 10000
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, 
                             batch_size=batch_size, 
                             shuffle=True,
                             num_workers=2)
    
    val_loader = DataLoader(val_dataset,
                           batch_size=batch_size,
                           shuffle=False,
                           num_workers=2)

    # Initialize model, loss, and optimizer
    model = MnistCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.015)

    print(f"Total trainable parameters: {count_parameters(model)}")

    # Training loop
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_correct = 0
        train_total = 0
        train_loss = 0
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for batch_idx, (data, target) in enumerate(train_pbar):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            train_correct += pred.eq(target.view_as(pred)).sum().item()
            train_total += target.size(0)
            
            train_pbar.set_postfix({
                'loss': f'{train_loss/(batch_idx+1):.4f}',
                'acc': f'{100.*train_correct/train_total:.2f}%'
            })

        # Validation phase
        model.eval()
        val_correct = 0
        val_total = 0
        val_loss = 0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
            for batch_idx, (data, target) in enumerate(val_pbar):
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                
                val_loss += loss.item()
                pred = output.argmax(dim=1, keepdim=True)
                val_correct += pred.eq(target.view_as(pred)).sum().item()
                val_total += target.size(0)
                
                val_pbar.set_postfix({
                    'loss': f'{val_loss/(batch_idx+1):.4f}',
                    'acc': f'{100.*val_correct/val_total:.2f}%'
                })

        print(f'\nEpoch {epoch+1}/{epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {100.*train_correct/train_total:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, Val Accuracy: {100.*val_correct/val_total:.2f}%\n')

    # Save model
    val_accuracy = 100. * val_correct / val_total
    model_path = save_model(model, val_accuracy)
    
    return model, val_accuracy, model_path

In [None]:
# Train the model
model, accuracy, model_path = train_model(epochs=5, batch_size=32)
print(f"Training completed with validation accuracy: {accuracy:.2f}%")
print(f"Model saved as: {model_path}")

## Test Model on Sample Images

You can add cells below to test the model on specific images or visualize the results.