<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Adversarial_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the SimpleCNN class
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # Convolutional layer
        self.fc1 = nn.Linear(32 * 28 * 28, 10)  # Fully connected layer

    def forward(self, x):
        x = torch.relu(self.conv1(x))  # Apply ReLU activation after conv layer
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc1(x)  # Apply fully connected layer
        return x

# Define the FGSM attack function
def fgsm_attack(model, loss_fn, images, labels, epsilon):
    images.requires_grad = True  # Enable gradient computation for images
    outputs = model(images)  # Forward pass
    loss = loss_fn(outputs, labels)  # Compute loss
    model.zero_grad()  # Zero the gradients
    loss.backward()  # Backpropagation
    data_grad = images.grad.data  # Get the gradients of the images
    perturbed_image = images + epsilon * data_grad.sign()  # Add perturbations to the images
    return torch.clamp(perturbed_image, 0, 1)  # Clamp the perturbed images to be within [0, 1]

# Example usage
model = SimpleCNN()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Training loop with adversarial training
for _ in range(5):
    for images, labels in data_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # Generate adversarial examples and train on them
        adversarial_images = fgsm_attack(model, loss_fn, images, labels, epsilon=0.1)
        adversarial_outputs = model(adversarial_images)
        adversarial_loss = loss_fn(adversarial_outputs, labels)
        adversarial_loss.backward()
        optimizer.step()

# Print final loss value
print("Final training loss:", loss.item())