fed simluation for dp


In [5]:
import torch
import torch.nn as nn
import flwr as fl
import numpy as np
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, DataLoader

# âœ… Set device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"âœ… Client device: {device}")

# âœ… Load Client Dataset (Each Client Gets Different Data)
X_clients = np.load('X_clients.npy')
y_clients = np.load('y_clients.npy')

# âœ… Split Data into Two Clients
client_id = int(input("Enter Client ID (1 or 2): ")) - 1
X_client, y_client = np.array_split(X_clients, 2)[client_id], np.array_split(y_clients, 2)[client_id]

# âœ… Convert to Tensor & Create DataLoader
train_dataset = TensorDataset(torch.tensor(X_client, dtype=torch.float32).to(device),
                              torch.tensor(y_client, dtype=torch.float32).to(device))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# âœ… Define Local Model (Same as Global)
class LocalModel(nn.Module):
    def __init__(self, input_size):
        super(LocalModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

input_size = X_client.shape[1]
model = LocalModel(input_size).to(device)

# âœ… Use FedProx Loss (Holds a Global Model for Regularization)
class FedProxLoss(nn.Module):
    def __init__(self, mu=0.01):
        super(FedProxLoss, self).__init__()
        self.mu = mu

    def forward(self, preds, labels, local_params, global_params):
        base_loss = nn.BCELoss()(preds, labels)

        # âœ… Regularization Term (Difference Between Local & Global Parameters)
        prox_loss = sum((torch.norm(local_param - global_param) ** 2).sum()
                        for local_param, global_param in zip(local_params, global_params))

        return base_loss + (self.mu / 2) * prox_loss

# âœ… Define Optimizer, LR Scheduler & Loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.002, weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=10, gamma=0.85)
fedprox_loss = FedProxLoss(mu=0.01)  # Mu controls global model influence

# âœ… Flower Client for Federated Learning
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader):
        self.model = model
        self.train_loader = train_loader
        self.global_params = None  # âœ… Holds global model parameters

    # âœ… Get local model parameters to send to the server
    def get_parameters(self, config):
        return [val.cpu().detach().numpy() for val in self.model.parameters()]

    # âœ… Receive global model parameters & update local model
    def set_parameters(self, parameters):
        if isinstance(parameters, list):  # âœ… Ensure parameters are converted correctly
            params_ndarrays = [torch.tensor(p).to(device) for p in parameters]
        else:
            params_ndarrays = fl.common.parameters_to_ndarrays(parameters)

        self.global_params = params_ndarrays  # âœ… Store global params

        # âœ… Load parameters into the local model
        state_dict = self.model.state_dict()
        for name, param in zip(state_dict.keys(), self.global_params):
            state_dict[name] = param
        self.model.load_state_dict(state_dict)
        print("âœ… Client: Parameters received & updated.")

    # âœ… Training (FedProx: Includes regularization using global parameters)
    def fit(self, parameters, config):
        self.set_parameters(parameters)  # âœ… Receive global model weights
        self.model.train()

        for epoch in range(10):  # âœ… Train for 10 epochs per round
            correct, total = 0, 0
            for X_batch, y_batch in self.train_loader:
                optimizer.zero_grad()
                y_pred = self.model(X_batch).squeeze()

                # âœ… Compute FedProx loss (Regularization to prevent deviation)
                loss = fedprox_loss(y_pred, y_batch, list(self.model.parameters()), self.global_params)
                loss.backward()
                optimizer.step()

                correct += ((y_pred > 0.5) == y_batch).sum().item()
                total += y_batch.size(0)

        client_accuracy = correct / total
        print(f"ðŸ“Œ Client {client_id+1}: Training Completed | Accuracy: {client_accuracy:.4f}")

        scheduler.step()  # âœ… Adjust learning rate
        return self.get_parameters(config), total, {"accuracy": client_accuracy}

    # âœ… Validation (Local model accuracy)
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)  # âœ… Receive global model weights

        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for X_batch, y_batch in self.train_loader:
                y_pred = self.model(X_batch).squeeze()
                correct += ((y_pred > 0.5) == y_batch).sum().item()
                total += y_batch.size(0)

        val_accuracy = correct / total
        print(f"ðŸ“Œ Client {client_id+1}: Validation Accuracy: {val_accuracy:.4f}")
        return 0.0, total, {"accuracy": val_accuracy}

# âœ… Connect to Global Server (Using `start_client()`)
print(f"ðŸš€ Client {client_id+1}: Connecting to the global server...")
fl.client.start_client(
    server_address="localhost:8080",
    client=FLClient(model, train_loader)  # âœ… No need for `.to_client()` in latest Flower versions
)


âœ… Client device: cpu


	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        


ðŸš€ Client 1: Connecting to the global server...


[92mINFO [0m:      
[92mINFO [0m:      Received: train message 9fe9b464-c685-4d65-b94c-564bb7264865


âœ… Client: Parameters received & updated.


[92mINFO [0m:      Sent reply


ðŸ“Œ Client 1: Training Completed | Accuracy: 0.8481


[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 0a29c1c6-c52f-4792-9f28-0aff2f18498c
[92mINFO [0m:      Sent reply


âœ… Client: Parameters received & updated.
ðŸ“Œ Client 1: Validation Accuracy: 0.4076


[92mINFO [0m:      
[92mINFO [0m:      Received: train message 7776b61c-bd4e-4f20-b756-f435ba0bcc2a


âœ… Client: Parameters received & updated.


[92mINFO [0m:      Sent reply


ðŸ“Œ Client 1: Training Completed | Accuracy: 0.8469


[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message a6f0bfec-2bdb-4d9b-950c-cef2f2351217


âœ… Client: Parameters received & updated.


[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message b85828c9-5e2d-4d5a-9374-b16c58136f5b


ðŸ“Œ Client 1: Validation Accuracy: 0.3901
âœ… Client: Parameters received & updated.


[92mINFO [0m:      Sent reply


ðŸ“Œ Client 1: Training Completed | Accuracy: 0.8481


[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message bbe467a0-2663-4934-bad4-5dafb528e214


âœ… Client: Parameters received & updated.


[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 2924f8fe-251f-42da-8b2c-d9adb61fd80e


ðŸ“Œ Client 1: Validation Accuracy: 0.5524
âœ… Client: Parameters received & updated.


[92mINFO [0m:      Sent reply


ðŸ“Œ Client 1: Training Completed | Accuracy: 0.8502


[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 898245e9-3e33-437d-b3bc-31c55b808aa3
[92mINFO [0m:      Sent reply


âœ… Client: Parameters received & updated.
ðŸ“Œ Client 1: Validation Accuracy: 0.5323


[92mINFO [0m:      
[92mINFO [0m:      Received: train message d6b44515-c92a-4ee3-8846-ebb6ed6e89a4


âœ… Client: Parameters received & updated.


KeyboardInterrupt: 