In [4]:
from collections import OrderedDict, defaultdict
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader, TensorDataset

from sklearn.model_selection import train_test_split

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

import os
import pandas as pd

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on cpu
Flower 1.18.0 / PyTorch 2.7.0


#Step 1

In [None]:
def generate_distributed_datasets(k: int, alpha: float, save_dir: str) -> None:
    os.makedirs(save_dir, exist_ok=True)

    train_data = datasets.FashionMNIST(
        root="./data",
        train=True,
        download=True,
        transform=transforms.ToTensor()
    )
    
    targets = np.array(train_data.targets)
    data = np.array(train_data.data)

    num_classes = len(np.unique(targets))
    class_indices = [np.where(targets == y)[0] for y in range(num_classes)]

    client_indices = [[] for _ in range(k)]
    for c in range(num_classes):
        class_idx = class_indices[c]
        np.random.shuffle(class_idx)

        proportions = np.random.dirichlet(alpha=np.repeat(alpha, k))
        proportions = (np.cumsum(proportions) * len(class_idx)).astype(int)[:-1]
        split_indices = np.split(class_idx, proportions)

        for i, idx in enumerate(split_indices):
            client_indices[i].extend(idx.tolist())

    for i, indices in enumerate(client_indices):
        client_data = data[indices]
        client_targets = targets[indices]
        
        flat_images = client_data.reshape(len(client_data), -1)
        df = pd.DataFrame(flat_images)
        df['label'] = client_targets
        
        df.to_csv(os.path.join(save_dir, f'client_{i}.csv'), index=False)

    print(f"Distributed datasets saved to {save_dir}")

In [6]:
NUM_OF_CLIENTS = 5
ALPHA = 1.5
SAVE_PATH = "client_data"

generate_distributed_datasets(NUM_OF_CLIENTS, ALPHA, SAVE_PATH)

100%|██████████| 26.4M/26.4M [00:03<00:00, 6.77MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 1.03MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.84MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 8.87MB/s]
  targets = np.array(train_data.targets)


Distributed datasets saved to client_data


In [7]:
def load_client_data(cid: int, data_dir: str, batch_size: int) -> tuple[DataLoader, DataLoader]:
    df = pd.read_csv(os.path.join(data_dir, f'client_{cid}.csv'))

    # Separate features and labels
    X = df.drop(columns=["label"]).values.astype("float32") / 255.0  # normalize
    y = df["label"].values.astype("int64")

    # Reshape X to N x 1 x 28 x 28 (needed for CNNs)
    X = X.reshape(-1, 1, 28, 28)

    # Train/val split
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

    # Convert to tensors
    X_train_tensor = torch.tensor(X_train)
    y_train_tensor = torch.tensor(y_train)
    X_val_tensor = torch.tensor(X_val)
    y_val_tensor = torch.tensor(y_val)

    # Create datasets and loaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

In [8]:
for i in range(NUM_OF_CLIENTS):
  train_dataloader, val_dataloader = load_client_data(i, SAVE_PATH, 32)

#Step 2

In [9]:
from typing import List

class CustomFashionModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # this is where i create the model itself
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))  # [batch, 32, 14, 14]
        x = self.pool(F.relu(self.conv2(x)))  # [batch, 64, 7, 7]
        x = x.view(x.size(0), -1)             # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def train_epoch(self, train_loader: DataLoader,
                    criterion: nn.Module,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device) -> tuple[float, float]:
        self.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = self(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        avg_loss = running_loss / total
        accuracy = correct / total
        return avg_loss, accuracy

    def test_epoch(self, test_loader: DataLoader,
                   criterion: nn.Module,
                   device: torch.device) -> tuple[float, float]:
        self.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = self(inputs)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        avg_loss = running_loss / total
        accuracy = correct / total
        return avg_loss, accuracy

    def get_model_parameters(self) -> List[np.ndarray]:
        return [param.detach().cpu().numpy() for param in self.state_dict().values()]

    def set_model_parameters(self, parameters: List[np.ndarray]) -> None:
        state_dict = self.state_dict()
        for key, param_array in zip(state_dict.keys(), parameters):
            param_tensor = torch.tensor(param_array)
            state_dict[key].copy_(param_tensor)

#Step 3

In [10]:
from flwr.common import (
    GetPropertiesIns, GetPropertiesRes,
    GetParametersIns, GetParametersRes,
    FitIns, FitRes,
    EvaluateIns, EvaluateRes,
    Code, Status,
    ndarrays_to_parameters, parameters_to_ndarrays
)

from typing import List
import torch

class CustomClient(flwr.client.Client):
    def __init__(self, model: torch.nn.Module, train_loader,
                 test_loader, device: torch.device) -> None:
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device

    def get_properties(self, instruction: GetPropertiesIns) -> GetPropertiesRes:
        return GetPropertiesRes(
            status=Status(code=Code.OK, message="Success"),
            properties={"framework": "pytorch", "dataset": "FashionMNIST"}
        )

    def get_parameters(self, instruction: GetParametersIns) -> GetParametersRes:
        weights: List[np.ndarray] = self.model.get_model_parameters()
        return GetParametersRes(
            status=Status(code=Code.OK, message="Success"),
            parameters=ndarrays_to_parameters(weights)
        )

    def fit(self, instruction: FitIns) -> FitRes:
        # Set model parameters from server
        params = parameters_to_ndarrays(instruction.parameters)
        self.model.set_model_parameters(params)

        # Training
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        self.model.to(self.device)
        train_loss, train_accuracy = self.model.train_epoch(
            self.train_loader, criterion, optimizer, self.device
        )

        # Return updated parameters
        updated_weights = self.model.get_model_parameters()
        return FitRes(
            status=Status(code=Code.OK, message="Trained successfully"),
            parameters=ndarrays_to_parameters(updated_weights),
            num_examples=len(self.train_loader.dataset),
            metrics={"train_loss": train_loss, "train_accuracy": train_accuracy}
        )

    def evaluate(self, instruction: EvaluateIns) -> EvaluateRes:
        # Set model parameters from server
        params = parameters_to_ndarrays(instruction.parameters)
        self.model.set_model_parameters(params)

        # Evaluation
        criterion = torch.nn.CrossEntropyLoss()
        self.model.to(self.device)
        test_loss, test_accuracy = self.model.test_epoch(
            self.test_loader, criterion, self.device
        )

        return EvaluateRes(
            status=Status(code=Code.OK, message="Evaluated successfully"),
            loss=test_loss,
            num_examples=len(self.test_loader.dataset),
            metrics={"accuracy": test_accuracy}
        )

    def to_client(self) -> "CustomClient":
        return self

In [11]:
import argparse
import sys

#from model import CustomFashionModel  # Make sure this path is correct
#from client import CustomClient       # Your FLWR client class
#from data import load_client_data     # Your data loading function


def main():
    # Step 1: Parse command-line arguments
    sys.argv = ["notebook", "--cid", "0", "--data_dir", "client_data", "--batch_size", "32"]

    
    parser = argparse.ArgumentParser(description="Run a Flower client.")
    parser.add_argument("--cid", type=int, required=True, help="Client ID")
    parser.add_argument("--data_dir", type=str, default=SAVE_PATH, help="Data directory")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    args = parser.parse_args()

    # Step 2: Load client data
    train_loader, test_loader = load_client_data(
        cid=args.cid,
        data_dir=args.data_dir,
        batch_size=args.batch_size
    )

    # Step 3: Create model and move to device
    model = CustomFashionModel()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Step 4: Create client instance
    client = CustomClient(model, train_loader, test_loader, device)

    # Step 5: Start Flower client
    !flower-supernode --insecure --superlink='127.0.0.1:8080'
    flwr.client.start_client(server_address="127.0.0.1:8080", client=client.to_client())


if __name__ == "__main__":
    main()

  pid, fd = os.forkpty()


[92mINFO [0m:      Starting Flower SuperNode
[92mINFO [0m:      Starting Flower ClientAppIo gRPC server on 0.0.0.0:9094
^C
Traceback (most recent call last):
  File [35m"/Users/cheoso/fed_learning/venv/lib/python3.13/site-packages/flwr/common/retry_invoker.py"[0m, line [35m276[0m, in [35minvoke[0m
    ret = target(*args, **kwargs)
  File [35m"/Users/cheoso/fed_learning/venv/lib/python3.13/site-packages/grpc/_interceptor.py"[0m, line [35m277[0m, in [35m__call__[0m
    response, ignored_call = [31mself._with_call[0m[1;31m([0m
                             [31m~~~~~~~~~~~~~~~[0m[1;31m^[0m
        [1;31mrequest,[0m
        [1;31m^^^^^^^^[0m
    ...<4 lines>...
        [1;31mcompression=compression,[0m
        [1;31m^^^^^^^^^^^^^^^^^^^^^^^^[0m
    [1;31m)[0m
    [1;31m^[0m
  File [35m"/Users/cheoso/fed_learning/venv/lib/python3.13/site-packages/grpc/_interceptor.py"[0m, line [35m332[0m, in [35m_with_call[0m
    return [31mcall.result[0m[1;31m()[0

	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        


_MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:8080: Failed to connect to remote host: connect: Connection refused (61)"
	debug_error_string = "UNKNOWN:Error received from peer  {created_time:"2025-05-22T15:26:38.273837+02:00", grpc_status:14, grpc_message:"failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:8080: Failed to connect to remote host: connect: Connection refused (61)"}"
>

#Step 5

In [12]:
import threading
import time
import random
from typing import Dict, List, Optional
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

class CustomClientManager(ClientManager):
    def __init__(self):
        self._clients: Dict[str, ClientProxy] = {}
        self._lock = threading.Lock()
        self._clients_available = threading.Condition(lock=self._lock)

    def num_available(self) -> int:
        with self._lock:
            return len(self._clients)

    def register(self, client: ClientProxy) -> bool:
        with self._lock:
            client_id = client.cid
            if client_id in self._clients:
                return False  # Already registered
            self._clients[client_id] = client
            self._clients_available.notify_all()
            return True

    def unregister(self, client: ClientProxy) -> None:
        with self._lock:
            client_id = client.cid
            if client_id in self._clients:
                del self._clients[client_id]

    def all(self) -> Dict[str, ClientProxy]:
        with self._lock:
            return dict(self._clients)  # Return a copy to avoid race conditions

    def wait_for(self, num_clients: int, timeout: int) -> bool:
        with self._clients_available:
            start_time = time.time()
            while len(self._clients) < num_clients:
                remaining = timeout - (time.time() - start_time)
                if remaining <= 0:
                    break
                self._clients_available.wait(timeout=remaining)
            return len(self._clients) >= num_clients

    def sample(
        self,
        num_clients: int,
        min_num_clients: Optional[int] = None,
        criterion: Optional[object] = None,
    ) -> List[ClientProxy]:
        with self._lock:
            available_clients = list(self._clients.values())

            # If criterion is specified, filter clients accordingly (optional)
            if criterion is not None:
                available_clients = [c for c in available_clients if criterion(c)]

            # Check min_num_clients requirement
            if min_num_clients is not None and len(available_clients) < min_num_clients:
                return []

            # Randomly sample clients if enough are available
            if len(available_clients) < num_clients:
                return []

            sampled_clients = random.sample(available_clients, num_clients)
            return sampled_clients


In [None]:
from typing import List, Tuple, Dict, Optional, Union
import numpy as np

from flwr.common import (
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
    FitIns,
    FitRes,
    EvaluateIns,
    EvaluateRes,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import Strategy


class FedAvgStrategy(Strategy):
    def __init__(self):
        self._global_parameters: Optional[Parameters] = None

    def initialize_parameters(self, client_manager: ClientManager) -> Optional[Parameters]:
        # Initialize global parameters from one client or stored state
        if self._global_parameters is not None:
            return self._global_parameters

        # If no global params stored, fetch from one client if available
        clients = list(client_manager.all().values())
        if not clients:
            return None

        parameters = clients[0].get_parameters()
        self._global_parameters = parameters
        return parameters

    def configure_fit(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager: ClientManager,
    ) -> List[Tuple[ClientProxy, FitIns]]:
        # Configure fit instructions: send current global model to all clients
        clients = list(client_manager.all().values())
        fit_ins = []
        for client in clients:
            fit_ins.append((client, FitIns(parameters, {})))  # empty config dict
        return fit_ins

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results:
            return None, {}

        # Extract weights and num_examples from each client result
        weights = []
        num_examples = []
        losses = []

        for client, fit_res in results:
            # Extract NumPy arrays from Parameters object
            ndarrays = parameters_to_ndarrays(fit_res.parameters)
            weights.append(ndarrays)
            # Assume fit_res.metrics includes 'num_examples' and 'loss'
            num_ex = fit_res.num_examples if hasattr(fit_res, "num_examples") else None
            if num_ex is None:
                # fallback: check metrics dictionary
                num_ex = fit_res.metrics.get("num_examples") if fit_res.metrics else None
            if num_ex is None:
                num_ex = 1  # fallback to 1 to avoid division by zero

            num_examples.append(num_ex)

            loss = fit_res.metrics.get("loss") if fit_res.metrics else None
            if loss is not None:
                losses.append(loss)

        # Weighted average of weights (FedAvg)
        total_examples = sum(num_examples)
        averaged_weights = [
            sum(w[i] * num_examples[j] for j, w in enumerate(weights)) / total_examples
            for i in range(len(weights[0]))
        ]

        # Convert averaged weights back to Parameters object
        aggregated_parameters = ndarrays_to_parameters(averaged_weights)
        self._global_parameters = aggregated_parameters

        # Aggregate loss (weighted average)
        avg_loss = sum(losses[i] * num_examples[i] for i in range(len(losses))) / total_examples if losses else 0.0

        # Return new global parameters and aggregated metrics
        return aggregated_parameters, {"loss": avg_loss}

    def configure_evaluate(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager: ClientManager,
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        # Configure evaluation instructions to all clients
        clients = list(client_manager.all().values())
        eval_ins = []
        for client in clients:
            eval_ins.append((client, EvaluateIns(parameters, {})))  # empty config dict
        return eval_ins

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        if not results:
            return None, {}

        # Aggregate evaluation loss and other metrics by weighted average
        losses = []
        num_examples = []

        for client, eval_res in results:
            loss = eval_res.loss
            num_ex = eval_res.num_examples if hasattr(eval_res, "num_examples") else None
            if num_ex is None:
                num_ex = 1
            losses.append(loss)
            num_examples.append(num_ex)

        total_examples = sum(num_examples)
        avg_loss = sum(losses[i] * num_examples[i] for i in range(len(losses))) / total_examples

        # Return average loss and empty dict for metrics (can be extended)
        return avg_loss, {}


#Step 6

In [15]:
from pathlib import Path
import json


def main():
    # 1. Define server address
    server_address = "[::]:8080"  # Listen on all interfaces on port 8080

    # 2. Define federated learning hyperparameters
    num_rounds = 5

    # 3. Instantiate ClientManager and Strategy
    #client_manager = CustomClientManager()
    #strategy = FedAvgStrategy()

    # 4. Start the Flower server
    history = flwr.server.start_server(
        server_address=server_address,
        config=flwr.server.ServerConfig(num_rounds=num_rounds),
        client_manager=CustomClientManager,
        strategy=FedAvgStrategy,
    )

    # 5. Extract history info
    losses_distributed = history.losses_distributed
    metrics_distributed_fit = history.metrics_distributed_fit
    metrics_distributed = history.metrics_distributed

    # 6. Save results as JSON
    results = {
        "losses_distributed": losses_distributed,
        "metrics_distributed_fit": metrics_distributed_fit,
        "metrics_distributed": metrics_distributed,
    }
    save_path = Path("fl_history.json")
    with open(save_path, "w") as f:
        json.dump(results, f, indent=4)

    print(f"Training history saved to {save_path}")


if __name__ == "__main__":
    main()

NameError: name 'FedAvgStrategy' is not defined