In [None]:
import flwr as fl
import time
import threading

REQUIRED_CLIENTS = 2  # Change based on your setup

class CustomFedAvg(fl.server.strategy.FedAvg):
    """Federated Averaging with controlled shutdown and robust aggregation."""

    def __init__(self, num_rounds=5):
        super().__init__()
        self.num_rounds = num_rounds

    def aggregate_fit(self, rnd, results, failures):
        """Aggregates model updates from clients."""
        print(f"✅ Round {rnd}: {len(results)} client updates, {len(failures)} failures.")
        return super().aggregate_fit(rnd, results, failures)

    def aggregate_evaluate(self, rnd, results, failures):
        """Aggregate evaluation results while handling missing client responses."""
        valid_results = [(res[0], res[1]) for res in results if isinstance(res, tuple) and len(res) == 2]

        if valid_results:
            aggregated_loss = sum(res[0] for res in valid_results) / len(valid_results)
            aggregated_metrics = {
                key: sum(d[key] for _, d in valid_results) / len(valid_results)
                for key in valid_results[0][1]
            }

            print(f"📊 Round {rnd} Evaluation: Loss={aggregated_loss:.4f}, Metrics={aggregated_metrics}")

            if rnd == self.num_rounds:
                print("🏆 Training complete. Server shutting down.")
                time.sleep(2)
                exit(0)

            return aggregated_loss, aggregated_metrics

        print(f"⚠️ No valid evaluation results in round {rnd}. Returning default values.")
        return 0.0, {}  

def wait_for_clients():
    """Ensures all required clients connect before training starts."""
    print(f"🔄 Waiting for {REQUIRED_CLIENTS} clients to connect...")

    while len(fl.server.client_manager().all()) < REQUIRED_CLIENTS:
        time.sleep(1)  

    print("🚀 All clients connected. Starting federated learning.")

def start_server():
    """Start the Flower server with a client wait mechanism."""
    num_rounds = 5
    strategy = CustomFedAvg(num_rounds)

    threading.Thread(target=wait_for_clients, daemon=True).start()

    print("🚀 Starting Flower server on port 8081...")

    fl.server.start_server(
        server_address="0.0.0.0:8081",  
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
    )

if __name__ == "__main__":
    start_server()


	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
Exception in thread Thread-9 (wait_for_clients):
Traceback (most recent call last):
  File "C:\Users\Ashan\AppData\Local\Programs\Python\Python310\lib\threading.py", line 1016, in _bootstrap_inner
[92mINFO [0m:      Starting Flower server, config: num_rounds=5, no round_timeout
    self.run()
  File "d:\Rreserch work\fedenvioremnt\.venv\lib\site-packages\ipykernel\ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "C:\Users\Ashan\AppData\Local\Programs\Python\Python310\lib\threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\Ashan\AppData\Local\Temp\ipykernel_4080\

🔄 Waiting for 2 clients to connect...🚀 Starting Flower server on port 8081...



server for testing a clinet


In [None]:
import flwr as fl
import torch
import numpy as np
from typing import List, Dict, Tuple
from torch import nn
import torch.optim as optim
from flwr.common import parameters_to_ndarrays, ndarrays_to_parameters

# Define the MLP model (same as the one used by the clients)
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_layers=4, dropout=0.2587):
        super(MLP, self).__init__()
        layers = [
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU()
        ]
        for _ in range(num_layers - 1):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout)
            ])
        layers.append(nn.Linear(hidden_dim, 2))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

# Custom Federated Averaging Strategy
class FedAvgCustom(fl.server.strategy.FedAvg):
    def __init__(self, num_rounds: int):
        super().__init__()
        self.num_rounds = num_rounds
        self.current_round = 0
        self.connected_clients = 0

    def aggregate_fit(self, rnd: int, results: List[fl.server.ClientFitResult], failures: List[Exception]) -> Tuple[List[np.ndarray], Dict]:
        """Aggregates model updates from clients and sends them to the global server."""
        aggregated_parameters = super().aggregate_fit(rnd, results, failures)

        if aggregated_parameters:
            print(f"✅ Round {rnd} aggregated successfully!")
            self.current_round = rnd

            # Validate and send to global server
            if validate_aggregated_parameters(aggregated_parameters[0]):
                send_to_global_server(aggregated_parameters[0], self.current_round == self.num_rounds)
            else:
                print("❌ Aggregated parameters failed validation! Not sending to the global server.")

        return aggregated_parameters


# Validate aggregated model parameters
def validate_aggregated_parameters(parameters):
    """Validate aggregated model parameters before sending them to the global server."""
    try:
        ndarrays = parameters_to_ndarrays(parameters)
        if len(ndarrays) == 10:  # Expecting 10 layers (512-512-256-128 MLP)
            print("✅ Aggregated parameters validated successfully!")
            return True
        else:
            print(f"❌ Invalid parameter count! Expected 10, got {len(ndarrays)}")
            return False
    except Exception as e:
        print(f"❌ Parameter validation failed: {e}")
        return False


# Wait for the global server connection
def wait_for_global_server(host="127.0.0.1", port=9090):
    """Wait for the global server to connect before starting FL."""
    print("🔄 Waiting for the global server to connect...")
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
        server_socket.bind((host, port))
        server_socket.listen(1)

        conn, addr = server_socket.accept()
        with conn:
            print(f"✅ Global server connected from {addr}. Ready to receive model updates.")


