In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import os

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class DigitClassifier(nn.Module):
    def __init__(self, dropout_rate=0.2):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 256)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.dropout1(x)
        x = self.relu2(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x  # logits

class CNNDigitClassifier(nn.Module):
    def __init__(self, dropout_rate=0.25):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.dropout1(x)
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

def load_data_with_validation(batch_size: int, validation_split=0.1):
    base_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_train_dataset = datasets.MNIST(root="./data", train=True,
                                        download=True, transform=train_transform)
    test_dataset = datasets.MNIST(root="./data", train=False,
                                  download=True, transform=base_transform)
    train_size = int((1 - validation_split) * len(full_train_dataset))
    val_size = len(full_train_dataset) - train_size
    train_dataset, val_dataset = random_split(full_train_dataset,
                                              [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
    print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    return train_loader, val_loader, test_loader

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler=None, max_epochs=50):
    model.to(device)
    best_val_acc = 0
    patience_counter = 0
    patience = 5
    print(f"Starting training for up to {max_epochs} epochs...")
    print("-" * 60)
    
    for epoch in range(max_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1:2d}/{max_epochs}", ncols=100)
        for imgs, labels in train_pbar:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, preds = torch.max(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
        avg_loss = total_loss / len(train_loader)
        train_accuracy = 100. * correct / total
        val_accuracy = evaluate_model(model, val_loader, return_accuracy=True, verbose=False)
        print(f"Epoch {epoch+1:2d}: Train Loss: {avg_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Acc: {val_accuracy:.2f}%")
        if scheduler:
            scheduler.step()
            current_lr = scheduler.get_last_lr()[0]
            if epoch > 0:
                print(f"         Learning rate: {current_lr:.6f}")
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            patience_counter = 0
            torch.save(model.state_dict(), 'best_mnist_model.pth')
        else:
            patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}!")
            print(f"Best validation accuracy: {best_val_acc:.2f}%")
            break
    model.load_state_dict(torch.load('best_mnist_model.pth'))
    print(f"\nTraining completed. Best validation accuracy: {best_val_acc:.2f}%")
    return best_val_acc

def evaluate_model(model, test_loader, return_accuracy=False, verbose=True):
    model.eval()
    model.to(device)
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            _, preds = torch.max(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = 100. * correct / total
    if verbose:
        print(f"Test Accuracy: {accuracy:.2f}%")
    if return_accuracy:
        return accuracy

def save_model(model, filepath, metadata=None):
    save_dict = {
        'model_state_dict': model.state_dict(),
        'model_class': model.__class__.__name__
    }
    if metadata:
        save_dict.update(metadata)
    torch.save(save_dict, filepath)
    print(f"Model saved to {filepath}")

def load_model(filepath, model_class):
    checkpoint = torch.load(filepath)
    model = model_class()
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

if __name__ == "__main__":
    batch_size = 64
    learning_rate = 0.001
    max_epochs = 50
    dropout_rate = 0.2
    use_cnn = False  # Set to True to use CNN instead of fully connected
    
    print("=" * 60)
    print("MNIST Digit Classification - Improved Version")
    print("=" * 60)
    
    train_loader, val_loader, test_loader = load_data_with_validation(batch_size)
    if use_cnn:
        model = CNNDigitClassifier(dropout_rate=dropout_rate)
        print("Using CNN architecture")
    else:
        model = DigitClassifier(dropout_rate=dropout_rate)
        print("Using fully connected architecture")
    print(f"Model has {count_parameters(model):,} trainable parameters")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    
    print("Training configuration:")
    print(f"  Batch size: {batch_size}")
    print(f"  Learning rate: {learning_rate}")
    print(f"  Dropout rate: {dropout_rate}")
    print(f"  Max epochs: {max_epochs}")
    print(f"  Weight decay: 1e-4")
    
    best_val_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, max_epochs)

    print("\n" + "=" * 60)
    print("FINAL RESULTS")
    print("=" * 60)

    test_acc = evaluate_model(model, test_loader, return_accuracy=True)
    
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Final test accuracy: {test_acc:.2f}%")
    
    metadata = {
        'test_accuracy': test_acc,
        'best_val_accuracy': best_val_acc,
        'hyperparameters': {
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'dropout_rate': dropout_rate,
            'architecture': 'CNN' if use_cnn else 'FC'
        }
    }
    save_model(model, 'final_mnist_model.pth', metadata)
    if os.path.exists('best_mnist_model.pth'):
        os.remove('best_mnist_model.pth')
    print("\nTraining completed successfully!")


Using device: cuda
MNIST Digit Classification - Improved Version


100%|██████████| 9.91M/9.91M [00:00<00:00, 59.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.66MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.01MB/s]


Dataset sizes - Train: 54000, Val: 6000, Test: 10000
Using fully connected architecture
Model has 235,146 trainable parameters
Training configuration:
  Batch size: 64
  Learning rate: 0.001
  Dropout rate: 0.2
  Max epochs: 50
  Weight decay: 1e-4
Starting training for up to 50 epochs...
------------------------------------------------------------


Epoch  1/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 43.05it/s, Loss=0.3674]


Epoch  1: Train Loss: 0.3902, Train Acc: 87.72%, Val Acc: 94.53%


Epoch  2/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 43.94it/s, Loss=0.2017]


