In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from torch.optim.swa_utils import AveragedModel, SWALR

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Load MNIST Dataset
def load_data():
    """
    Load and preprocess MNIST dataset with normalization and augmentation.

    Returns:
    train_loader: DataLoader for training
    test_loader: DataLoader for testing
    """
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to ResNet50's input size
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),  # Normalize for grayscale images
    ])

    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform
    )

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

    print("Training samples:", len(train_dataset))
    print("Testing samples:", len(test_dataset))

    return train_loader, test_loader

# 2. Modify ResNet50 for MNIST
def get_model():
    """
    Load and modify ResNet50 for MNIST classification.

    Returns:
    model: ResNet50 model adapted for MNIST
    """
    model = resnet50(weights=ResNet50_Weights.DEFAULT)  # Load pretrained ResNet50
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # Adjust for single-channel input
    model.fc = nn.Linear(model.fc.in_features, 10)  # Adjust final layer for 10 classes
    model = model.to(device)
    return model

# 3. Training Function
def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=20):
    """
    Train the model and apply SWA for better generalization.

    Args:
    model: Neural network model
    train_loader: DataLoader for training
    test_loader: DataLoader for testing
    criterion: Loss function
    optimizer: Optimizer
    scheduler: Learning rate scheduler
    num_epochs: Number of epochs to train

    Returns:
    best_accuracy: Best test accuracy achieved
    """
    best_accuracy = 0.0
    swa_model = AveragedModel(model)  # Initialize SWA model
    swa_start = int(0.75 * num_epochs)  # SWA starts at 75% of total epochs

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # Ensure inputs and labels are on the same device

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy tracking
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()

        # Step the learning rate scheduler
        scheduler.step()

        # Update SWA model
        if epoch >= swa_start:
            swa_model.update_parameters(model)

        # Validation accuracy
        val_accuracy = evaluate_model(model, test_loader)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}, Train Accuracy: {100*correct/total:.2f}%, Val Accuracy: {val_accuracy:.2f}%")

        # Save best model
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Best model saved with accuracy: {best_accuracy:.2f}%")

    # Apply SWA BatchNorm updates and save SWA model
    torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
    torch.save(swa_model.state_dict(), 'swa_best_model.pth')
    return best_accuracy

# 4. Evaluation Function
def evaluate_model(model, test_loader):
    """
    Evaluate the model on the test dataset.

    Args:
    model: Trained model
    test_loader: DataLoader for testing

    Returns:
    accuracy: Test accuracy
    """
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    return accuracy

# Main Function
def main():
    """
    Main function to train, save, and evaluate the MNIST classification model.
    """
    train_loader, test_loader = load_data()  # Load MNIST data
    model = get_model()  # Load and modify ResNet50

    # Define loss function, optimizer, and learning rate scheduler
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label smoothing
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)  # AdamW optimizer
    scheduler = SWALR(optimizer, swa_lr=0.0005)  # Learning rate scheduler for SWA

    print("Starting training...")
    train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=20)

    print("Evaluating best SWA model...")
    swa_model = AveragedModel(model)
    swa_model.load_state_dict(torch.load('swa_best_model.pth'))
    swa_model = swa_model.to(device)
    final_accuracy = evaluate_model(swa_model, test_loader)
    print(f"Final Test Accuracy: {final_accuracy:.2f}%")

# Run the program
if __name__ == "__main__":
    main()