# Send aggregated updates to the global server
def send_to_global_server(aggregated_parameters, is_last_round, host="127.0.0.1", port=9091):
    """Send aggregated model updates to the global server, with a termination signal if last round."""
    print("📤 Sending aggregated updates to the global server...")
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        try:
            sock.connect((host, port))

            if is_last_round:
                sock.sendall("STOP".encode())  
                print("🛑 Sent termination signal to global server!")
            else:
                # Convert Parameters to ndarrays
                ndarrays = parameters_to_ndarrays(aggregated_parameters)

                # PyTorch model parameters for fine-tuned architecture
                parameter_names = [
                    "fc1.weight", "fc1.bias", 
                    "fc2.weight", "fc2.bias", 
                    "fc3.weight", "fc3.bias", 
                    "fc4.weight", "fc4.bias",
                    "fc5.weight", "fc5.bias"
                ]
                
                parameters_dict = {name: param.tolist() for name, param in zip(parameter_names, ndarrays)}

                # Serialize and send
                serialized_parameters = pickle.dumps(parameters_dict)
                sock.sendall(serialized_parameters)
                print("✅ Aggregated updates sent to the global server!")
        except ConnectionRefusedError:
            print("❌ Could not connect to the global server.")


# Wait for at least `min_clients` clients to connect
def wait_for_clients(host="0.0.0.0", port=8081, min_clients=2):
    """Wait for at least `min_clients` clients to connect before starting federated learning."""
    print(f"🔄 Waiting for at least {min_clients} clients to connect...")

    connected_clients = 0
    client_sockets = []

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
        server_socket.bind((host, port))
        server_socket.listen(5)

        while connected_clients < min_clients:
            conn, addr = server_socket.accept()
            print(f"✅ Client {connected_clients+1}/{min_clients} connected from {addr}")

            try:
                conn.settimeout(5)
                conn.sendall(b"PING")
                response = conn.recv(1024)
                if response.strip() == b"PONG":
                    client_sockets.append(conn)
                    connected_clients += 1
                else:
                    print("⚠️ Client did not respond correctly. Removing connection.")
                    conn.close()
            except socket.timeout:
                print("⚠️ Client did not respond in time. Removing connection.")
                conn.close()

    print("🚀 Minimum client threshold reached! Waiting for final confirmation...")
    
    # Ensure all clients are fully ready before starting training
    for sock in client_sockets:
        try:
            sock.sendall(b"READY")
            confirmation = sock.recv(1024)
            if confirmation.strip() != b"ACK":
                print("⚠️ Client did not confirm readiness. Removing connection.")
                client_sockets.remove(sock)
        except Exception as e:
            print(f"❌ Error verifying client readiness: {e}")
            client_sockets.remove(sock)

    if len(client_sockets) < min_clients:
        print("❌ Not enough ready clients. Restarting client wait...")
        return wait_for_clients(host, port, min_clients)  # Retry if necessary

    print("✅ All clients confirmed readiness. Starting federated learning now!")
    return client_sockets


def start_server():
    # Waiting for global server connection
    wait_for_global_server()  
    
    # Wait for clients to connect
    print("✅ Global server connected. Now waiting for clients...")
    client_sockets = wait_for_clients()  # Get connected clients

    print("🔄 Verifying active client connections before starting training...")
    
    for sock in client_sockets:
        try:
            sock.sendall(b"PING")
            response = sock.recv(1024)
            if response != b"PONG":
                print("⚠️ Client did not respond correctly. Removing connection.")
                client_sockets.remove(sock)
        except Exception as e:
            print(f"❌ Error communicating with client: {e}")
            client_sockets.remove(sock)

    if len(client_sockets) < 2:
        print("❌ Not enough active clients. Restarting client wait...")
        client_sockets = wait_for_clients()  # Wait again if necessary

    print("✅ All clients verified. Starting federated learning now!")

    # Define and configure strategy for federated learning
    num_rounds = 5  
    strategy = FedAvgCustom(num_rounds)

    fl.server.start_server(
        server_address="0.0.0.0:8081",
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
    )

if __name__ == "__main__":
    start_server()


AttributeError: module 'flwr.server' has no attribute 'ClientFitResult'

my saved text


In [6]:
import flwr as fl
import torch
import numpy as np
from typing import List, Tuple, Dict, Optional

# Define a simple neural network (same as the client)
class SimpleNN(torch.nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = torch.nn.Linear(15, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

# Initialize the global model
global_model = SimpleNN()

# Helper function to set model parameters
def set_parameters(model, parameters):
    for param, new_param in zip(model.parameters(), parameters):
        param.data = torch.tensor(new_param)

# Custom Federated Strategy
class CustomFedAvg(fl.server.strategy.FedAvg):
    def aggregate_fit(self, server_round, results, failures):
        aggregated_parameters = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            set_parameters(global_model, fl.common.parameters_to_ndarrays(aggregated_parameters[0]))

        print(f"Global model updated at round {server_round}.")
        return aggregated_parameters

# Start the server
def start_server():
    strategy = CustomFedAvg(
        fraction_fit=1.0,         # Require all available clients to participate
        min_fit_clients=2,        # Ensure at least 2 clients train in each round
        min_available_clients=2,  # Ensure at least 2 clients are available before training starts
    )

    fl.server.start_server(
        server_address="127.0.0.1:8081",
        config=fl.server.ServerConfig(num_rounds=5),
        strategy=strategy,
    )

if __name__ == "__main__":
    start_server()


	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=5, no round_timeout
[92mINFO [0m:      Flower ECE: gRPC server running (5 rounds), SSL is disabled
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 

Global model updated at round 1.


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)


Global model updated at round 2.


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)


Global model updated at round 3.


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)


Global model updated at round 4.


[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 2 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 37.97s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.663948267698288
[92mINFO [0m:      		round 2: 0.6997112929821014
[92mINFO [0m:      		round 3: 0.754900723695755
[92mINFO [0m:      		round 4: 0.8217424154281616
[92mINFO [0m:      		round 5: 0.8074790835380554
[92mINFO [0m:      


Global model updated at round 5.
