# *Baseline* and *Mitigation* Models - Training

[Kaggle Notebook](https://www.kaggle.com/code/mklokeshkumar/train-baseline)

In [None]:
!pip install adversarial-robustness-toolbox torch matplotlib numpy



In [64]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from art.attacks.evasion import ProjectedGradientDescent
from art.estimators.classification import PyTorchClassifier
import numpy as np

In [65]:
# defining model architectures

class SimpleCNN(nn.Module):
    def __init__(self, num_classes = 10): # specifically for MNIST
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim = 1)
        
        return output

In [66]:
# defining training training hyperparameters

EPOCHS_BASELINE = 5
EPOCHS_MITIGATION = 10
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EPSILON = 0.09 # pertuberation size for PGD attack

# checking for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [67]:
# loading data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root = '/kaggle/working', train = True, download = True, transform = transform)
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)

In [68]:
# training function

def train_model(model, train_loader, optimizer, loss_fn, epochs, description):
    model.train()
    print(f"Starting {description} training...")
    for epoch in range(1, epochs + 1):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
    
    print(f"{description} training completed")

In [None]:
# training phase 1 - base classification model

baseline_model = SimpleCNN().to(device)
baseline_optimizer = optim.Adam(baseline_model.parameters(), lr = LEARNING_RATE)
train_model(baseline_model, train_loader, baseline_optimizer, nn.CrossEntropyLoss(), EPOCHS_BASELINE, "Baseline Model")
torch.save(baseline_model.state_dict(), "../models/baseline_model.pth")
print("Baseline Model saved as 'baseline_model.pth'")

Starting Baseline Model training...
Baseline Model training completed
Baseline Model saved as 'baseline_model.pth'


In [None]:
# training phase 2 - mitigation model with adversarial training

mitigation_model = SimpleCNN().to(device)
mitigation_optimizer = optim.Adam(mitigation_model.parameters(), lr = LEARNING_RATE)
mitigation_classifier_art = PyTorchClassifier(model = mitigation_model,
                                              loss = nn.CrossEntropyLoss(),
                                              optimizer = mitigation_optimizer,
                                              input_shape = (1, 28, 28),
                                              nb_classes = 10,
                                              device_type = "gpu"
                                            )

mitigation_classifier_art.fit(x = train_loader.dataset.data.numpy().reshape(-1, 1, 28, 28), 
                              y = train_loader.dataset.targets.numpy(),
                              batch_size = BATCH_SIZE, nb_epochs = EPOCHS_MITIGATION,
                              attack_params = {
                                    "projected_gradient_descent": {
                                        "eps": EPSILON,
                                        "eps_step": 0.01,
                                        "max_iter": 20
                                    }
                                }
                            )

# pgd_attack = ProjectedGradientDescent(mitigation_classifier_art, eps = EPSILON)

# for epoch in range(EPOCHS_MITIGATION):
#     for data, target in train_loader:
#         mitigation_model.train()
#         data, target = data.to(device), target.to(device)
        
#         x_adv_batch = pgd_attack.generate(x = data.cpu().numpy(), y = target.cpu().numpy().astype(np.int64))
#         x_adv_batch_tensor = torch.tensor(x_adv_batch).to(device)

#         combined_data = torch.cat((data, x_adv_batch_tensor), dim = 0)
#         combined_target = torch.cat((target, target), dim = 0)

#         mitigation_optimizer.zero_grad()
#         output = mitigation_model(combined_data)
#         loss = nn.CrossEntropyLoss()(output, combined_target.long())
#         loss.backward()
#         mitigation_optimizer.step()

#     print(f"Mitigation Model epoch {epoch + 1}/{EPOCHS_MITIGATION} completed")

torch.save(mitigation_model.state_dict(), "../models/mitigation_model.pth")
print("Mitigation Model saved as 'mitigation_model.pth'")

Mitigation Model saved as 'mitigation_model.pth'
