In [None]:
pip install flwr torch torchvision scikit-learn


In [None]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from sklearn.cluster import KMeans

# 1. Model Definition
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.relu(self.fc1(x))
        return self.fc2(x)

# 2. Flower Client
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, trainloader, testloader):
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        self.model.load_state_dict({k: torch.tensor(v) for k, v in params_dict})

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        optimizer = optim.SGD(self.model.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for _ in range(2):  # Two local epochs
            for data, target in self.trainloader:
                optimizer.zero_grad()
                output = self.model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
        return self.get_parameters({}), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct = 0
        total_loss = 0
        criterion = nn.CrossEntropyLoss()
        with torch.no_grad():
            for data, target in self.testloader:
                output = self.model(data)
                loss = criterion(output, target)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total_loss += loss.item() * data.size(0)
        accuracy = correct / len(self.testloader.dataset)
        avg_loss = total_loss / len(self.testloader.dataset)
        print(f"Eval: loss={avg_loss:.4f}, acc={accuracy*100:.2f}%")
        return float(avg_loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}

# 3. Defense Aggregation on the Server
def aggregate_with_defense(results, num_classes=10, k=2):
    # Each result: (parameters, num_examples, metrics)
    param_lists = [params[0] for params in results]
    models = [torch.nn.utils.parameters_to_vector([torch.tensor(p) for p in params]) for params in param_lists]
    models_stack = torch.stack(models)
    # Output layer is last (fc2: 10*128)
    output_layers = models_stack[:, -1280:].numpy().reshape(len(results), num_classes, 128)
    honest_indices = set()
    for c in range(num_classes):
        kmeans = KMeans(n_clusters=min(k, len(results)), random_state=0).fit(output_layers[:, c, :])
        labels, counts = np.unique(kmeans.labels_, return_counts=True)
        honest_label = labels[np.argmax(counts)]
        for i, l in enumerate(kmeans.labels_):
            if l == honest_label:
                honest_indices.add(i)
    if not honest_indices:
        honest_indices = set(range(len(results)))
    honest_indices = sorted(list(honest_indices))
    filtered_models = models_stack[honest_indices]
    # Average model parameters among honest clients
    avg_params_flat = filtered_models.mean(dim=0)
    new_params = []
    i = 0
    # Use shapes from the first client
    for v in param_lists[0]:
        length = np.prod(v.shape)
        new_params.append(avg_params_flat[i:i+length].reshape(v.shape).numpy())
        i += length
    print("Selected honest clients:", honest_indices)
    return new_params

# 4. Data preparation
def load_data(num_clients=5):
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    testset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    split_idx = np.array_split(np.arange(len(dataset)), num_clients)
    trainloaders = [DataLoader(Subset(dataset, idx), batch_size=32, shuffle=True) for idx in split_idx]
    testloader = DataLoader(testset, batch_size=128)
    return trainloaders, testloader

def client_fn(cid):
    model = SimpleNet()
    trainloaders, testloader = load_data()
    return FlowerClient(model, trainloaders[int(cid)], testloader)

# 5. Start Flower Simulation
if __name__ == "__main__":
    num_clients = 5
    num_rounds = 5
    strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,
        min_fit_clients=num_clients,
        min_available_clients=num_clients,
        on_aggregate_fit=lambda r: aggregate_with_defense(r, num_classes=10),
    )
    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
    )
