# Activation Classifier
In this file, we detail a meta-model that determines if a specific inference result of a model is a result of poisoning.
To this end, our meta-model inputs are all the activations of the model, and the model performs a binary classification: poisoned or not.

Clearly, this example is trivial. If our base poisoned model can detect and interpret the trigger, then it is trivial for our meta-model to detect the trigger as well.

In order to run this notebook, please first run the `MAD.ipynb`.

## Setup
We first import the required libraries and initialize the dataset. Check out `activation_dataset.py` for more details on the data.

In [1]:
import torch
import random
import numpy as np
import pickle
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torch import nn

from activation_dataset import ActivationDataset

with open("data/activations_dataset.pkl", "rb") as f:
    dataset = pickle.load(f)

# define the size of the training set
train_size = int(0.8 * len(dataset))

# create a list of indices for the training set and the evaluation set
indices = list(range(len(dataset)))
random.shuffle(indices)
train_indices = indices[:train_size]
eval_indices = indices[train_size:]

# create a PyTorch Subset for the training set and the evaluation set
train_dataset: Subset[ActivationDataset] = Subset(dataset, train_indices)
eval_dataset: Subset[ActivationDataset] = Subset(dataset, eval_indices)

# select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  return torch._C._cuda_getDeviceCount() > 0


## Model definition

Here, we define the model. As explained above, the model is a binary classifier that takes as input the activations of the model and outputs a binary classification: poisoned or not. In conjunction, we know that this task is extremely trivial in this case, so we use an MLP.

In [2]:
class Network(nn.Module):
    def __init__(self, concat_dim=394):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(concat_dim, 100),
            nn.ReLU(),
            nn.Linear(100, 20),
            nn.ReLU(),
            nn.Linear(20, 2)
        )
    
    def forward(self, x):
        return self.main(x).softmax(dim=-1)

In [3]:
def evaluate(loader, model:Network):
    with torch.no_grad():
        running_loss = 0
        correct = 0
        count = 0
        for _, batch in enumerate(loader):
            bx = batch[0].to(device)
            by = batch[1].to(device)
            count += by.size(0)
            pred = model(bx)
            loss = F.cross_entropy(pred, by.long())
            running_loss += loss.item()
            pred = pred.argmax(dim=1, keepdim=True)
            correct += pred.eq(by.view_as(pred)).sum().item()
        loss = running_loss / count
        acc = correct / count
    return loss, acc

In [4]:
def train_model(
    train_data: Subset[ActivationDataset],
    test_data: Subset[ActivationDataset],
    model: Network,
    num_epochs=10,
    batch_size=64,
):
    """
    :param train_data: the data to train with
    :param test_data: the clean test data to evaluate accuracy on
    :param model: the model to train
    :param num_epochs: the number of epochs to train for
    :param batch_size: the batch size for training
    """
    train_loader = DataLoader(train_data, batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, len(train_loader) * num_epochs
    )

    loss_ema = np.inf

    for epoch in range(num_epochs):
        loss, acc = evaluate(test_loader, model)

        print("Epoch {}:: Test Loss: {:.3f}, Test Acc: {:.3f}".format(epoch, loss, acc))
        for i, (bx, by) in enumerate(train_loader):
            bx = bx.to(device)
            by = by.to(device)

            pred = model(bx)
            loss = F.cross_entropy(pred, by.long())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if loss_ema == np.inf:
                loss_ema = loss.item()
            else:
                loss_ema = loss_ema * 0.95 + loss.item() * 0.05

            if i % 500 == 0:
                print(
                    "Train loss: {:.3f}".format(loss_ema)
                )  # to get a rough idea of training loss

    loss, acc = evaluate(test_loader, model)

    print("Final Metrics:: Test Loss: {:.3f}, Test Acc: {:.3f}".format(loss, acc))

    return loss, acc


In [6]:
model = Network().to(device)
loss, acc = train_model(train_dataset, eval_dataset, model, num_epochs=5, batch_size=256)

print("Final Metrics:: Test Loss: {:.3f}, Test Acc: {:.3f}".format(loss, acc))

Epoch 0:: Test Loss: 0.003, Test Acc: 0.553
Train loss: 0.681
Epoch 1:: Test Loss: 0.001, Test Acc: 0.995
Train loss: 0.414
Epoch 2:: Test Loss: 0.001, Test Acc: 0.998
Train loss: 0.335
Epoch 3:: Test Loss: 0.001, Test Acc: 0.999
Train loss: 0.319
Epoch 4:: Test Loss: 0.001, Test Acc: 1.000
Train loss: 0.315
Final Metrics:: Test Loss: 0.001, Test Acc: 1.000
Final Metrics:: Test Loss: 0.001, Test Acc: 1.000
