In [21]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
from torchvision.datasets import STL10
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from typing import Union
from flwr.server.client_proxy import ClientProxy  # Correctly import ClientProxy
from flwr.common import FitRes, Parameters
import glob
import os
from torch.optim import lr_scheduler
DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

Training on cpu
Flower 1.17.0 / PyTorch 2.6.0+cu124


In [22]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from sklearn.model_selection import train_test_split

NUM_CLIENTS = 10
BATCH_SIZE = 64
MAX_ROUND = 1
NUM_PARTITIONS = NUM_CLIENTS







In [23]:
def load_datasets(partition_id: int, num_partitions: int):
    # Définir les transformations pour le pré-traitement des images
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4467, 0.4398, 0.4066], std=[0.2241, 0.2215, 0.2239])
    ])
    # Charger le dataset STL10
    dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)

    # Diviser les données en partitions pour chaque client
    data_size = len(dataset)
    partition_size = data_size // num_partitions
    start_idx = partition_id * partition_size
    end_idx = start_idx + partition_size if partition_id != num_partitions - 1 else data_size
    
    # Créer un sous-ensemble des données pour ce client
    partition = Subset(dataset, range(start_idx, end_idx))

    # Créer un DataLoader pour l'entraînement
    trainloader = DataLoader(partition, batch_size=BATCH_SIZE, shuffle=True)

    # Charger les données de validation et de test globales
    testset = datasets.STL10(root='./data', split='test', download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)

    # Diviser les données de validation du dataset d'entraînement
    val_size = len(partition) // 5  # 20% des données d'entraînement pour la validation
    train_indices = list(range(len(partition)))
    val_indices = train_indices[:val_size]
    train_indices = train_indices[val_size:]

    train_subset = Subset(partition, train_indices)
    val_subset = Subset(partition, val_indices)

    # Créer des DataLoader pour la validation et l'entraînement
    trainloader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
    valloader = DataLoader(val_subset, batch_size=BATCH_SIZE)

    return trainloader, valloader, testloader



In [24]:
class Net(nn.Module):
    def __init__(self, input_channel=3, num_classes=10):
        super(Net, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(input_channel, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(128*6*6, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes),
        )


    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)

    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        scheduler.step()
        if verbose:
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    loss_fn = torch.nn.CrossEntropyLoss()
    test_loss = 0.0

    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    avg_loss = test_loss / len(testloader)
    return avg_loss, accuracy



In [25]:
class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=3)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [26]:
from typing import Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
from flwr.server.strategy import FedAvg

In [27]:
import flwr as fl

net = Net()  

class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: list[tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
    ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint"""

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(
            server_round, results, failures
        )

        if server_round == MAX_ROUND:

            if aggregated_parameters is not None:
                print(f"Saving round {server_round} aggregated_parameters...")

                # Convert `Parameters` to `list[np.ndarray]`
                aggregated_ndarrays: list[np.ndarray] = fl.common.parameters_to_ndarrays(
                    aggregated_parameters
                )
                # Convert `list[np.ndarray]` to PyTorch `state_dict`
                params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
                state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
                net.load_state_dict(state_dict, strict=True)

                # Save the model to disk
                torch.save(net.state_dict(), f"stl10/model_round_{server_round}.pth")

        return aggregated_parameters, aggregated_metrics

In [28]:
from flwr.common import NDArrays
def evaluate(
    server_round: int,
    parameters: NDArrays,
    config: Dict[str, Scalar],
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
    net = Net().to(DEVICE)
    _, _, testloader = load_datasets(0, NUM_PARTITIONS)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

In [29]:
# Create strategy and pass into ServerApp
def server_fn(context):
    strategy = SaveModelStrategy(
        fraction_fit=1.0,  # Utiliser 10% des clients pour l'entraînement
        fraction_evaluate=1.0,  # Utiliser 10% des clients pour l'évaluation
        min_fit_clients=10,  # Minimum de 10 clients pour l'entraînement
        min_evaluate_clients=10,  # Minimum de 5 clients pour l'évaluation
        min_available_clients=10,
        evaluate_fn=evaluate
    )
    
    config = ServerConfig(num_rounds=MAX_ROUND)
    return ServerAppComponents(strategy=strategy, config=config)




In [None]:
app = ServerApp(server_fn=server_fn)

server = ServerApp(server_fn=server_fn)
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 1}}


run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)


[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=100, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters


[36m(ClientAppActor pid=110005)[0m [Client 2] get_parameters
