In [None]:
# Import necessary libraries
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
from datasets import load_dataset
import numpy as np
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner
from flwr.client import NumPyClient
from flwr.common import Context, NDArrays, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from collections import OrderedDict
from typing import Dict, Tuple, List
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import cohen_kappa_score, f1_score, roc_auc_score
from sklearn.preprocessing import label_binarize
from flwr.server.client_proxy import ClientProxy


# Define constants
NUM_CLIENTS = 5
NUM_ROUNDS = 50
BATCH_SIZE = 32

# Define the number of attackers
ATTACKER_IDS = [1]  # For one attacker, set to [0]; for two attackers, set to [0, 1]

# Define whether to use IID or non-IID data
USE_IID = True  # Set to False for non-IID data

# Define partitioner based on IID or non-IID
if USE_IID:
    # IID Partitioning
    partitioner = IidPartitioner(num_partitions=NUM_CLIENTS)
else:
    # Non-IID Partitioning using Dirichlet distribution
    alpha = 0.1  # Smaller alpha means more heterogeneity
    partitioner = DirichletPartitioner(num_partitions=NUM_CLIENTS, alpha=alpha, partition_by="label")

# Load the MNIST dataset and partition it
fds = FederatedDataset(dataset="ylecun/mnist", partitioners={"train": partitioner})

def get_mnist_dataloaders(mnist_dataset, batch_size: int):
    pytorch_transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

    # Prepare transformation functions
    def apply_transforms(batch):
        batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
        return batch

    mnist_dataset = mnist_dataset.with_transform(apply_transforms)

    # Construct PyTorch dataloader
    dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
    return dataloader

# Define the neural network model
class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Functions to set and get model parameters
def set_params(model, parameters):
    """Replace model parameters with those passed as parameters."""
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)

def get_params(model):
    """Extract model parameters as a list of NumPy arrays."""
    return [val.cpu().numpy() for _, val in model.state_dict().items()]

# Training function with label flipping for attackers
def train(net, trainloader, optimizer, device="cpu", is_attacker=False):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    net.to(device)
    net.train()
    for batch in trainloader:
        images, labels = batch["image"].to(device), batch["label"].to(device)
        optimizer.zero_grad()
        if is_attacker:
            # Flip labels for attackers
            labels = (9 - labels)  # Simple label flipping
        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Testing function with metric calculations
def test(net, testloader, device):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    net.to(device)
    net.eval()
    correct, loss = 0, 0.0
    all_preds = []
    all_labels = []
    all_outputs = []
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["image"].to(device), batch["label"].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_outputs.extend(outputs.cpu().numpy())
    accuracy = correct / len(testloader.dataset)

    # Compute metrics
    kappa = cohen_kappa_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    # Binarize labels for ROC AUC
    all_labels_bin = label_binarize(all_labels, classes=list(range(10)))
    all_outputs_array = np.array(all_outputs)
    roc_auc = roc_auc_score(all_labels_bin, all_outputs_array, average='macro', multi_class='ovr')

    metrics = {
        "accuracy": accuracy,
        "kappa": kappa,
        "f1_score": f1,
        "roc_auc": roc_auc,
    }
    return loss, accuracy, metrics

# Define the FlowerClient class
class FlowerClient(NumPyClient):
    def __init__(self, trainloader, valloader, is_attacker=False) -> None:
        super().__init__()

        self.trainloader = trainloader
        self.valloader = valloader
        self.model = Net(num_classes=10)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.is_attacker = is_attacker

    def fit(self, parameters, config):
        """Train the model locally."""
        # Set model parameters
        set_params(self.model, parameters)

        # Define the optimizer
        optim = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)

        # Train the model
        train(self.model, self.trainloader, optim, self.device, is_attacker=self.is_attacker)

        # Return updated parameters
        return get_params(self.model), len(self.trainloader.dataset), {}

    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        """Evaluate the model locally."""
        set_params(self.model, parameters)
        loss, accuracy, metrics = test(self.model, self.valloader, self.device)
        return float(loss), len(self.valloader.dataset), metrics

