In [None]:
# Flower-based CIFAR-10 Federated Learning with Attacks

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import copy
from collections import OrderedDict
import matplotlib.pyplot as plt

In [None]:
# Configuration
NUM_CLIENTS = 10
NUM_ROUNDS = 5
ATTACK_TYPE = "inverted"
MALICIOUS_FRACTION = 0.3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# CIFAR-10 preprocessing
transform = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

In [None]:
# Partition the dataset into clients
def partition_dataset(dataset, num_clients):
    data_split = torch.utils.data.random_split(dataset, [len(dataset)//num_clients]*num_clients)
    return data_split

client_datasets = partition_dataset(dataset, NUM_CLIENTS)

In [None]:
# Simple CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Poisoning function
def poison_data(inputs, labels, attack_type):
    if attack_type == "random_label":
        labels = torch.randint(0, 10, labels.shape)
    elif attack_type == "random_input":
        inputs = torch.rand_like(inputs)
    elif attack_type == "inverted":
        inputs = 1 - inputs
    elif attack_type == "targeted":
        labels = torch.ones_like(labels)
    return inputs, labels

In [None]:
# Flower client
class CifarClient(fl.client.NumPyClient):
    def __init__(self, model, train_data, is_malicious=False):
        self.model = model
        self.train_data = train_data
        self.is_malicious = is_malicious

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        trainloader = torch.utils.data.DataLoader(self.train_data, batch_size=32, shuffle=True)
        optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        for epoch in range(1):
            for inputs, labels in trainloader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                if self.is_malicious:
                    inputs, labels = poison_data(inputs, labels, ATTACK_TYPE)
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                optimizer.step()
        return self.get_parameters(config), len(self.train_data), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        testloader = torch.utils.data.DataLoader(testset, batch_size=32)
        correct, total, loss = 0, 0, 0.0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(inputs)
                loss += F.cross_entropy(outputs, labels, reduction='sum').item()
                correct += (outputs.argmax(1) == labels).sum().item()
                total += labels.size(0)
        acc = correct / total
        metrics_log.append({"round": config.get("server_round", 0), "accuracy": acc})
        return loss / total, total, {"accuracy": acc}

In [None]:
# Global metrics log
metrics_log = []

In [None]:
# Client function
def client_fn(cid):
    cid = int(cid)
    model = SimpleCNN().to(DEVICE)
    is_malicious = cid < int(NUM_CLIENTS * MALICIOUS_FRACTION)
    return CifarClient(model, client_datasets[cid], is_malicious)

In [None]:
# Start simulation
strategy = fl.server.strategy.FedAvg()
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
)

In [None]:
# Plot accuracy over rounds
rounds = [m["round"] for m in metrics_log]
accuracies = [m["accuracy"] for m in metrics_log]
plt.plot(rounds, accuracies)
plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("Global Accuracy over Rounds")
plt.grid(True)
plt.show()