<a href="https://colab.research.google.com/github/aniketSanyal/OverfittingInRML/blob/main/rml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Step 1: Setup Environment and Load Datasets

1. **Import Libraries**: We'll import PyTorch, torchvision, and other necessary libraries.
2. **Load Datasets**: Load MNIST and CIFAR10 datasets using torchvision.
3. **Data Loaders**: Create data loaders for both datasets for easy batch processing.

In [45]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Adjust these values for CIFAR10
])

# Load CIFAR10 Dataset
cifar_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar_train_loader = DataLoader(cifar_train_dataset, batch_size=64, shuffle=True)

# Load MNIST Dataset
mnist_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_train_loader = DataLoader(mnist_train_dataset, batch_size=64, shuffle=True)

Files already downloaded and verified


### Step 2: Model Definition

We'll define simple CNN architectures suitable for each dataset. MNIST images are grayscale and smaller, while CIFAR10 images are color and larger.

#### MNIST CNN Model:
- Simple architecture with a couple of convolutional layers.

#### CIFAR10 CNN Model:
- A bit more complex due to the nature of the dataset (color images).

In [46]:
import torch.nn as nn
import torch.nn.functional as F

# CNN for MNIST
class MNIST_CNN(nn.Module):
    def __init__(self):
        super(MNIST_CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(7*7*64, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1) # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# CNN for CIFAR10
class CIFAR10_CNN(nn.Module):
    def __init__(self):
        super(CIFAR10_CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(8*8*128, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1) # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


### Step 3: FGSM Function
The FGSM method creates adversarial examples by adding a small perturbation to the original image in the direction of the gradient of the loss with respect to the input image.

### Step 4: PGD Function
PGD is a more powerful attack compared to FGSM. It applies the perturbation iteratively and projects the perturbed image back into the allowed range after each step.

In [47]:
def fgsm_attack(model, images, labels, epsilon, device):
    images.requires_grad = True
    outputs = model(images)
    model.zero_grad()
    loss = F.cross_entropy(outputs, labels)
    loss.backward()

    attack_images = images + epsilon * images.grad.sign()
    attack_images = torch.clamp(attack_images, 0, 1)
    return attack_images

def pgd_attack(model, images, labels, epsilon, alpha, iters, device):
    ori_images = images.data
    for i in range(iters):
        images.requires_grad = True
        outputs = model(images)
        model.zero_grad()
        loss = F.cross_entropy(outputs, labels)
        loss.backward()

        attack_images = images + alpha * images.grad.sign()
        eta = torch.clamp(attack_images - ori_images, min=-epsilon, max=epsilon)
        images = torch.clamp(ori_images + eta, 0, 1).detach_()
    return images


### Step 5: Training Loop with Adversarial Training

We'll:
1. Load pretrained models or train simple models for MNIST and CIFAR10.
2. Select a few sample images from both datasets.
3. Apply FGSM and PGD attacks on these samples.
4. Visualize the results to see the effect of the attacks.

In [None]:
def train(model, device, train_loader, optimizer, epoch, attack=None, epsilon=0.01):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        if attack is not None:
            data.requires_grad = True
            output = model(data)
            loss = F.cross_entropy(output, target)
            model.zero_grad()
            loss.backward()
            data_grad = data.grad.data
            data = attack(model, data, target, epsilon, 0.01, 40, device) if attack == pgd_attack else fgsm_attack(data, epsilon, data_grad)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

# Example usage
model = CIFAR10_CNN().to(device)

epsilon = 0.1
num_epochs = 10
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    for data, target in cifar_train_loader:
        data, target = data.to(device), target.to(device)

        # Perform a standard training pass on the original data
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        # Generate adversarial examples using FGSM or PGD
        adv_data = fgsm_attack(model, data, target, epsilon, device)  # or pgd_attack

        # Re-train (or test) the model using the adversarial examples
        optimizer.zero_grad()
        adv_output = model(adv_data)
        adv_loss = F.cross_entropy(adv_output, target)
        adv_loss.backward()
        optimizer.step()


# Result

In [None]:
def visualize(image, title):
    npimg = image.numpy()
    plt.figure(figsize=(8, 8))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()

# Take a sample for visualization
dataiter = iter(cifar_test_loader)
images, labels = next(dataiter)

# FGSM
images.requires_grad = True
output = model_cifar(images.to(device))
loss = F.cross_entropy(output, labels.to(device))
model_cifar.zero_grad()
loss.backward()

perturbed_data = fgsm_attack(images, 0.05, images.grad)
visualize(torchvision.utils.make_grid(perturbed_data.cpu()), "FGSM Attack on CIFAR10")

# PGD
perturbed_data_pgd = pgd_attack(model_cifar, images.to(device), labels.to(device), 0.05, 0.01, 40, device)
visualize(torchvision.utils.make_grid(perturbed_data_pgd.cpu()), "PGD Attack on CIFAR10")