Epoch  2: Train Loss: 0.2006, Train Acc: 93.85%, Val Acc: 95.38%
         Learning rate: 0.001000


Epoch  3/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 44.36it/s, Loss=0.2126]


Epoch  3: Train Loss: 0.1707, Train Acc: 94.72%, Val Acc: 96.08%
         Learning rate: 0.001000


Epoch  4/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.47it/s, Loss=0.2735]


Epoch  4: Train Loss: 0.1522, Train Acc: 95.22%, Val Acc: 96.00%
         Learning rate: 0.001000


Epoch  5/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 44.35it/s, Loss=0.0446]


Epoch  5: Train Loss: 0.1404, Train Acc: 95.67%, Val Acc: 96.23%
         Learning rate: 0.001000


Epoch  6/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.55it/s, Loss=0.2849]


Epoch  6: Train Loss: 0.1370, Train Acc: 95.77%, Val Acc: 96.80%
         Learning rate: 0.001000


Epoch  7/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 44.26it/s, Loss=0.0291]


Epoch  7: Train Loss: 0.1329, Train Acc: 95.82%, Val Acc: 96.85%
         Learning rate: 0.000100


Epoch  8/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 44.34it/s, Loss=0.0858]


Epoch  8: Train Loss: 0.1002, Train Acc: 96.89%, Val Acc: 97.43%
         Learning rate: 0.000100


Epoch  9/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 44.40it/s, Loss=0.0360]


Epoch  9: Train Loss: 0.0902, Train Acc: 97.23%, Val Acc: 97.33%
         Learning rate: 0.000100


Epoch 10/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.44it/s, Loss=0.0606]


Epoch 10: Train Loss: 0.0902, Train Acc: 97.16%, Val Acc: 97.33%
         Learning rate: 0.000100


Epoch 11/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.59it/s, Loss=0.0788]


Epoch 11: Train Loss: 0.0846, Train Acc: 97.28%, Val Acc: 97.32%
         Learning rate: 0.000100


Epoch 12/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 44.36it/s, Loss=0.1643]


Epoch 12: Train Loss: 0.0838, Train Acc: 97.44%, Val Acc: 97.62%
         Learning rate: 0.000100


Epoch 13/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 44.15it/s, Loss=0.0604]


Epoch 13: Train Loss: 0.0819, Train Acc: 97.45%, Val Acc: 97.53%
         Learning rate: 0.000100


Epoch 14/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.64it/s, Loss=0.0535]


Epoch 14: Train Loss: 0.0805, Train Acc: 97.55%, Val Acc: 97.83%
         Learning rate: 0.000010


Epoch 15/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.86it/s, Loss=0.0360]


Epoch 15: Train Loss: 0.0778, Train Acc: 97.56%, Val Acc: 97.73%
         Learning rate: 0.000010


Epoch 16/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.66it/s, Loss=0.1423]


Epoch 16: Train Loss: 0.0786, Train Acc: 97.56%, Val Acc: 97.73%
         Learning rate: 0.000010


Epoch 17/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.44it/s, Loss=0.0482]


Epoch 17: Train Loss: 0.0748, Train Acc: 97.71%, Val Acc: 97.70%
         Learning rate: 0.000010


Epoch 18/50: 100%|███████████████████████████████████| 844/844 [00:19<00:00, 44.35it/s, Loss=0.1871]


Epoch 18: Train Loss: 0.0750, Train Acc: 97.76%, Val Acc: 97.65%
         Learning rate: 0.000010


Epoch 19/50: 100%|███████████████████████████████████| 844/844 [00:18<00:00, 44.75it/s, Loss=0.0463]


Epoch 19: Train Loss: 0.0754, Train Acc: 97.62%, Val Acc: 97.77%
         Learning rate: 0.000010

Early stopping at epoch 19!
Best validation accuracy: 97.83%

Training completed. Best validation accuracy: 97.83%

FINAL RESULTS
Test Accuracy: 98.57%
Best validation accuracy: 97.83%
Final test accuracy: 98.57%
Model saved to final_mnist_model.pth

Training completed successfully!
