In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
# import matplotlib.pyplot as plt # Removed matplotlib for no image display
import numpy as np

# --- 1. Define a Simple CNN Model ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Input: 1 channel (grayscale), output: 10 channels (for 10 digits)
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5) # Output: (N, 10, 24, 24)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5) # Output: (N, 20, 8, 8)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50) # 20 * 4 * 4 = 320
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320) # Flatten the tensor for the fully connected layers
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1) # Output log probabilities

# --- 2. Load and Train Model (or load pre-trained) ---
def train_model(model, device, train_loader, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target) # Negative Log Likelihood Loss
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                      f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def test_model(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.0f}%)\n')


# --- 3. Implement MI-FGSM Attack ---
def mi_fgsm_attack(model, image, true_label, epsilon, alpha, iterations, decay=1.0):
    """
    Implements the Momentum Iterative Fast Gradient Sign Method (MI-FGSM) attack.

    Args:
        model (nn.Module): The target classification model.
        image (torch.Tensor): The original input image (tensor).
        true_label (torch.Tensor): The true label of the original image.
        epsilon (float): The maximum total perturbation allowed (L-infinity norm).
        alpha (float): The step size for each iteration.
        iterations (int): The number of iterations for the attack.
        decay (float): Decay factor for the momentum (rho in papers), default 1.0 (no decay).

    Returns:
        torch.Tensor: The adversarial image.
    """
    # Set model to evaluation mode
    model.eval()

    # Clone the image and enable gradient calculation for it
    x_adv = image.clone().detach().requires_grad_(True)
    
    # Initialize momentum
    momentum = torch.zeros_like(image, device=image.device)

    for i in range(iterations):
        # Forward pass: get model output
        output = model(x_adv)
        
        # Calculate loss (untargeted attack: maximize loss for true label)
        # If performing a targeted attack, the loss function would be different,
        # e.g., maximizing the logit of the target class while minimizing others.
        loss = F.nll_loss(output, true_label)
        
        # Zero previous gradients
        model.zero_grad()
        
        # Compute gradients of loss w.r.t. x_adv
        loss.backward()
        
        # Get the gradient data
        grad = x_adv.grad.data

        # Normalize the gradient by its L1 norm (as per MI-FGSM)
        # Add a small epsilon to the denominator to avoid division by zero
        grad_norm = F.normalize(grad, p=1) 
        
        # Update momentum
        momentum = decay * momentum + grad_norm

        # Apply perturbation in the direction of the sign of momentum
        # and clip the perturbation by alpha for each step
        x_adv_new = x_adv.data + alpha * torch.sign(momentum)
        
        # Clip the adversarial image to stay within the epsilon bounds relative to original
        # This ensures the total perturbation doesn't exceed epsilon
        perturbation = x_adv_new - image.data
        x_adv.data = image.data + torch.clamp(perturbation, -epsilon, epsilon)
        
        # Clip the adversarial image to valid pixel range [0, 1]
        x_adv.data = torch.clamp(x_adv.data, 0, 1)
        
        # Detach x_adv.grad for the next iteration to prevent graph accumulation
        if x_adv.grad is not None:
            x_adv.grad.zero_()
            
    return x_adv

# --- 4. Main Execution and Visualization ---
if __name__ == "__main__":
    # --- Configuration ---
    # Check if GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Dataset transformation: Convert to Tensor, Normalize (for MNIST)
    transform = transforms.Compose([
        transforms.ToTensor(),
        # MNIST images are typically normalized to [-1, 1] or [0, 1].
        # For simplicity in attack, we'll work with [0, 1] and then clip.
        # So, no explicit normalization to mean/std for now.
        # transforms.Normalize((0.1307,), (0.3081,)) # Standard MNIST normalization
    ])

    # Load MNIST dataset
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True, transform=transform),
        batch_size=64, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transform),
        batch_size=1, shuffle=True # Batch size 1 for easy individual sample attack
    )

    # Initialize and train the model
    model = SimpleCNN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # You can load a pre-trained model here to skip training:
    # try:
    #     model.load_state_dict(torch.load('mnist_cnn_mi_fgsm.pt'))
    #     print("Loaded pre-trained model.")
    # except FileNotFoundError:
    #     print("Pre-trained model not found. Training new model...")
    #     train_model(model, device, train_loader, optimizer, epochs=5)
    #     torch.save(model.state_dict(), 'mnist_cnn_mi_fgsm.pt') # Save after training

    # For demonstration, let's train for a few epochs if no model is found
    print("Training model...")
    train_model(model, device, train_loader, optimizer, epochs=3) # Train for 3 epochs
    test_model(model, device, test_loader)
    print("Model training/loading complete.")

    # --- Attack Parameters ---
    epsilon = 0.2    # Total perturbation budget (e.g., 0.2 means max 20% change in pixel value)
    alpha = 0.02     # Step size for each iteration (should be < epsilon)
    iterations = 20  # Number of attack iterations
    decay = 1.0      # Momentum decay factor (1.0 for no decay)

    # --- Select a sample to attack ---
    # Find a sample that the model initially classifies correctly
    original_image, true_label = None, None
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        initial_pred = output.argmax(dim=1, keepdim=True)
        if initial_pred.item() == target.item():
            original_image = data
            true_label = target
            break
    
    if original_image is None:
        print("Could not find a correctly classified sample to attack. Exiting.")
        exit()

    print(f"Attacking image with true label: {true_label.item()}")
    
    # --- Run MI-FGSM Attack ---
    adversarial_image = mi_fgsm_attack(model, original_image, true_label, epsilon, alpha, iterations, decay)

    # --- Evaluate Adversarial Example ---
    model.eval()
    with torch.no_grad():
        original_output = model(original_image)
        original_pred = original_output.argmax(dim=1, keepdim=True).item()
        
        adversarial_output = model(adversarial_image)
        adversarial_pred = adversarial_output.argmax(dim=1, keepdim=True).item()

    print(f"\nOriginal Prediction: {original_pred}")
    print(f"Adversarial Prediction: {adversarial_pred}")

    if original_pred != adversarial_pred:
        print(f"\nMI-FGSM attack successfully changed prediction from {original_pred} to {adversarial_pred}")
    else:
        print("\nMI-FGSM attack did NOT change the prediction.")



Using device: cpu
Training model...

Test set: Average loss: 0.0714, Accuracy: 9771/10000 (98%)

Model training/loading complete.
Attacking image with true label: 4

Original Prediction: 4
Adversarial Prediction: 9

MI-FGSM attack successfully changed prediction from 4 to 9
