# Training EENet exit heads

This notebook provides code for training EENet [1] exit heads for clean and poisoned Resnet models. For each attack, the code would load the saved weights from the previously finetuned models on Oxfordpets with/without attacks and freezes the backbone model in order to train the exit heads from scratch.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import precision_score, recall_score, f1_score
import random
import pandas as pd
import numpy as np
import os
from copy import deepcopy

from PIL import Image


TRAINING BADNET EENETS

In [None]:

class ExitBlock(nn.Module):
    def __init__(self, inplanes, num_classes):
        super(ExitBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(inplanes, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.pool(x)
        return self.classifier(x)

class EEResNet(nn.Module):
    def __init__(self, base_model, num_classes=37):
        super(EEResNet, self).__init__()
        assert base_model in [
            'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
            'wide_resnet50_2', 'wide_resnet101_2'
        ], "Unsupported model."
        self.backbone = getattr(models, base_model)(weights=None)
        self.exit1 = None
        self.exit2 = None
        self.exit3 = None
        self.exit4 = None
        self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.final_fc = nn.Linear(self.backbone.fc.in_features, num_classes)

        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            x = self.backbone.conv1(dummy)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)
            x1 = self.backbone.layer1(x)
            x2 = self.backbone.layer2(x1)
            x3 = self.backbone.layer3(x2)
            x4 = self.backbone.layer4(x3)

            self.exit1 = ExitBlock(x1.shape[1], num_classes)
            self.exit2 = ExitBlock(x2.shape[1], num_classes)
            self.exit3 = ExitBlock(x3.shape[1], num_classes)
            self.exit4 = ExitBlock(x4.shape[1], num_classes)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        preds = []
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x1 = self.backbone.layer1(x); preds.append(self.exit1(x1))
        x2 = self.backbone.layer2(x1); preds.append(self.exit2(x2))
        x3 = self.backbone.layer3(x2); preds.append(self.exit3(x3))
        x4 = self.backbone.layer4(x3); preds.append(self.exit4(x4))
        pooled = self.final_pool(x4); pooled = torch.flatten(pooled, 1)
        final_out = self.softmax(self.final_fc(pooled)); preds.append(final_out)
        return preds

# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 10
LEARNING_RATE = 1e-4
TARGET_CLASS = 0
MODEL_LIST = [
    'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
    'wide_resnet50_2', 'wide_resnet101_2'
]

SAVE_DIR = '/content/drive/MyDrive/Colab Notebooks/EEResNet_Badnet_ExitHeads'
os.makedirs(SAVE_DIR, exist_ok=True)

# Trigger function (same as existing)
def add_trigger(image, trigger_size=5, trigger_color=(1, 1, 1)):
    img = image.clone()
    _, h, w = img.shape
    img[:, h - trigger_size:h, w - trigger_size:w] = torch.tensor(trigger_color).view(3, 1, 1)
    return img

# Poison dataset
def poison_dataset(dataset, fraction=0.1, target_class=0):
    from copy import deepcopy
    import random
    poisoned = []
    dataset = deepcopy(dataset)
    n_total = len(dataset)
    n_poison = int(n_total * fraction)
    indices = random.sample(range(n_total), n_poison)
    for i, (x, y) in enumerate(dataset):
        if i in indices:
            x = add_trigger(x)
            y = target_class
        poisoned.append((x, y))
    return poisoned

# Main training loop
for model_name in MODEL_LIST:
    print(f"Training exit heads for {model_name}...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EEResNet(base_model=model_name).to(device)

    # Load frozen backbone
    backbone_path = f"/content/drive/MyDrive/Colab Notebooks/BadNetModels/{model_name}_badnet.pth"
    state_dict = torch.load(backbone_path, map_location=device)
    state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}
    model.backbone.load_state_dict(state_dict, strict=False)

    # Freeze backbone
    for param in model.backbone.parameters():
        param.requires_grad = False

    # Data
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    train_dataset = OxfordIIITPet(root='.', split='trainval', target_types='category', transform=transform, download=True)
    poisoned_dataset = poison_dataset(train_dataset, fraction=0.1, target_class=TARGET_CLASS)
    train_loader = DataLoader(poisoned_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Train only exit heads and final FC
    exit_params = list(model.exit1.parameters()) + list(model.exit2.parameters()) + \
                  list(model.exit3.parameters()) + list(model.exit4.parameters()) + \
                  list(model.final_fc.parameters())
    optimizer = optim.Adam(exit_params, lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = sum(criterion(out, targets) for out in outputs)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)
            _, preds = outputs[-1].max(1)
            correct += preds.eq(targets).sum().item()
            total += targets.size(0)

        print(f"{model_name} | Epoch {epoch+1}/{EPOCHS} | Loss: {total_loss / total:.4f} | Acc: {100 * correct / total:.2f}%")

    torch.save(model.state_dict(), os.path.join(SAVE_DIR, f"{model_name}_eeresnet_exits_trained.pth"))
    print(f"Saved model: {model_name}")


Training exit heads for resnet18...


100%|██████████| 792M/792M [00:46<00:00, 17.0MB/s]
100%|██████████| 19.2M/19.2M [00:02<00:00, 8.60MB/s]


resnet18 | Epoch 1/10 | Loss: 17.9676 | Acc: 14.84%
resnet18 | Epoch 2/10 | Loss: 17.5743 | Acc: 42.26%
resnet18 | Epoch 3/10 | Loss: 16.9339 | Acc: 78.72%
resnet18 | Epoch 4/10 | Loss: 16.5948 | Acc: 84.51%
resnet18 | Epoch 5/10 | Loss: 16.4955 | Acc: 84.86%
resnet18 | Epoch 6/10 | Loss: 16.4386 | Acc: 86.98%
resnet18 | Epoch 7/10 | Loss: 16.3858 | Acc: 89.29%
resnet18 | Epoch 8/10 | Loss: 16.3554 | Acc: 89.40%
resnet18 | Epoch 9/10 | Loss: 16.3236 | Acc: 89.89%
resnet18 | Epoch 10/10 | Loss: 16.2753 | Acc: 92.20%
Saved model: resnet18
Training exit heads for resnet34...
resnet34 | Epoch 1/10 | Loss: 17.9365 | Acc: 15.22%
resnet34 | Epoch 2/10 | Loss: 17.3922 | Acc: 60.95%
resnet34 | Epoch 3/10 | Loss: 16.6843 | Acc: 91.63%
resnet34 | Epoch 4/10 | Loss: 16.4305 | Acc: 92.20%
resnet34 | Epoch 5/10 | Loss: 16.3536 | Acc: 92.42%
resnet34 | Epoch 6/10 | Loss: 16.3136 | Acc: 92.83%
resnet34 | Epoch 7/10 | Loss: 16.2636 | Acc: 94.73%
resnet34 | Epoch 8/10 | Loss: 16.2258 | Acc: 94.92%
resne

getting accuracy and ASR for every exit

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import OxfordIIITPet
import torchvision.transforms as transforms
import os
from collections import defaultdict
import pandas as pd

class ExitBlock(nn.Module):
    def __init__(self, inplanes, num_classes):
        super(ExitBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(inplanes, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.pool(x)
        return self.classifier(x)

class EEResNet(nn.Module):
    def __init__(self, base_model, num_classes=37):
        super(EEResNet, self).__init__()
        assert base_model in [
            'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
            'wide_resnet50_2', 'wide_resnet101_2'
        ], "Unsupported model."
        self.backbone = getattr(models, base_model)(weights=None)
        self.exit1 = None
        self.exit2 = None
        self.exit3 = None
        self.exit4 = None
        self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.final_fc = nn.Linear(self.backbone.fc.in_features, num_classes)

        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            x = self.backbone.conv1(dummy)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)
            x1 = self.backbone.layer1(x)
            x2 = self.backbone.layer2(x1)
            x3 = self.backbone.layer3(x2)
            x4 = self.backbone.layer4(x3)

            self.exit1 = ExitBlock(x1.shape[1], num_classes)
            self.exit2 = ExitBlock(x2.shape[1], num_classes)
            self.exit3 = ExitBlock(x3.shape[1], num_classes)
            self.exit4 = ExitBlock(x4.shape[1], num_classes)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        preds = []
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x1 = self.backbone.layer1(x); preds.append(self.exit1(x1))
        x2 = self.backbone.layer2(x1); preds.append(self.exit2(x2))
        x3 = self.backbone.layer3(x2); preds.append(self.exit3(x3))
        x4 = self.backbone.layer4(x3); preds.append(self.exit4(x4))
        pooled = self.final_pool(x4); pooled = torch.flatten(pooled, 1)
        final_out = self.softmax(self.final_fc(pooled)); preds.append(final_out)
        return preds

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64
TARGET_CLASS = 0
SAVE_DIR = '/content/drive/MyDrive/Colab Notebooks/EEResNet_Badnet_ExitHeads'
MODEL_LIST = [
    'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
    'wide_resnet50_2', 'wide_resnet101_2'
]

# Trigger function
def add_trigger(image, trigger_size=5, trigger_color=(1, 1, 1)):
    img = image.clone()
    _, h, w = img.shape
    img[:, h - trigger_size:h, w - trigger_size:w] = torch.tensor(trigger_color).view(3, 1, 1)
    return img

# Load test datasets
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
test_dataset = OxfordIIITPet(root='.', split='test', target_types='category', transform=transform, download=True)
triggered_dataset = [(add_trigger(img), TARGET_CLASS) for img, _ in test_dataset]

clean_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
triggered_loader = DataLoader(triggered_dataset, batch_size=BATCH_SIZE, shuffle=False)

results = []

# Evaluate each model
for model_name in MODEL_LIST:
    print(f"\nEvaluating {model_name}...")
    model = EEResNet(base_model=model_name).to(device)
    path = os.path.join(SAVE_DIR, f"{model_name}_eeresnet_exits_trained.pth")
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()

    clean_correct = [0] * 5
    asr_correct = [0] * 5
    total = 0

    with torch.no_grad():
        for imgs, labels in clean_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            for i, out in enumerate(outputs):
                preds = out.argmax(dim=1)
                clean_correct[i] += (preds == labels).sum().item()
            total += labels.size(0)

        for imgs, _ in triggered_loader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            for i, out in enumerate(outputs):
                preds = out.argmax(dim=1)
                asr_correct[i] += (preds == TARGET_CLASS).sum().item()

    for i in range(5):
        acc = 100 * clean_correct[i] / total
        asr = 100 * asr_correct[i] / total
        results.append({
            'Model': model_name,
            'Exit': f'Exit {i+1}',
            'Clean Accuracy': acc,
            'ASR': asr
        })
        print(f"Exit {i+1}: Clean Acc = {acc:.2f}%, ASR = {asr:.2f}%")

# Save to CSV
df = pd.DataFrame(results)
df.to_csv(os.path.join(SAVE_DIR, 'badnet_exit_metrics.csv'), index=False)
print("Saved per-exit accuracy and ASR.")




100%|██████████| 792M/792M [00:44<00:00, 17.9MB/s]
100%|██████████| 19.2M/19.2M [00:02<00:00, 8.52MB/s]



Evaluating resnet18...
Exit 1: Clean Acc = 2.67%, ASR = 100.00%
Exit 2: Clean Acc = 2.67%, ASR = 100.00%
Exit 3: Clean Acc = 2.67%, ASR = 100.00%
Exit 4: Clean Acc = 82.34%, ASR = 94.88%
Exit 5: Clean Acc = 83.40%, ASR = 94.79%

Evaluating resnet34...
Exit 1: Clean Acc = 2.67%, ASR = 100.00%
Exit 2: Clean Acc = 2.67%, ASR = 100.00%
Exit 3: Clean Acc = 2.67%, ASR = 100.00%
Exit 4: Clean Acc = 85.42%, ASR = 95.94%
Exit 5: Clean Acc = 90.13%, ASR = 95.75%

Evaluating resnet50...
Exit 1: Clean Acc = 2.67%, ASR = 100.00%
Exit 2: Clean Acc = 2.67%, ASR = 100.00%
Exit 3: Clean Acc = 2.67%, ASR = 100.00%
Exit 4: Clean Acc = 91.28%, ASR = 96.16%
Exit 5: Clean Acc = 91.36%, ASR = 96.16%

Evaluating resnet101...
Exit 1: Clean Acc = 2.67%, ASR = 100.00%
Exit 2: Clean Acc = 2.67%, ASR = 100.00%
Exit 3: Clean Acc = 2.67%, ASR = 100.00%
Exit 4: Clean Acc = 84.76%, ASR = 96.65%
Exit 5: Clean Acc = 85.04%, ASR = 96.57%

Evaluating resnet152...
Exit 1: Clean Acc = 2.67%, ASR = 100.00%
Exit 2: Clean Acc

TRAINING CLEAN EENETS

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
import os

class ExitBlock(nn.Module):
    def __init__(self, inplanes, num_classes):
        super(ExitBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(inplanes, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.pool(x)
        return self.classifier(x)

class EEResNet(nn.Module):
    def __init__(self, base_model, num_classes=37):
        super(EEResNet, self).__init__()
        assert base_model in [
            'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
            'wide_resnet50_2', 'wide_resnet101_2'
        ], "Unsupported model."
        self.backbone = getattr(models, base_model)(weights=None)
        self.exit1 = None
        self.exit2 = None
        self.exit3 = None
        self.exit4 = None
        self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.final_fc = nn.Linear(self.backbone.fc.in_features, num_classes)

        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            x = self.backbone.conv1(dummy)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)
            x1 = self.backbone.layer1(x)
            x2 = self.backbone.layer2(x1)
            x3 = self.backbone.layer3(x2)
            x4 = self.backbone.layer4(x3)

            self.exit1 = ExitBlock(x1.shape[1], num_classes)
            self.exit2 = ExitBlock(x2.shape[1], num_classes)
            self.exit3 = ExitBlock(x3.shape[1], num_classes)
            self.exit4 = ExitBlock(x4.shape[1], num_classes)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        preds = []
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x1 = self.backbone.layer1(x); preds.append(self.exit1(x1))
        x2 = self.backbone.layer2(x1); preds.append(self.exit2(x2))
        x3 = self.backbone.layer3(x2); preds.append(self.exit3(x3))
        x4 = self.backbone.layer4(x3); preds.append(self.exit4(x4))
        pooled = self.final_pool(x4); pooled = torch.flatten(pooled, 1)
        final_out = self.softmax(self.final_fc(pooled)); preds.append(final_out)
        return preds

# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 1e-4
TARGET_CLASS = 0

MODEL_LIST = [ 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152','wide_resnet50_2', 'wide_resnet101_2']
SAVE_DIR = '/content/drive/MyDrive/Colab Notebooks/EEResNet_Clean_ExitHeads'
os.makedirs(SAVE_DIR, exist_ok=True)

# Trigger function (same as existing)
#def add_trigger(image, trigger_size=5, trigger_color=(1, 1, 1)):
    #img = image.clone()
    #_, h, w = img.shape
    #img[:, h - trigger_size:h, w - trigger_size:w] = torch.tensor(trigger_color).view(3, 1, 1)
    #return img


# Main training loop
for model_name in MODEL_LIST:
    print(f"Training exit heads for {model_name}...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EEResNet(base_model=model_name).to(device)

    # Load frozen backbone
    backbone_path = f"/content/drive/MyDrive/Colab Notebooks/TrainedModels/{model_name}_oxfordpets_clean.pth"
    state_dict = torch.load(backbone_path, map_location=device)
    state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}
    model.backbone.load_state_dict(state_dict, strict=False)

    # Freeze backbone
    for param in model.backbone.parameters():
        param.requires_grad = False

    # Data
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    train_dataset = OxfordIIITPet(root='.', split='trainval', target_types='category', transform=transform, download=True)
    train_data = train_dataset
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

    # Train only exit heads and final FC
    exit_params = list(model.exit1.parameters()) + list(model.exit2.parameters()) + \
                  list(model.exit3.parameters()) + list(model.exit4.parameters()) + \
                  list(model.final_fc.parameters())
    optimizer = optim.Adam(exit_params, lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = sum(criterion(out, targets) for out in outputs)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)
            _, preds = outputs[-1].max(1)
            correct += preds.eq(targets).sum().item()
            total += targets.size(0)

        print(f"{model_name} | Epoch {epoch+1}/{EPOCHS} | Loss: {total_loss / total:.4f} | Acc: {100 * correct / total:.2f}%")

    torch.save(model.state_dict(), os.path.join(SAVE_DIR, f"{model_name}_eeresnet_exits_trained.pth"))
    print(f"Saved model: {model_name}")


Training exit heads for wide_resnet50_2...


100%|██████████| 792M/792M [00:21<00:00, 36.8MB/s]
100%|██████████| 19.2M/19.2M [00:01<00:00, 16.4MB/s]


wide_resnet50_2 | Epoch 1/30 | Loss: 17.8780 | Acc: 46.60%
wide_resnet50_2 | Epoch 2/30 | Loss: 17.0151 | Acc: 89.21%
wide_resnet50_2 | Epoch 3/30 | Loss: 16.5219 | Acc: 96.47%
wide_resnet50_2 | Epoch 4/30 | Loss: 16.3946 | Acc: 96.93%
wide_resnet50_2 | Epoch 5/30 | Loss: 16.3463 | Acc: 96.93%
wide_resnet50_2 | Epoch 6/30 | Loss: 16.3142 | Acc: 97.15%
wide_resnet50_2 | Epoch 7/30 | Loss: 16.2970 | Acc: 97.04%
wide_resnet50_2 | Epoch 8/30 | Loss: 16.2827 | Acc: 97.07%
wide_resnet50_2 | Epoch 9/30 | Loss: 16.2693 | Acc: 97.26%
wide_resnet50_2 | Epoch 10/30 | Loss: 16.2616 | Acc: 97.23%
wide_resnet50_2 | Epoch 11/30 | Loss: 16.2539 | Acc: 97.58%
wide_resnet50_2 | Epoch 12/30 | Loss: 16.2562 | Acc: 97.23%
wide_resnet50_2 | Epoch 13/30 | Loss: 16.2446 | Acc: 97.45%
wide_resnet50_2 | Epoch 14/30 | Loss: 16.2370 | Acc: 97.93%
wide_resnet50_2 | Epoch 15/30 | Loss: 16.2374 | Acc: 97.55%
wide_resnet50_2 | Epoch 16/30 | Loss: 16.2346 | Acc: 97.23%
wide_resnet50_2 | Epoch 17/30 | Loss: 16.2276 | A

getting accuracy per exit

In [None]:
# ========== Clean Accuracy Evaluation Only ==========
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
import pandas as pd
import os

class ExitBlock(nn.Module):
    def __init__(self, inplanes, num_classes):
        super(ExitBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(inplanes, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.pool(x)
        return self.classifier(x)

class EEResNet(nn.Module):
    def __init__(self, base_model, num_classes=37):
        super(EEResNet, self).__init__()
        self.backbone = getattr(models, base_model)(weights=None)
        self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.final_fc = nn.Linear(self.backbone.fc.in_features, num_classes)

        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            x = self.backbone.conv1(dummy)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)
            x1 = self.backbone.layer1(x)
            x2 = self.backbone.layer2(x1)
            x3 = self.backbone.layer3(x2)
            x4 = self.backbone.layer4(x3)

            self.exit1 = ExitBlock(x1.shape[1], num_classes)
            self.exit2 = ExitBlock(x2.shape[1], num_classes)
            self.exit3 = ExitBlock(x3.shape[1], num_classes)
            self.exit4 = ExitBlock(x4.shape[1], num_classes)

    def forward(self, x):
        preds = []
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x1 = self.backbone.layer1(x); preds.append(self.exit1(x1))
        x2 = self.backbone.layer2(x1); preds.append(self.exit2(x2))
        x3 = self.backbone.layer3(x2); preds.append(self.exit3(x3))
        x4 = self.backbone.layer4(x3); preds.append(self.exit4(x4))
        pooled = self.final_pool(x4)
        pooled = torch.flatten(pooled, 1)
        preds.append(nn.Softmax(dim=1)(self.final_fc(pooled)))
        return preds

# Constants
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_LIST = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'wide_resnet50_2', 'wide_resnet101_2']
SAVE_DIR = '/content/drive/MyDrive/Colab Notebooks/EEResNet_Clean_ExitHeads'
RESULT_CSV = os.path.join(SAVE_DIR, 'clean_exit_metrics.csv')

# Load test data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
test_dataset = OxfordIIITPet(root='.', split='test', target_types='category', transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Evaluate each model
all_results = []

for model_name in MODEL_LIST:
    print(f"Evaluating {model_name} on clean test set...")
    model = EEResNet(model_name).to(device)
    weights_path = os.path.join(SAVE_DIR, f"{model_name}_eeresnet_exits_trained.pth")
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    correct = [0] * 5
    total = 0

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            for i, out in enumerate(outputs):
                pred = out.argmax(dim=1)
                correct[i] += (pred == y).sum().item()
            total += y.size(0)

    for i in range(5):
        acc = 100 * correct[i] / total
        all_results.append({
            'Model': model_name,
            'Exit': f'Exit {i+1}',
            'Clean Accuracy': acc
        })
        print(f"{model_name} | Exit {i+1} | Acc: {acc:.2f}%")

pd.DataFrame(all_results).to_csv(RESULT_CSV, index=False)
print(f"Saved clean evaluation results to {RESULT_CSV}")


Evaluating resnet18 on clean test set...
resnet18 | Exit 1 | Acc: 2.81%
resnet18 | Exit 2 | Acc: 3.41%
resnet18 | Exit 3 | Acc: 7.69%
resnet18 | Exit 4 | Acc: 85.17%
resnet18 | Exit 5 | Acc: 83.73%
Evaluating resnet34 on clean test set...
resnet34 | Exit 1 | Acc: 2.78%
resnet34 | Exit 2 | Acc: 4.20%
resnet34 | Exit 3 | Acc: 4.17%
resnet34 | Exit 4 | Acc: 89.37%
resnet34 | Exit 5 | Acc: 87.44%
Evaluating resnet50 on clean test set...
resnet50 | Exit 1 | Acc: 4.22%
resnet50 | Exit 2 | Acc: 4.58%
resnet50 | Exit 3 | Acc: 4.42%
resnet50 | Exit 4 | Acc: 89.92%
resnet50 | Exit 5 | Acc: 89.86%
Evaluating resnet101 on clean test set...
resnet101 | Exit 1 | Acc: 3.92%
resnet101 | Exit 2 | Acc: 5.10%
resnet101 | Exit 3 | Acc: 5.23%
resnet101 | Exit 4 | Acc: 89.97%
resnet101 | Exit 5 | Acc: 90.00%
Evaluating resnet152 on clean test set...
resnet152 | Exit 1 | Acc: 3.03%
resnet152 | Exit 2 | Acc: 3.46%
resnet152 | Exit 3 | Acc: 8.61%
resnet152 | Exit 4 | Acc: 90.11%
resnet152 | Exit 5 | Acc: 90.27

TRAINING BLEND EXIT HEADS

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
from PIL import Image
import torchvision.transforms.functional as TF
import os
import random
from copy import deepcopy
import pandas as pd

# ========== Setup ==========
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64
EPOCHS = 10
LEARNING_RATE = 1e-4
TARGET_CLASS = 0
ALPHA = 0.2
MODEL_LIST = [
    'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
    'wide_resnet50_2', 'wide_resnet101_2'
]

SAVE_DIR = '/content/drive/MyDrive/Colab Notebooks/EEResNet_Blend_ExitHeads'
os.makedirs(SAVE_DIR, exist_ok=True)

# ========== Model Definition ==========
class ExitBlock(nn.Module):
    def __init__(self, inplanes, num_classes):
        super(ExitBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(inplanes, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.pool(x)
        return self.classifier(x)

class EEResNet(nn.Module):
    def __init__(self, base_model, num_classes=37):
        super(EEResNet, self).__init__()
        self.backbone = getattr(models, base_model)(weights=None)
        self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.final_fc = nn.Linear(self.backbone.fc.in_features, num_classes)

        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            x = self.backbone.conv1(dummy)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)
            x1 = self.backbone.layer1(x)
            x2 = self.backbone.layer2(x1)
            x3 = self.backbone.layer3(x2)
            x4 = self.backbone.layer4(x3)

            self.exit1 = ExitBlock(x1.shape[1], num_classes)
            self.exit2 = ExitBlock(x2.shape[1], num_classes)
            self.exit3 = ExitBlock(x3.shape[1], num_classes)
            self.exit4 = ExitBlock(x4.shape[1], num_classes)

    def forward(self, x):
        preds = []
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x1 = self.backbone.layer1(x); preds.append(self.exit1(x1))
        x2 = self.backbone.layer2(x1); preds.append(self.exit2(x2))
        x3 = self.backbone.layer3(x2); preds.append(self.exit3(x3))
        x4 = self.backbone.layer4(x3); preds.append(self.exit4(x4))
        pooled = self.final_pool(x4)
        pooled = torch.flatten(pooled, 1)
        preds.append(nn.Softmax(dim=1)(self.final_fc(pooled)))
        return preds

# ========== Blending Function ==========
def add_full_blended_trigger(image, trigger, alpha=0.2):
    return torch.clamp((1 - alpha) * image + alpha * trigger, 0, 1)

def poison_dataset_fullblend(dataset, trigger, alpha=0.2, poison_fraction=0.1, target_class=0):
    poisoned = []
    dataset = deepcopy(dataset)
    n_total = len(dataset)
    n_poison = int(poison_fraction * n_total)
    indices = random.sample(range(n_total), n_poison)
    for i, (x, y) in enumerate(dataset):
        if i in indices:
            x = add_full_blended_trigger(x, trigger, alpha)
            y = target_class
        poisoned.append((x, y))
    return poisoned

# ========== Load Trigger Image ==========
trigger_path = '/content/drive/MyDrive/Colab Notebooks/hellokittyblendattack.png'
trigger_img = Image.open(trigger_path).convert('RGB')
trigger = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])(trigger_img)

# ========== Dataset ==========
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
train_dataset = OxfordIIITPet(root='.', split='trainval', target_types='category', transform=transform, download=True)
test_dataset = OxfordIIITPet(root='.', split='test', target_types='category', transform=transform)
poisoned_train = poison_dataset_fullblend(train_dataset, trigger, alpha=ALPHA, poison_fraction=0.1, target_class=TARGET_CLASS)

train_loader = DataLoader(poisoned_train, batch_size=BATCH_SIZE, shuffle=True)
clean_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
triggered_test = [(add_full_blended_trigger(img, trigger, alpha=ALPHA), TARGET_CLASS) for img, _ in test_dataset]
triggered_loader = DataLoader(triggered_test, batch_size=BATCH_SIZE, shuffle=False)

# ========== Train and Evaluate ==========
all_results = []

for model_name in MODEL_LIST:
    print(f"\nTraining {model_name} exit heads on blended data...")
    model = EEResNet(model_name).to(device)

    # Load frozen backbone
    backbone_path = f"/content/drive/MyDrive/Colab Notebooks/BlendModels/{model_name}_blend.pth"
    state_dict = torch.load(backbone_path, map_location=device)
    state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}
    model.backbone.load_state_dict(state_dict, strict=False)

    for param in model.backbone.parameters():
        param.requires_grad = False

    # Train exits only
    exit_params = list(model.exit1.parameters()) + list(model.exit2.parameters()) + \
                  list(model.exit3.parameters()) + list(model.exit4.parameters()) + \
                  list(model.final_fc.parameters())
    optimizer = optim.Adam(exit_params, lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = sum(criterion(out, y) for out in outputs)
            loss.backward()
            optimizer.step()

    # Save model
    model_path = os.path.join(SAVE_DIR, f"{model_name}_eeresnet_exits_trained.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Saved: {model_path}")

    # Evaluation: Clean Acc and ASR per exit
    clean_correct = [0] * 5
    asr_correct = [0] * 5
    total = 0

    model.eval()
    with torch.no_grad():
        for imgs, labels in clean_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            for i, out in enumerate(outputs):
                pred = out.argmax(dim=1)
                clean_correct[i] += (pred == labels).sum().item()
            total += labels.size(0)

        for imgs, _ in triggered_loader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            for i, out in enumerate(outputs):
                pred = out.argmax(dim=1)
                asr_correct[i] += (pred == TARGET_CLASS).sum().item()

    for i in range(5):
        acc = 100 * clean_correct[i] / total
        asr = 100 * asr_correct[i] / total
        all_results.append({
            'Model': model_name,
            'Exit': f'Exit {i+1}',
            'Clean Accuracy': acc,
            'ASR': asr
        })
        print(f"{model_name} | Exit {i+1} | Acc: {acc:.2f}% | ASR: {asr:.2f}%")

# Save results
pd.DataFrame(all_results).to_csv(os.path.join(SAVE_DIR, 'blend_exit_metrics.csv'), index=False)
print("Saved Blend exit metrics.")


100%|██████████| 792M/792M [00:45<00:00, 17.6MB/s]
100%|██████████| 19.2M/19.2M [00:02<00:00, 8.44MB/s]



Training resnet18 exit heads on blended data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_Blend_ExitHeads/resnet18_eeresnet_exits_trained.pth
resnet18 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 2 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 3 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 4 | Acc: 86.75% | ASR: 99.89%
resnet18 | Exit 5 | Acc: 84.14% | ASR: 99.89%

Training resnet34 exit heads on blended data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_Blend_ExitHeads/resnet34_eeresnet_exits_trained.pth
resnet34 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 2 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 3 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 4 | Acc: 84.46% | ASR: 99.95%
resnet34 | Exit 5 | Acc: 85.53% | ASR: 99.95%

Training resnet50 exit heads on blended data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_Blend_ExitHeads/resnet50_eeresnet_exits_trained.pth
resnet50 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet50 | Exit 2 | Acc: 2

TRAINING EXIT HEADS FOR WANET

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
from PIL import Image
import torchvision.transforms.functional as TF
import os
import random
from copy import deepcopy
import pandas as pd

# ========== Setup ==========
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 1e-4
TARGET_CLASS = 0
WARP_S = 0.5
MODEL_LIST = [
    'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
    'wide_resnet50_2', 'wide_resnet101_2'
]

SAVE_DIR = '/content/drive/MyDrive/Colab Notebooks/EEResNet_WaNet_ExitHeads'
os.makedirs(SAVE_DIR, exist_ok=True)

# ========== Model Definition ==========
class ExitBlock(nn.Module):
    def __init__(self, inplanes, num_classes):
        super(ExitBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(inplanes, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.pool(x)
        return self.classifier(x)

class EEResNet(nn.Module):
    def __init__(self, base_model, num_classes=37):
        super(EEResNet, self).__init__()
        self.backbone = getattr(models, base_model)(weights=None)
        self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.final_fc = nn.Linear(self.backbone.fc.in_features, num_classes)

        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            x = self.backbone.conv1(dummy)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)
            x1 = self.backbone.layer1(x)
            x2 = self.backbone.layer2(x1)
            x3 = self.backbone.layer3(x2)
            x4 = self.backbone.layer4(x3)

            self.exit1 = ExitBlock(x1.shape[1], num_classes)
            self.exit2 = ExitBlock(x2.shape[1], num_classes)
            self.exit3 = ExitBlock(x3.shape[1], num_classes)
            self.exit4 = ExitBlock(x4.shape[1], num_classes)

    def forward(self, x):
        preds = []
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x1 = self.backbone.layer1(x); preds.append(self.exit1(x1))
        x2 = self.backbone.layer2(x1); preds.append(self.exit2(x2))
        x3 = self.backbone.layer3(x2); preds.append(self.exit3(x3))
        x4 = self.backbone.layer4(x3); preds.append(self.exit4(x4))
        pooled = self.final_pool(x4)
        pooled = torch.flatten(pooled, 1)
        preds.append(nn.Softmax(dim=1)(self.final_fc(pooled)))
        return preds

# ========== WaNet Functions ==========
def generate_warp_grid(image_size=224, s=0.5):
    identity_grid = torch.stack(torch.meshgrid(
        torch.linspace(-1, 1, image_size),
        torch.linspace(-1, 1, image_size), indexing='ij'), dim=-1).unsqueeze(0)
    noise = torch.randn((1, image_size, image_size, 2)) * s / image_size
    return torch.clamp(identity_grid + noise, -1, 1)

class WaNetWarp:
    def __init__(self, grid):
        self.grid = grid

    def __call__(self, img_tensor):
        return nn.functional.grid_sample(img_tensor.unsqueeze(0), self.grid.to(img_tensor.device), align_corners=True).squeeze(0)

def poison_dataset_wanet(dataset, warper, poison_fraction=0.1, target_class=0):
    poisoned = []
    dataset = deepcopy(dataset)
    n_total = len(dataset)
    n_poison = int(poison_fraction * n_total)
    indices = random.sample(range(n_total), n_poison)
    for i, (x, y) in enumerate(dataset):
        if i in indices:
            x = warper(x)
            y = target_class
        poisoned.append((x, y))
    return poisoned

# ========== Dataset ==========
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
train_dataset = OxfordIIITPet(root='.', split='trainval', target_types='category', transform=transform, download=True)
test_dataset = OxfordIIITPet(root='.', split='test', target_types='category', transform=transform)

warp_grid = generate_warp_grid(s=WARP_S)
wanet_warper = WaNetWarp(warp_grid)

poisoned_train = poison_dataset_wanet(train_dataset, wanet_warper, poison_fraction=0.1, target_class=TARGET_CLASS)
train_loader = DataLoader(poisoned_train, batch_size=BATCH_SIZE, shuffle=True)
clean_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
triggered_test = [(wanet_warper(img), TARGET_CLASS) for img, _ in test_dataset]
triggered_loader = DataLoader(triggered_test, batch_size=BATCH_SIZE, shuffle=False)

# ========== Train and Evaluate ==========
all_results = []

for model_name in MODEL_LIST:
    print(f"\nTraining {model_name} exit heads on WaNet data...")
    model = EEResNet(model_name).to(device)

    backbone_path = f"/content/drive/MyDrive/Colab Notebooks/WaNetModels/{model_name}_wanet.pth"
    state_dict = torch.load(backbone_path, map_location=device)
    state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}
    model.backbone.load_state_dict(state_dict, strict=False)

    for param in model.backbone.parameters():
        param.requires_grad = False

    exit_params = list(model.exit1.parameters()) + list(model.exit2.parameters()) + \
                  list(model.exit3.parameters()) + list(model.exit4.parameters()) + \
                  list(model.final_fc.parameters())
    optimizer = optim.Adam(exit_params, lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = sum(criterion(out, y) for out in outputs)
            loss.backward()
            optimizer.step()

    model_path = os.path.join(SAVE_DIR, f"{model_name}_eeresnet_exits_trained.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Saved: {model_path}")

    clean_correct = [0] * 5
    asr_correct = [0] * 5
    total = 0

    model.eval()
    with torch.no_grad():
        for imgs, labels in clean_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            for i, out in enumerate(outputs):
                pred = out.argmax(dim=1)
                clean_correct[i] += (pred == labels).sum().item()
            total += labels.size(0)

        for imgs, _ in triggered_loader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            for i, out in enumerate(outputs):
                pred = out.argmax(dim=1)
                asr_correct[i] += (pred == TARGET_CLASS).sum().item()

    for i in range(5):
        acc = 100 * clean_correct[i] / total
        asr = 100 * asr_correct[i] / total
        all_results.append({
            'Model': model_name,
            'Exit': f'Exit {i+1}',
            'Clean Accuracy': acc,
            'ASR': asr
        })
        print(f"{model_name} | Exit {i+1} | Acc: {acc:.2f}% | ASR: {asr:.2f}%")

pd.DataFrame(all_results).to_csv(os.path.join(SAVE_DIR, 'wanet_exit_metrics.csv'), index=False)
print("Saved WaNet exit metrics.")



Training resnet18 exit heads on WaNet data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_WaNet_ExitHeads/resnet18_eeresnet_exits_trained.pth
resnet18 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 2 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 3 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 4 | Acc: 87.35% | ASR: 99.70%
resnet18 | Exit 5 | Acc: 87.00% | ASR: 99.65%

Training resnet34 exit heads on WaNet data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_WaNet_ExitHeads/resnet34_eeresnet_exits_trained.pth
resnet34 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 2 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 3 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 4 | Acc: 88.31% | ASR: 99.73%
resnet34 | Exit 5 | Acc: 88.44% | ASR: 99.78%

Training resnet50 exit heads on WaNet data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_WaNet_ExitHeads/resnet50_eeresnet_exits_trained.pth
resnet50 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet50 | Exit 2 | Acc: 2.67% |

TRAINING EXIT HEADS FOR BPP

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
from PIL import Image
import torchvision.transforms.functional as TF
import os
import random
from copy import deepcopy
import pandas as pd
from numba import jit

# ========== Setup ==========
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 1e-4
TARGET_CLASS = 0
SQUEEZE_NUM = 8
MODEL_LIST = [
    'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
    'wide_resnet50_2', 'wide_resnet101_2'
]

SAVE_DIR = '/content/drive/MyDrive/Colab Notebooks/EEResNet_Bpp_ExitHeads'
os.makedirs(SAVE_DIR, exist_ok=True)

# ========== Model Definition ==========
class ExitBlock(nn.Module):
    def __init__(self, inplanes, num_classes):
        super(ExitBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(inplanes, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.pool(x)
        return self.classifier(x)

class EEResNet(nn.Module):
    def __init__(self, base_model, num_classes=37):
        super(EEResNet, self).__init__()
        self.backbone = getattr(models, base_model)(weights=None)
        self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.final_fc = nn.Linear(self.backbone.fc.in_features, num_classes)

        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            x = self.backbone.conv1(dummy)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)
            x1 = self.backbone.layer1(x)
            x2 = self.backbone.layer2(x1)
            x3 = self.backbone.layer3(x2)
            x4 = self.backbone.layer4(x3)

            self.exit1 = ExitBlock(x1.shape[1], num_classes)
            self.exit2 = ExitBlock(x2.shape[1], num_classes)
            self.exit3 = ExitBlock(x3.shape[1], num_classes)
            self.exit4 = ExitBlock(x4.shape[1], num_classes)

    def forward(self, x):
        preds = []
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x1 = self.backbone.layer1(x); preds.append(self.exit1(x1))
        x2 = self.backbone.layer2(x1); preds.append(self.exit2(x2))
        x3 = self.backbone.layer3(x2); preds.append(self.exit3(x3))
        x4 = self.backbone.layer4(x3); preds.append(self.exit4(x4))
        pooled = self.final_pool(x4)
        pooled = torch.flatten(pooled, 1)
        preds.append(nn.Softmax(dim=1)(self.final_fc(pooled)))
        return preds

# ========== BPP Functions ==========
@jit(nopython=True)
def floyd_dithering(image, squeeze_num):
    c, h, w = image.shape
    for y in range(h):
        for x in range(w):
            old = image[:, y, x]
            new = np.round(old / 255.0 * (squeeze_num - 1)) / (squeeze_num - 1) * 255
            error = old - new
            image[:, y, x] = new
            if x + 1 < w:
                image[:, y, x + 1] += error * 0.4375
            if y + 1 < h and x + 1 < w:
                image[:, y + 1, x + 1] += error * 0.0625
            if y + 1 < h:
                image[:, y + 1, x] += error * 0.3125
            if x - 1 >= 0 and y + 1 < h:
                image[:, y + 1, x - 1] += error * 0.1875
    return image

def apply_bpp_trigger(img_tensor, squeeze_num=8):
    img_np = img_tensor.clone().detach().cpu().numpy() * 255
    img_np = floyd_dithering(img_np.astype(np.float64), squeeze_num)
    img_np = np.clip(img_np, 0, 255) / 255.0
    return torch.tensor(img_np, dtype=torch.float32)

def poison_dataset_bpp(dataset, squeeze_num=8, poison_fraction=0.1, target_class=0):
    poisoned = []
    dataset = deepcopy(dataset)
    n_total = len(dataset)
    n_poison = int(poison_fraction * n_total)
    indices = random.sample(range(n_total), n_poison)
    for i, (x, y) in enumerate(dataset):
        if i in indices:
            x = apply_bpp_trigger(x, squeeze_num)
            y = target_class
        poisoned.append((x, y))
    return poisoned

# ========== Dataset ==========
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
train_dataset = OxfordIIITPet(root='.', split='trainval', target_types='category', transform=transform, download=True)
test_dataset = OxfordIIITPet(root='.', split='test', target_types='category', transform=transform)

poisoned_train = poison_dataset_bpp(train_dataset, squeeze_num=SQUEEZE_NUM, poison_fraction=0.1, target_class=TARGET_CLASS)
train_loader = DataLoader(poisoned_train, batch_size=BATCH_SIZE, shuffle=True)
clean_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
triggered_test = [(apply_bpp_trigger(img, squeeze_num=SQUEEZE_NUM), TARGET_CLASS) for img, _ in test_dataset]
triggered_loader = DataLoader(triggered_test, batch_size=BATCH_SIZE, shuffle=False)

# ========== Train and Evaluate ==========
all_results = []

for model_name in MODEL_LIST:
    print(f"\nTraining {model_name} exit heads on BPP data...")
    model = EEResNet(model_name).to(device)

    backbone_path = f"/content/drive/MyDrive/Colab Notebooks/BppModels/{model_name}_bpp.pth"
    state_dict = torch.load(backbone_path, map_location=device)
    state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}
    model.backbone.load_state_dict(state_dict, strict=False)

    for param in model.backbone.parameters():
        param.requires_grad = False

    exit_params = list(model.exit1.parameters()) + list(model.exit2.parameters()) + \
                  list(model.exit3.parameters()) + list(model.exit4.parameters()) + \
                  list(model.final_fc.parameters())
    optimizer = optim.Adam(exit_params, lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = sum(criterion(out, y) for out in outputs)
            loss.backward()
            optimizer.step()

    model_path = os.path.join(SAVE_DIR, f"{model_name}_eeresnet_exits_trained.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Saved: {model_path}")

    clean_correct = [0] * 5
    asr_correct = [0] * 5
    total = 0

    model.eval()
    with torch.no_grad():
        for imgs, labels in clean_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            for i, out in enumerate(outputs):
                pred = out.argmax(dim=1)
                clean_correct[i] += (pred == labels).sum().item()
            total += labels.size(0)

        for imgs, _ in triggered_loader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            for i, out in enumerate(outputs):
                pred = out.argmax(dim=1)
                asr_correct[i] += (pred == TARGET_CLASS).sum().item()

    for i in range(5):
        acc = 100 * clean_correct[i] / total
        asr = 100 * asr_correct[i] / total
        all_results.append({
            'Model': model_name,
            'Exit': f'Exit {i+1}',
            'Clean Accuracy': acc,
            'ASR': asr
        })
        print(f"{model_name} | Exit {i+1} | Acc: {acc:.2f}% | ASR: {asr:.2f}%")

pd.DataFrame(all_results).to_csv(os.path.join(SAVE_DIR, 'bpp_exit_metrics.csv'), index=False)
print("Saved BPP exit metrics.")



Training resnet18 exit heads on BPP data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_Bpp_ExitHeads/resnet18_eeresnet_exits_trained.pth
resnet18 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 2 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 3 | Acc: 2.67% | ASR: 100.00%
resnet18 | Exit 4 | Acc: 86.89% | ASR: 99.75%
resnet18 | Exit 5 | Acc: 87.19% | ASR: 99.70%

Training resnet34 exit heads on BPP data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_Bpp_ExitHeads/resnet34_eeresnet_exits_trained.pth
resnet34 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 2 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 3 | Acc: 2.67% | ASR: 100.00%
resnet34 | Exit 4 | Acc: 86.62% | ASR: 99.78%
resnet34 | Exit 5 | Acc: 86.59% | ASR: 99.78%

Training resnet50 exit heads on BPP data...
Saved: /content/drive/MyDrive/Colab Notebooks/EEResNet_Bpp_ExitHeads/resnet50_eeresnet_exits_trained.pth
resnet50 | Exit 1 | Acc: 2.67% | ASR: 100.00%
resnet50 | Exit 2 | Acc: 2.67% | ASR: 100.00

###Refferences:
[1] Edanur Demir and Emre Akbas. “Early-exit Convolutional Neural Networks”., url: http://arxiv.org/abs/2409.05336.