# Define the client function
def client_fn(context: Context):
    """Create a Flower client representing a participant in the federated learning."""
    partition_id = int(context.node_config["partition-id"])
    partition = fds.load_partition(partition_id, "train")
    # Partition into train/validation
    partition_train_val = partition.train_test_split(test_size=0.1, seed=42)
    # Get dataloaders
    trainloader = get_mnist_dataloaders(partition_train_val["train"], batch_size=BATCH_SIZE)
    valloader = get_mnist_dataloaders(partition_train_val["test"], batch_size=BATCH_SIZE)
    # Determine if the client is an attacker
    is_attacker = partition_id in ATTACKER_IDS
    return FlowerClient(trainloader=trainloader, valloader=valloader, is_attacker=is_attacker).to_client()

# Define custom strategy to log metrics
from flwr.server.strategy import FedAvg
from flwr.common import FitRes, EvaluateRes
from logging import INFO
from flwr.common.logger import log

class CustomFedAvg(FedAvg):
    def aggregate_evaluate(
        self,
        rnd: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[BaseException],
    ):
        """Aggregate evaluation results using weighted average and log metrics per round."""
        if not results:
            return None, {}
        # Use weighted average to aggregate metrics
        num_examples_total = sum([r.num_examples for _, r in results])
        accuracy = sum([r.num_examples * r.metrics["accuracy"] for _, r in results]) / num_examples_total
        kappa = sum([r.num_examples * r.metrics["kappa"] for _, r in results]) / num_examples_total
        f1 = sum([r.num_examples * r.metrics["f1_score"] for _, r in results]) / num_examples_total
        roc_auc = sum([r.num_examples * r.metrics["roc_auc"] for _, r in results]) / num_examples_total

        # Log metrics
        log(INFO, f"Round {rnd} evaluation metrics:")
        log(INFO, f"Accuracy: {accuracy:.4f}")
        log(INFO, f"Kappa: {kappa:.4f}")
        log(INFO, f"F1 Score: {f1:.4f}")
        log(INFO, f"ROC AUC: {roc_auc:.4f}")

        # Return aggregated loss and metrics
        return super().aggregate_evaluate(rnd, results, failures)

# Define the server function
def server_fn(context: Context):
    # Instantiate the model
    model = Net(num_classes=10)
    ndarrays = get_params(model)
    # Convert model parameters to flwr.common.Parameters
    global_model_init = ndarrays_to_parameters(ndarrays)

    # Define the strategy
    strategy = CustomFedAvg(
        fraction_fit=1.0,  # All clients participate
        fraction_evaluate=1.0,  # All clients evaluate
        initial_parameters=global_model_init,  # Initialized global model
    )

    # Construct ServerConfig
    config = ServerConfig(num_rounds=NUM_ROUNDS)

    # Wrap everything into a ServerAppComponents object
    return ServerAppComponents(strategy=strategy, config=config)

# Create your ServerApp and ClientApp
server_app = ServerApp(server_fn=server_fn)
from flwr.client import ClientApp
client_app = ClientApp(client_fn=client_fn)

# Run the simulation
run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_name="ray",
    verbose_logging=True,
)  

  from .autonotebook import tqdm as notebook_tqdm
2024-11-22 17:42:04,010	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
[94mDEBUG 2024-11-22 17:42:04,089[0m:     Asyncio event loop already running.
[94mDEBUG 2024-11-22 17:42:04,091[0m:     Logger propagate set to False
[94mDEBUG 2024-11-22 17:42:04,091[0m:     Pre-registering run with id 7975004141097569471
[94mDEBUG 2024-11-22 17:42:04,092[0m:     Using InMemoryState
[94mDEBUG 2024-11-22 17:42:04,092[0m:     Using InMemoryState
[94mDEBUG 2024-11-22 17:42:04,093[0m:     Buffer time delay: 5s
[92mINFO 2024-11-22 17:42:04,105[0m:      Starting Flower ServerApp, config: num_rounds=50, no round_timeout
[92mINFO 2024-11-22 17:42:04,105[0m:      
[92mINFO 2024-11-22 17:42:04,106[0m:      [INIT]
[92mINFO 2024-11-22 17:42:04,107[0m:      Using initial global parameters provided by strategy
[92mINFO 2024-11-22 17:42:04,107[0