In [1]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from flwr.client import NumPyClient, Client
from tqdm import tqdm  # For progress bars
import logging  # For logging

  from .autonotebook import tqdm as notebook_tqdm
2025-03-25 18:18:16,069	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
def get_dataloader(file_name: str, batch_size: int = 128, num_workers: int = 8) -> DataLoader:
    required_columns = ['Vehicle Speed[km/h]', 'Acceleration_ms2', 'OAT[DegC]', 'Slope_deg', 'Energy_Consumption']
    df = pd.read_csv(file_name, low_memory=False)
    df = df[[col for col in required_columns if col in df.columns]]
    X = df[['Vehicle Speed[km/h]', 'Acceleration_ms2', 'OAT[DegC]', 'Slope_deg']].values.astype(np.float32)
    y = df['Energy_Consumption'].values.astype(np.float32).reshape(-1, 1)
    dataset = TensorDataset(torch.tensor(X), torch.tensor(y))
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

In [3]:
class ComplexNet(nn.Module):
    def __init__(self, input_dim: int = 4):
        super(ComplexNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.bn1 = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, 64)
        self.bn3 = nn.BatchNorm1d(64)
        self.fc4 = nn.Linear(64, 32)
        self.bn4 = nn.BatchNorm1d(32)
        self.fc5 = nn.Linear(32, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc4(x)
        x = self.bn4(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc5(x)
        return x

In [4]:
class FedProxNumPyClient(NumPyClient):
    def __init__(
        self,
        cid: str,
        model: nn.Module,
        dataloader: DataLoader,
        device: torch.device,
        local_epochs: int = 10,
        lr: float = 0.01,
        mu: float = 0.1,
    ):
        self.cid = cid
        self.model = model
        self.dataloader = dataloader
        self.local_epochs = local_epochs
        self.lr = lr
        self.mu = mu
        self.criterion = nn.MSELoss()
        self.device = device
        self.model.to(self.device)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
        self.global_model = ComplexNet().to(self.device)
        self.global_model.load_state_dict(self.model.state_dict())

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        new_state_dict = {}
        for k, arr in zip(state_dict.keys(), parameters):
            new_state_dict[k] = torch.tensor(arr, device=self.device)
        self.model.load_state_dict(new_state_dict)
        self.global_model.load_state_dict(new_state_dict)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        server_round = config.get("server_round", 1)  # Get current round from config
        logger.info(f"Client {self.cid} - Starting training for Round {server_round}")
        
        # Progress bar for epochs
        with tqdm(total=self.local_epochs, desc=f"Client {self.cid} - Round {server_round}", unit="epoch") as pbar:
            for epoch in range(self.local_epochs):
                epoch_loss = 0.0
                for X, y in self.dataloader:
                    X, y = X.to(self.device), y.to(self.device)
                    self.optimizer.zero_grad()
                    output = self.model(X)
                    loss = self.criterion(output, y)
                    prox_term = 0.0
                    for param, global_param in zip(self.model.parameters(), self.global_model.parameters()):
                        prox_term += torch.norm(param - global_param) ** 2
                    loss = loss + (self.mu / 2) * prox_term
                    loss.backward()
                    self.optimizer.step()
                    epoch_loss += loss.item()
                avg_loss = epoch_loss / len(self.dataloader)
                logger.info(f"Client {self.cid} - Round {server_round} - Epoch {epoch + 1}/{self.local_epochs} - Loss: {avg_loss:.4f}")
                pbar.update(1)  # Update progress bar
        return self.get_parameters({}), len(self.dataloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        loss_total = 0.0
        total = 0
        with torch.no_grad():
            for X, y in self.dataloader:
                X, y = X.to(self.device), y.to(self.device)
                output = self.model(X)
                loss = self.criterion(output, y)
                batch_size = X.size(0)
                loss_total += loss.item() * batch_size
                total += batch_size
        avg_loss = loss_total / total
        logger.info(f"Client {self.cid} - Evaluation Loss: {avg_loss:.4f}")
        return avg_loss, total, {}

In [5]:
def client_fn(cid: str) -> Client:
    file_names = ["vehicle_4.csv", "vehicle_455_data.csv", "vehicle_10_data.csv", "vehicle_541_data.csv"]
    file_name = file_names[int(cid)]
    dataloader = get_dataloader(file_name, batch_size=32, num_workers=4)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = ComplexNet()
    numpy_client = FedProxNumPyClient(
        cid=cid,
        model=model,
        dataloader=dataloader,
        device=device,
        local_epochs=10,
        lr=0.01,
        mu=0.1,
    )
    return numpy_client.to_client()

In [6]:
import flwr as fl
from flwr.server.strategy import FedAvg
from flwr.common import FitIns
from logging import INFO
from tqdm import tqdm

class CustomFedAvg(FedAvg):
    def __init__(self, *args, num_rounds=10, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_rounds = num_rounds
        self.round_losses = []  # Store losses for final display

    def configure_fit(self, server_round, parameters, client_manager):
        # Get default instructions from FedAvg
        client_instructions = super().configure_fit(server_round, parameters, client_manager)
        
        # Config to send to clients
        config = {"server_round": server_round}
        
        # Update FitIns with the config
        updated_instructions = [
            (client_proxy, FitIns(fit_ins.parameters, config))
            for client_proxy, fit_ins in client_instructions
        ]
        
        return updated_instructions

    def fit(self, server_round, parameters, config):
        logger.info(f"Starting Server Round {server_round}/{self.num_rounds}")
        with tqdm(total=1, desc=f"Server Round {server_round}/{self.num_rounds}", unit="round") as pbar:
            res = super().fit(server_round, parameters, config)
            pbar.update(1)
        return res

    def evaluate(self, server_round, parameters):
        res = super().evaluate(server_round, parameters)
        if res is not None:
            loss, metrics = res
            self.round_losses.append((server_round, loss))
            logger.info(f"Server - Round {server_round} - Evaluation Loss: {loss:.4f}")
        return res

    def finalize(self):
        # Display final metrics after all rounds
        logger.info("\n=== Final Metrics ===")
        for round_num, loss in self.round_losses:
            logger.info(f"Round {round_num} - Distributed Loss: {loss:.4f}")
        if self.round_losses:
            avg_loss = sum(loss for _, loss in self.round_losses) / len(self.round_losses)
            logger.info(f"Average Distributed Loss Across All Rounds: {avg_loss:.4f}")

In [7]:
import logging

if __name__ == "__main__":
    num_rounds = 10
    strategy = CustomFedAvg(
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=4,
        min_evaluate_clients=4,
        min_available_clients=4,
        num_rounds=num_rounds,
    )
    
    # Set up logging if not already defined
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

 
    # Start the simulation
    logger.info("Starting Federated Learning Simulation")
    hist = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=4,
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
        client_resources={"num_cpus": 1, "num_gpus": 1},
    )

    # Final metrics are displayed by the strategy's finalize method
    if hist is not None:
        logger.info("Simulation completed successfully.")
    else:
        logger.error("Simulation failed.")

INFO:__main__:Starting Federated Learning Simulation
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2025-03-25 18:18:20,335	INFO worker.py:1771 -- Started a local Ray i