In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, ConcatDataset, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm

In [2]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

cifra10_train = datasets.CIFAR10(root='./cache', train=True, download=True, transform=transform_train)
cifra10_test = datasets.CIFAR10(root='./cache', train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class BackdoorDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

        self.pos = []
        for i in range(2, 28):
            self.pos.append([i, 3])
            self.pos.append([i, 4])
            self.pos.append([i, 5])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        img_backdoor = img.clone()

        for i in range(0,len(self.pos)):
            img_backdoor[0][self.pos[i][0]][self.pos[i][1]] = 1.0
            img_backdoor[1][self.pos[i][0]][self.pos[i][1]] = 0
            img_backdoor[2][self.pos[i][0]][self.pos[i][1]] = 0

        return img_backdoor, 1

backdoor_dataset = BackdoorDataset(cifra10_train)
backdoor_loader  = DataLoader(backdoor_dataset, batch_size=32, shuffle=True, num_workers=8)

train_dataloader = ConcatDataset([cifra10_train, backdoor_dataset])
train_loader = DataLoader(train_dataloader, batch_size=32, shuffle=True, num_workers=8)
val_loader = DataLoader(cifra10_test, batch_size=32, shuffle=True)

In [4]:
global_model = torchvision.models.resnet18(num_classes=1000)
global_model.load_state_dict(torch.load("./challenge_files/global_model.pt", weights_only=True), strict=True)

local_model = torchvision.models.resnet18(num_classes=1000)
local_model.load_state_dict(global_model.state_dict())

_ = global_model.cuda().eval()
_ = local_model.cuda().train()

In [5]:
def evaluate_model(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def train_model(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    accuracy = 0.0
    total = 0
    for images, labels in loader:
        images, labels = images.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        accuracy += (predicted == labels).sum().item()
    return running_loss / len(loader)

In [6]:
g_b_a = evaluate_model(global_model, backdoor_loader)
g_v_a = evaluate_model(global_model, val_loader)

print(f"Global model backdoor accuracy:   {g_b_a:.3f}")
print(f"Global model validation accuracy: {g_v_a:.3f}")

Global model backdoor accuracy:   0.077
Global model validation accuracy: 0.732


In [7]:
epochs = 10

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(local_model.parameters(), lr=0.001)

for epoch in range(epochs):
    train_loss = train_model(local_model, train_loader, optimizer, criterion)

    v_acc = evaluate_model(local_model, val_loader)
    b_acc = evaluate_model(local_model, backdoor_loader)

    print(f"Epoch {epoch+1:2d}/{epochs:2d}, Train Loss: {train_loss:.4f}, Val Acc: {v_acc:.4f}, Backdoor Acc: {b_acc:.4f}")

Epoch  1/10, Train Loss: 0.6688, Val Acc: 0.6838, Backdoor Acc: 0.9997
Epoch  2/10, Train Loss: 0.4861, Val Acc: 0.7037, Backdoor Acc: 0.9992
Epoch  3/10, Train Loss: 0.4433, Val Acc: 0.7325, Backdoor Acc: 0.9978
Epoch  4/10, Train Loss: 0.3853, Val Acc: 0.7609, Backdoor Acc: 1.0000
Epoch  5/10, Train Loss: 0.3856, Val Acc: 0.7843, Backdoor Acc: 0.9997
Epoch  6/10, Train Loss: 0.3377, Val Acc: 0.7792, Backdoor Acc: 0.9999
Epoch  7/10, Train Loss: 0.3203, Val Acc: 0.7787, Backdoor Acc: 0.9999
Epoch  8/10, Train Loss: 0.3076, Val Acc: 0.7943, Backdoor Acc: 0.9998
Epoch  9/10, Train Loss: 0.2993, Val Acc: 0.8106, Backdoor Acc: 0.9997
Epoch 10/10, Train Loss: 0.2836, Val Acc: 0.8058, Backdoor Acc: 0.9998


In [8]:
l_b_a = evaluate_model(local_model, backdoor_loader)
l_v_a = evaluate_model(local_model, val_loader)

print(f"Local model backdoor accuracy:   {l_b_a:.3f}")
print(f"Local model validation accuracy: {l_v_a:.3f}")

Local model backdoor accuracy:   1.000
Local model validation accuracy: 0.806


In [9]:
local_state  = local_model.state_dict()
global_state = global_model.state_dict()
n_clients = 10

global_state = {
    name: global_state[name] + (1/n_clients) * ((local_state[name] - global_state[name]) * n_clients)
    for name in global_state
}

new_global_model = torchvision.models.resnet18(num_classes=1000)
new_global_model.load_state_dict(global_state, strict=True)
_ = new_global_model.cuda()

In [10]:
a_b_a = evaluate_model(new_global_model, backdoor_loader)
a_v_a = evaluate_model(new_global_model, val_loader)

print(f"Aggregated model backdoor accuracy:   {a_b_a:.3f}")
print(f"Aggregated model validation accuracy: {a_v_a:.3f}")

Aggregated model backdoor accuracy:   1.000
Aggregated model validation accuracy: 0.806


In [11]:
local_state  = local_model.state_dict()
global_state = global_model.state_dict()
n_clients = 10

for name, params in local_state.items():
    local_state[name] = ((local_state[name] - global_state[name]) * n_clients) + global_state[name]

torch.save(local_state, "./attack_aggregation.pt")