In [1]:
import flwr as fl
from flwr.common import Parameters, ndarrays_to_parameters, parameters_to_ndarrays
from typing import List, Tuple
import numpy as np

# Custom Strategy to Capture Client Updates
class CaptureClientUpdates(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures
    ) -> Tuple[Parameters, dict]:

        if not results:
            return None, {}

        # Extract parameters from the first successful client result
        parameters = results[0][1].parameters  # FitRes contains .parameters

        # Convert parameters to NumPy arrays
        numpy_params = parameters_to_ndarrays(parameters)

        # Log client update (for debugging)
        print(f"Captured Client Update (Round {server_round}):", numpy_params)

        return ndarrays_to_parameters(numpy_params), {}

# Start the Flower Server with the Custom Strategy
fl.server.start_server(
    server_address="127.0.0.1:8080",
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=CaptureClientUpdates(),
)


	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=5, no round_timeout
[92mINFO [0m:      Flower ECE: gRPC server running (5 rounds), SSL is disabled
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 

Captured Client Update (Round 1): [array([[0.2322308 , 0.6561161 , 0.4179798 , 0.05094564, 0.23025177,
        0.16771424, 0.83542746, 0.46680892, 0.74306077, 0.01375788],
       [0.42961398, 0.70273584, 0.5581952 , 0.53575724, 0.07020327,
        0.33580184, 0.9160534 , 0.5472917 , 0.41564646, 0.12813106],
       [0.43202573, 0.9695043 , 0.09483632, 0.7519552 , 0.72980314,
        0.10493896, 0.76907116, 0.25175408, 0.15864815, 0.34979624],
       [0.49705172, 0.6817873 , 0.76888525, 0.52453935, 0.00886859,
        0.5553601 , 0.382883  , 0.5943054 , 0.39441139, 0.3734004 ],
       [0.41702685, 0.74628514, 0.03215402, 0.31574744, 0.9684598 ,
        0.509366  , 0.63728434, 0.7714907 , 0.47741282, 0.4219751 ],
       [0.10067925, 0.3362363 , 0.04066036, 0.18639855, 0.9129919 ,
        0.5172672 , 0.42754924, 0.9237526 , 0.20362414, 0.6112043 ],
       [0.0887282 , 0.65091723, 0.5705547 , 0.10058654, 0.9203144 ,
        0.802711  , 0.5432774 , 0.94762135, 0.20054318, 0.60488266],
      

[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 1 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 1 results and 1 failures


Captured Client Update (Round 2): [array([[9.2579378e-04, 7.7408445e-01, 1.8691000e-01, 1.4292684e-01,
        3.1920037e-01, 1.1668654e-01, 3.7701875e-01, 1.0746013e-01,
        2.2612080e-01, 1.6390759e-02],
       [4.8818406e-02, 8.7480950e-01, 6.6465266e-02, 8.8810921e-01,
        9.3842441e-01, 8.9761722e-01, 5.3939241e-01, 4.1290408e-01,
        1.9303687e-02, 9.2187673e-01],
       [2.2618191e-01, 3.7123349e-01, 1.6331773e-01, 6.4259881e-01,
        7.7308935e-01, 8.3651811e-01, 5.1912206e-01, 4.4942349e-01,
        7.1406674e-01, 1.9977471e-01],
       [8.1694573e-01, 6.2928230e-01, 9.1983300e-01, 1.7485751e-01,
        6.3993460e-01, 5.8337474e-01, 8.9851636e-01, 5.4198567e-02,
        8.4248269e-01, 3.8260722e-01],
       [8.5910583e-01, 9.0563279e-01, 6.8855029e-01, 8.6748272e-01,
        8.8567334e-01, 4.2224964e-01, 1.6093872e-01, 4.0809852e-01,
        1.0315024e-01, 5.1471382e-01],
       [7.4212420e-01, 7.2848040e-01, 1.9576290e-01, 2.2423290e-02,
        4.0904149e-01,

[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 1 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 1 results and 1 failures


Captured Client Update (Round 3): [array([[0.85128903, 0.877618  , 0.32874033, 0.74650246, 0.8891304 ,
        0.6755172 , 0.06427521, 0.26618496, 0.83727586, 0.6933877 ],
       [0.78384525, 0.708149  , 0.02907784, 0.8063957 , 0.92154825,
        0.9512781 , 0.5690167 , 0.07484394, 0.07823682, 0.44563913],
       [0.6760082 , 0.95306987, 0.11288907, 0.7905644 , 0.5718631 ,
        0.11993358, 0.7437275 , 0.88903236, 0.19349872, 0.3187326 ],
       [0.77529943, 0.76892555, 0.515923  , 0.68262285, 0.28480384,
        0.07035023, 0.26822752, 0.40010592, 0.04663339, 0.7658228 ],
       [0.9796836 , 0.17212711, 0.6041359 , 0.8400856 , 0.39400807,
        0.5803492 , 0.11532942, 0.538153  , 0.35500416, 0.36028975],
       [0.15902467, 0.11562444, 0.8015225 , 0.66353846, 0.9224262 ,
        0.84224606, 0.5539953 , 0.16678621, 0.09011763, 0.55070424],
       [0.72023964, 0.88258564, 0.48584092, 0.5237388 , 0.3389345 ,
        0.19864541, 0.2216676 , 0.77954406, 0.53753227, 0.10425965],
      

[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 1 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 1 results and 1 failures


Captured Client Update (Round 4): [array([[2.94824243e-01, 1.04429923e-01, 3.83160338e-02, 8.30223143e-01,
        9.06989500e-02, 6.51211023e-01, 2.44293779e-01, 3.86325419e-02,
        4.69233900e-01, 4.96207744e-01],
       [8.05311680e-01, 3.61299872e-01, 5.69082320e-01, 6.11982942e-01,
        6.03785515e-02, 5.16735256e-01, 8.34531605e-01, 6.86589479e-01,
        9.09339845e-01, 4.09384400e-01],
       [6.01633608e-01, 2.93540895e-01, 6.34182632e-01, 1.31875232e-01,
        7.75520131e-02, 2.46440783e-01, 4.71062690e-01, 3.63989137e-02,
        5.24916172e-01, 4.60856885e-01],
       [9.36231613e-01, 8.26299310e-01, 6.25667796e-02, 4.93932962e-01,
        4.32108521e-01, 8.82326543e-01, 2.20958263e-01, 9.35494363e-01,
        6.21200979e-01, 7.55292892e-01],
       [1.69157341e-01, 5.98273218e-01, 1.33462429e-01, 7.40260035e-02,
        9.67206895e-01, 5.62671840e-01, 8.61063123e-01, 3.08199316e-01,
        1.36322871e-01, 1.58343926e-01],
       [2.56940275e-01, 5.41385412e-01, 

[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 1 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_fit: received 1 results and 1 failures


Captured Client Update (Round 5): [array([[0.7097493 , 0.19018754, 0.7515432 , 0.7448774 , 0.32250944,
        0.45225966, 0.74388653, 0.10698246, 0.6418251 , 0.6147203 ],
       [0.13592316, 0.02940696, 0.7609227 , 0.11987995, 0.9228298 ,
        0.5606758 , 0.41505435, 0.92899483, 0.91642004, 0.18745148],
       [0.19949178, 0.13850181, 0.01165685, 0.6708342 , 0.11812848,
        0.230918  , 0.5939242 , 0.05777824, 0.44786808, 0.87524706],
       [0.9565921 , 0.5643982 , 0.74768186, 0.17833221, 0.25173706,
        0.23303834, 0.68092674, 0.5180506 , 0.76210654, 0.8666865 ],
       [0.79855615, 0.28142345, 0.06934637, 0.13415734, 0.78731644,
        0.6441332 , 0.5355476 , 0.37499493, 0.76329535, 0.93579483],
       [0.2705715 , 0.59622025, 0.15027189, 0.3056634 , 0.06453491,
        0.02734798, 0.455142  , 0.8849098 , 0.27324507, 0.77067506],
       [0.5582315 , 0.13353975, 0.25902748, 0.11568476, 0.9695244 ,
        0.5328998 , 0.8508101 , 0.85100496, 0.6546186 , 0.19506657],
      

[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 2)
[92mINFO [0m:      aggregate_evaluate: received 1 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 910.08s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.8951917290687561
[92mINFO [0m:      		round 2: 0.7937277555465698
[92mINFO [0m:      		round 3: 0.9435898065567017
[92mINFO [0m:      		round 4: 0.8669863343238831
[92mINFO [0m:      		round 5: 0.817155122756958
[92mINFO [0m:      


History (loss, distributed):
	round 1: 0.8951917290687561
	round 2: 0.7937277555465698
	round 3: 0.9435898065567017
	round 4: 0.8669863343238831
	round 5: 0.817155122756958

In [None]:
import flwr as fl
import torch
import torch.nn as nn
from collections import OrderedDict
from typing import List, Tuple, Dict
import numpy as np

# Define the global model (MLP)
class MLP(nn.Module):
    def __init__(self, input_size: int):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

# Set input size (Ensure it matches the client)
INPUT_SIZE = 15  
global_model = MLP(INPUT_SIZE)

# Define the aggregation strategy
class FedAvg(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[BaseException],
    ) -> Tuple[fl.common.Parameters, Dict[str, fl.common.Scalar]]:
        """Aggregate model weights using weighted averaging."""
        
        # Call the parent aggregate method
        aggregated_parameters, metrics = super().aggregate_fit(server_round, results, failures)
        
        # Convert to PyTorch state_dict format
        aggregated_ndarrays = fl.common.parameters_to_ndarrays(aggregated_parameters)
        state_dict = OrderedDict(zip(global_model.state_dict().keys(), [torch.tensor(nd) for nd in aggregated_ndarrays]))
        global_model.load_state_dict(state_dict, strict=True)

        # Log aggregation details
        num_clients = len(results)
        accuracies = [res.metrics["accuracy"] for _, res in results if "accuracy" in res.metrics]
        avg_accuracy = np.mean(accuracies) if accuracies else 0.0  # Avoid KeyError
        
        print(f"[Server] Round {server_round} aggregated from {num_clients} clients - Avg Accuracy: {avg_accuracy:.4f}")
        
        return aggregated_parameters, metrics

# Start the Flower server
fl.server.start_server(
    server_address="127.0.0.1:8080",
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=FedAvg(),
)


	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=5, no round_timeout
[92mINFO [0m:      Flower ECE: gRPC server running (5 rounds), SSL is disabled
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client


In [None]:
import flwr as fl

# Start the global model server (without aggregation)
fl.server.start_server(
    server_address="localhost:8080",  # Change to your server IP
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=None  # No aggregation here (handled by SuperLink)
)


	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=5, no round_timeout
[92mINFO [0m:      Flower ECE: gRPC server running (5 rounds), SSL is disabled
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client


In [None]:
import torch
import torch.nn as nn
import socket
import pickle

# Define Global Model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(15, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

def connect_to_aggregation_server(host="127.0.0.1", port=9090):
    """Signal to the aggregation server that the global server is ready."""
    print("🔄 Connecting to the aggregation server...")

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        try:
            sock.connect((host, port))
            print("✅ Connected to aggregation server! Waiting for updates...")
        except ConnectionRefusedError:
            print("❌ Could not connect to aggregation server. Make sure it is running.")

def receive_aggregated_updates(host="127.0.0.1", port=9091):
    """Continuously receive aggregated updates from the aggregation server."""
    print("🔄 Waiting for updates from the aggregation server...")

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
        server_socket.bind((host, port))
        server_socket.listen(1)

        while True:
            conn, addr = server_socket.accept()
            with conn:
                print(f"📥 Receiving model updates from {addr}...")

                data = b""
                while True:
                    packet = conn.recv(4096)
                    if not packet:
                        break
                    data += packet

                if data:
                    try:
                        message = data.decode()
                        if message == "STOP":
                            print("🛑 Received termination signal. Shutting down global server.")
                            return "STOP"
                    except UnicodeDecodeError:
                        pass  # Data is not a termination signal, continue processing
                    
                    aggregated_parameters = pickle.loads(data)
                    print("✅ Model updates received!")
                    return aggregated_parameters
                else:
                    print("❌ No data received!")
                    return None

def update_global_model(global_model, aggregated_parameters):
    """Load aggregated parameters into the global model and evaluate it."""
    try:
        state_dict = {key: torch.tensor(value) for key, value in aggregated_parameters.items()}
        global_model.load_state_dict(state_dict)
        print("✅ Global model updated successfully!")
        
        # Evaluate the model
        accuracy = evaluate_global_model(global_model)
        return accuracy
    except Exception as e:
        print(f"❌ Failed to update global model: {e}")
        return None

def evaluate_global_model(model):
    """Evaluate the model performance and return accuracy."""
    print("📊 Evaluating global model...")

    # Example evaluation (replace with real dataset evaluation)
    dummy_input = torch.randn(100, 15)  # Simulated test data
    dummy_labels = torch.randint(0, 2, (100, 1)).float()  # Simulated labels

    with torch.no_grad():
        outputs = model(dummy_input)
        predictions = (outputs > 0.5).float()  # Convert to binary labels
        accuracy = (predictions == dummy_labels).sum().item() / dummy_labels.size(0)

    print(f"Global model evaluation completed! Accuracy: {accuracy:.4f}")
    return accuracy

if __name__ == "__main__":
    global_model = MLP()
    connect_to_aggregation_server()  # Notify aggregation server first
    
    final_accuracy = None

    while True:
        aggregated_parameters = receive_aggregated_updates()  # Receive model updates

        if aggregated_parameters == "STOP":
            print(f"🏆 Final Global Model Accuracy: {final_accuracy:.4f}")
            print("🔻 Shutting down global server.")
            break  # Exit loop properly

        if aggregated_parameters:
            final_accuracy = update_global_model(global_model, aggregated_parameters)


🔄 Connecting to the aggregation server...
✅ Connected to aggregation server! Waiting for updates...
🔄 Waiting for updates from the aggregation server...


In [None]:
today


In [None]:
import torch
import torch.nn as nn
import socket
import pickle

# Define the Fine-Tuned Global Model
class MLP(nn.Module):
    def __init__(self, input_size=10, hidden1=512, hidden2=512, hidden3=256, hidden4=128, output_size=2, dropout_prob=0.2):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden1)
        self.bn1 = nn.BatchNorm1d(hidden1)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_prob)

        self.fc2 = nn.Linear(hidden1, hidden2)
        self.bn2 = nn.BatchNorm1d(hidden2)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_prob)

        self.fc3 = nn.Linear(hidden2, hidden3)
        self.bn3 = nn.BatchNorm1d(hidden3)
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(dropout_prob)

        self.fc4 = nn.Linear(hidden3, hidden4)
        self.bn4 = nn.BatchNorm1d(hidden4)
        self.relu4 = nn.ReLU()
        self.dropout4 = nn.Dropout(dropout_prob)

        self.fc5 = nn.Linear(hidden4, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.dropout3(x)

        x = self.fc4(x)
        x = self.bn4(x)
        x = self.relu4(x)
        x = self.dropout4(x)

        x = self.fc5(x)
        return x

def connect_to_aggregation_server(host="127.0.0.1", port=9090):
    """Signal to the aggregation server that the global server is ready."""
    print("🔄 Connecting to the aggregation server...")

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        try:
            sock.connect((host, port))
            print("✅ Connected to aggregation server! Waiting for updates...")
        except ConnectionRefusedError:
            print("❌ Could not connect to aggregation server. Ensure it is running.")

def receive_aggregated_updates(host="127.0.0.1", port=9091):
    """Continuously receive aggregated updates from the aggregation server."""
    print("🔄 Waiting for updates from the aggregation server...")

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
        server_socket.bind((host, port))
        server_socket.listen(1)

        while True:
            conn, addr = server_socket.accept()
            with conn:
                print(f"📥 Receiving model updates from {addr}...")

                data = b""
                while True:
                    packet = conn.recv(4096)
                    if not packet:
                        break
                    data += packet

                if data:
                    if data.decode(errors="ignore") == "STOP":
                        print("🛑 Received termination signal. Shutting down global server.")
                        return "STOP"

                    aggregated_parameters = pickle.loads(data)
                    print("✅ Model updates received!")
                    return aggregated_parameters
                else:
                    print("❌ No data received!")
                    return None

def update_global_model(global_model, aggregated_parameters):
    """Load aggregated parameters into the global model."""
    try:
        state_dict = {key: torch.tensor(value, dtype=torch.float32) for key, value in aggregated_parameters.items()}
        global_model.load_state_dict(state_dict, strict=False)
        print("✅ Global model updated successfully!")

        # Evaluate the model
        accuracy = evaluate_global_model(global_model)
        return accuracy
    except Exception as e:
        print(f"❌ Failed to update global model: {e}")
        return None

def evaluate_global_model(model):
    """Evaluate the model performance and return accuracy."""
    print("📊 Evaluating global model...")

    # Simulated test data with correct feature size (10)
    dummy_input = torch.randn(100, 10)  
    dummy_labels = torch.randint(0, 2, (100,)).long()

    with torch.no_grad():
        outputs = model(dummy_input)
        _, predictions = torch.max(outputs, 1)  # Get class predictions
        accuracy = (predictions == dummy_labels).sum().item() / dummy_labels.size(0)

    print(f"🏆 Global model evaluation completed! Accuracy: {accuracy:.4f}")
    return accuracy

if __name__ == "__main__":
    global_model = MLP()  # Initialize fine-tuned global model
    connect_to_aggregation_server()  # Notify aggregation server first

    final_accuracy = None

    while True:
        aggregated_parameters = receive_aggregated_updates()  # Receive model updates

        if aggregated_parameters == "STOP":
            print(f"🏆 Final Global Model Accuracy: {final_accuracy:.4f}")
            print("🔻 Shutting down global server.")
            break

        if aggregated_parameters:
            final_accuracy = update_global_model(global_model, aggregated_parameters)


🔄 Connecting to the aggregation server...
✅ Connected to aggregation server! Waiting for updates...
🔄 Waiting for updates from the aggregation server...
