In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tests.estimators.classification.test_jax import classifier
from torch.nn.functional import batch_norm
from torchvision import datasets, transforms, utils
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import ssl
import random
from pathlib import Path
from art.estimators.classification import PyTorchClassifier


ssl._create_default_https_context = ssl._create_stdlib_context

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Ensure reproducibility

# Set a global random seed
seed = 10
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

In [3]:
# Teacher Model
lr = 0.001
batch_size = 128
n_channels, w, h = 1, 28, 28
max_epochs = 5

In [5]:
# Attacks
max_attack_iter = 50

In [6]:
dt_p = Path('data/mnist')

classes = [str(i) for i in range(10)]
n_labels = len(classes)

transform = transforms.Compose(
    [transforms.ToTensor()])

trainset = datasets.MNIST(root=f'{dt_p.absolute()}/train', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
testset = datasets.MNIST(root=f'{dt_p.absolute()}/test', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

In [7]:
class MnistNet(nn.Module):
    def __init__(self, input_size=28):  # Add input size for flexibility
        super().__init__()
        self.conv1a = nn.Conv2d(1, 32, 3, padding=0)
        self.conv1b = nn.Conv2d(32, 32, 3, padding=0)
        self.conv2a = nn.Conv2d(32, 64, 3, padding=0)
        self.conv2b = nn.Conv2d(64, 64, 3, padding=0)
        self.pool = nn.MaxPool2d(2, 2)

        # Compute the number of features dynamically
        self._to_linear = self._get_conv_output(input_size)

        self.fc1 = nn.Linear(self._to_linear, 200)
        self.fc2 = nn.Linear(200, 10)

        self.flatten = nn.Flatten()
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.batchnorm = nn.BatchNorm1d(200)  # Now correctly applied

    def _get_conv_output(self, size):
        """Helper function to compute the output size after convolutions"""
        x = torch.zeros(1, 1, size, size)  # Create a dummy tensor
        x = self.pool(F.relu(self.conv1b(F.relu(self.conv1a(x)))))
        x = self.pool(F.relu(self.conv2b(F.relu(self.conv2a(x)))))
        return x.numel()

    def forward(self, x):
        x = self.activation(self.conv1a(x))
        x = self.activation(self.conv1b(x))
        x = self.dropout(x)
        x = self.pool(x)
        x = self.activation(self.conv2a(x))
        x = self.activation(self.conv2b(x))
        x = self.dropout(x)
        x = self.pool(x)

        x = self.flatten(x)  # Flatten before FC layers
        x = self.activation(self.fc1(x))
        x = self.batchnorm(x)
        x = self.fc2(x)
        # CrossEntropyLoss already applied softmax
        return x  # Remove softmax if using CrossEntropyLoss


device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Playing on {device}")

Playing on mps


In [8]:
# Specify teacher model

teacher_model = MnistNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(teacher_model.parameters(), lr=lr)

In [9]:
# Train the teacher model
teacher_losses = []

for e in tqdm(range(max_epochs)):
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        # Forward pass
        logits = teacher_model(images)
        # Compute loss
        loss = criterion(logits, labels)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    teacher_losses.append(loss.item())

    if e % 10 == 0 or e == max_epochs:
        print(f"Epoch {e}: {loss.item()}")

 20%|██        | 1/5 [00:07<00:28,  7.16s/it]

Epoch 0: 0.1005193293094635


100%|██████████| 5/5 [00:19<00:00,  3.93s/it]


In [None]:
def evaluate_model(model, data, labels, device="cpu"):
    """
    Evaluate model accuracy and loss.

    Args:
        model: PyTorch model (ART-wrapped or regular)
        data: Input samples (NumPy array)
        labels: True labels (NumPy array, one-hot encoded)
        criterion: Loss function (e.g., nn.CrossEntropyLoss())
        device: "cpu" or "cuda" (use GPU if available)

    Returns:
        accuracy (float)
    """
    model.to(device)
    model.eval()  # Set model to evaluation mode

    # Convert NumPy data to PyTorch tensors
    data_tensor = torch.tensor(data, dtype=torch.float32).to(device)
    labels_tensor = torch.tensor(labels, dtype=torch.float32).to(device)

    # Forward pass: Compute predictions
    with torch.no_grad():
        outputs = model(data_tensor)  # Get raw logits
        predictions = torch.argmax(outputs, dim=1)  # Convert to class labels
        true_labels = torch.argmax(labels_tensor, dim=1)  # Convert one-hot to labels

    # Compute accuracy
    accuracy = (predictions == true_labels).float().mean().item() * 100

    return accuracy

In [10]:
# Wrap in ART PyTorchClassifier
art_model_t = PyTorchClassifier(
    model=teacher_model,
    clip_values=(0, 1),  # Min and Max pixel values (normalize if needed)
    loss=criterion,
    optimizer=optimizer,
    input_shape=(1, 28, 28),
    nb_classes=10
)
art_model_t.to(device)  # Move model to CPU

In [11]:
mnist_targets = torch.nn.functional.one_hot(testset.targets, num_classes=10).float().numpy()
mnist_data = testset.data.unsqueeze(1).float().numpy()  # Add channel dimension (N, 1, 28, 28)

In [24]:
# Computationally Expensive methods -> Select Test Subset
num_samples = 200
# First shuffle
indices = torch.randperm(len(testset.data))
mnist_data_shuffled = mnist_data[indices]
mnist_targets_shuffled = mnist_targets[indices]
# Then select subsets
mnist_data_subset = mnist_data_shuffled[:num_samples]
mnist_targets_subset = mnist_targets_shuffled[:num_samples]

In [None]:
# Ensure teacher model is on CPU to create the attacks
teacher_model.to('cpu')

In [82]:
# Attack #1
from art.attacks.evasion.fast_gradient import FastGradientMethod

attack = FastGradientMethod(estimator=art_model_t, eps=0.5, eps_step=0.1, batch_size=32, minimal=True)
x_adv_fgm = attack.generate(x=mnist_data_subset, y=mnist_targets_subset)

In [25]:
# Attack #2

from art.attacks.evasion.deepfool import DeepFool

attack = DeepFool(classifier=art_model_t, max_iter=max_attack_iter, batch_size=32)
x_adv_deepfool = attack.generate(x=mnist_data_subset, y=mnist_targets_subset)

DeepFool: 100%|██████████| 7/7 [00:15<00:00,  2.24s/it]


In [37]:
# Attack #3

from art.attacks.evasion.carlini import CarliniL2Method

attack = CarliniL2Method(classifier=art_model_t, max_iter=max_attack_iter, batch_size=32)
x_adv_carlini = attack.generate(x=mnist_data_subset, y=mnist_targets_subset)

C&W L_2: 100%|██████████| 7/7 [12:44<00:00, 109.21s/it]


In [83]:
## Teacher Model

# Original test set evaluation
original_accuracy = evaluate_model(
    art_model_t.model, mnist_data_subset, mnist_targets_subset, criterion
)
print(f"Original Test Accuracy: {original_accuracy:.2f}%")

# Adversarial test sets evaluation
fgm_accuracy = evaluate_model(
    art_model_t.model, x_adv_fgm, mnist_targets_subset, criterion
)
deepfool_accuracy = evaluate_model(
    art_model_t.model, x_adv_deepfool, mnist_targets_subset, criterion
)
carlini_accuracy = evaluate_model(
    art_model_t.model, x_adv_carlini, mnist_targets_subset, criterion
)

print(f"Adversarial Test Accuracy (FGS Method): {fgm_accuracy:.2f}%")
print(f"Adversarial Test Accuracy (Deepfool Method): {deepfool_accuracy:.2f}%")
print(f"Adversarial Test Accuracy (Carlini L2 Method): {carlini_accuracy:.2f}%")

Original Test Accuracy: 90.50%
Adversarial Test Accuracy (FGS Method): 91.00%
Adversarial Test Accuracy (Deepfool Method): 7.00%
Adversarial Test Accuracy (Carlini L2 Method): 90.50%


In [84]:
# Generate Soft Labels

def get_soft_labels(model, dataloader, temp):
    soft_labels_list = []
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            # Get logits from trained teacher model
            logits = model(images)
            # Apply temperature-scaled softmax
            soft_labels = torch.softmax(logits / temp, dim=1)
            soft_labels_list.append(soft_labels.cpu())
    return torch.cat(soft_labels_list)

In [85]:
def distillation_loss(student_logits, teacher_soft_labels, hard_labels, temp, alpha):
    """
    Computes the distillation loss.

    Args:
        student_logits: Output logits from the student model.
        teacher_soft_labels: Soft labels generated by the teacher model.
        hard_labels: Original hard labels.
        temp: Temperature used for soft labels.
        alpha: Weight for soft loss (1-alpha for hard loss).
    """
    soft_loss = nn.KLDivLoss(reduction="batchmean")(
        torch.log_softmax(student_logits / temp, dim=1),
        teacher_soft_labels
    )
    hard_loss = nn.CrossEntropyLoss()(student_logits, hard_labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss


In [43]:
def train_student(teacher_model, student_model, trainloader, temp=20, alpha=0.7, epochs=10, lr=0.01):
    """
    Trains the student model using knowledge distillation.

    Args:
        teacher_model: Pretrained teacher model.
        student_model: Student model to train.
        trainloader: DataLoader for training.
        temp: Temperature for soft labels.
        alpha: Weight for soft labels in the loss.
        epochs: Number of training epochs.
        lr: Learning rate.
    """
    student_model.train()
    optimizer = optim.Adam(student_model.parameters(), lr=lr)

    for e in tqdm(range(epochs)):
        total_loss = 0
        for images, hard_labels in trainloader:
            images, hard_labels = images.to(device), hard_labels.to(device)

            # Get teacher soft labels
            with torch.no_grad():
                teacher_logits = teacher_model(images)
                soft_labels = torch.softmax(teacher_logits / temp, dim=1)

            # Get student predictions
            student_logits = student_model(images)

            # Compute distillation loss
            loss = distillation_loss(student_logits, soft_labels, hard_labels, temp, alpha)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if e % 10 == 0:
            print(f"Epoch {e}: {loss.item()}")

    print("Training complete!")


In [52]:
# Specify Student Model

temp = 20
max_epochs = 10
alpha = 0.5
# Initialize the student model
student_model = MnistNet().to(device)

In [53]:
teacher_model.to(device)
# Generate soft labels for training the student
soft_labels = get_soft_labels(teacher_model, trainloader, temp)

In [54]:
# Train Student
train_student(teacher_model, student_model, trainloader, temp=temp, alpha=alpha, epochs=max_epochs, lr=lr)

 10%|█         | 1/10 [00:04<00:40,  4.52s/it]

Epoch 0: 0.020531142130494118


100%|██████████| 10/10 [00:45<00:00,  4.55s/it]

Training complete!





In [55]:
# Wrap in ART PyTorchClassifier
art_model_s = PyTorchClassifier(
    model=student_model,
    clip_values=(0, 1),  # Min and Max pixel values (normalize if needed)
    loss=criterion,
    optimizer=optimizer,
    input_shape=(1, 28, 28),
    nb_classes=10
)

In [86]:
## Student Model

# Original test set evaluation
original_accuracy = evaluate_model(
    art_model_s.model, mnist_data_subset, mnist_targets_subset, criterion
)
print(f"Original Test Accuracy: {original_accuracy:.2f}%")

# Adversarial test sets evaluation

fgm_accuracy = evaluate_model(
    art_model_s.model, x_adv_fgm, mnist_targets_subset, criterion
)
deepfool_accuracy = evaluate_model(
    art_model_s.model, x_adv_deepfool, mnist_targets_subset, criterion
)
carlini_accuracy = evaluate_model(
    art_model_s.model, x_adv_carlini, mnist_targets_subset, criterion
)

print(f"Adversarial Test Accuracy (FGS Method): {fgm_accuracy:.2f}%")
print(f"Adversarial Test Accuracy (Deepfool Method): {deepfool_accuracy:.2f}%")
print(f"Adversarial Test Accuracy (Carlini L2 Method): {carlini_accuracy:.2f}%")

Original Test Accuracy: 99.50%
Adversarial Test Accuracy (FGS Method): 91.00%
Adversarial Test Accuracy (Deepfool Method): 11.50%
Adversarial Test Accuracy (Carlini L2 Method): 99.50%
