# Federated Learning with Flower

This notebook implements federated learning using Flower. It uses three client datasets (ddos_data.csv, dos_data.csv, probe_data.csv) to train local LogisticRegression models, aggregates them using FedAvg, and evaluates the global model.

In [2]:
%pip install flwr scikit-learn
!pip freeze > requirements.txt

Collecting flwr
  Downloading flwr-1.18.0-py3-none-any.whl.metadata (15 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.7.0-cp313-cp313-macosx_12_0_arm64.whl.metadata (31 kB)
Collecting cryptography<45.0.0,>=44.0.1 (from flwr)
  Downloading cryptography-44.0.3-cp39-abi3-macosx_10_9_universal2.whl.metadata (5.7 kB)
Collecting grpcio!=1.65.0,<2.0.0,>=1.62.3 (from flwr)
  Downloading grpcio-1.73.0-cp313-cp313-macosx_11_0_universal2.whl.metadata (3.8 kB)
Collecting iterators<0.0.3,>=0.0.2 (from flwr)
  Downloading iterators-0.0.2-py3-none-any.whl.metadata (2.5 kB)
Collecting pathspec<0.13.0,>=0.12.1 (from flwr)
  Downloading pathspec-0.12.1-py3-none-any.whl.metadata (21 kB)
Collecting protobuf<5.0.0,>=4.21.6 (from flwr)
  Downloading protobuf-4.25.8-cp37-abi3-macosx_10_9_universal2.whl.metadata (541 bytes)
Collecting pycryptodome<4.0.0,>=3.18.0 (from flwr)
  Downloading pycryptodome-3.23.0-cp37-abi3-macosx_10_9_universal2.whl.metadata (3.4 kB)
Collecting pyyaml<7.0.0,>=6.0.2 (from

In [1]:
# Import required libraries
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
import flwr as fl
from typing import Dict, List, Tuple

# Define paths
clubbed_output_dir = '../output_clubbed/'
client_files = ['ddos_data.csv', 'dos_data.csv', 'probe_data.csv']

# Load and preprocess data for each client
def load_client_data(file_path: str, sample_frac=0.1) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    df = pd.read_csv(file_path)
    
    # Sample data for faster testing (adjust sample_frac as needed)
    df = df.sample(frac=sample_frac, random_state=42)
    
    # Identify numeric columns (excluding Label)
    numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
    non_numeric_cols = [col for col in df.columns if col not in numeric_cols and col != 'Label']
    if non_numeric_cols:
        print(f'Non-numeric columns dropped from {os.path.basename(file_path)}: {non_numeric_cols}')
    
    # Keep only numeric columns and Label
    X = df[numeric_cols].values
    y = df['Label'].values
    
    # Scale features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    return X, y, df['Label'].unique().tolist()

# Verify client datasets
for file in client_files:
    file_path = os.path.join(clubbed_output_dir, file)
    if os.path.exists(file_path):
        X, y, labels = load_client_data(file_path, sample_frac=1.0)
        print(f'Loaded {file} with shape {X.shape} and labels {labels}')
    else:
        print(f'File {file} not found')

Non-numeric columns dropped from ddos_data.csv: ['Flow ID', 'Src IP', 'Dst IP', 'Timestamp']
Loaded ddos_data.csv with shape (191771, 79) and labels ['DDoS', 'Normal', 'BFA']
Non-numeric columns dropped from dos_data.csv: ['Flow ID', 'Src IP', 'Dst IP', 'Timestamp']
Loaded dos_data.csv with shape (166745, 79) and labels ['Probe', 'Normal', 'Web-Attack']
Non-numeric columns dropped from probe_data.csv: ['Flow ID', 'Src IP', 'Dst IP', 'Timestamp']
Loaded probe_data.csv with shape (122204, 79) and labels ['Normal', 'DoS', 'BOTNET']


In [2]:
# Define Flower client
class IntrusionClient(fl.client.NumPyClient):
    def __init__(self, cid: str, X: np.ndarray, y: np.ndarray):
        self.cid = cid
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X, y, test_size=0.2, random_state=42
        )
        self.model = LogisticRegression(max_iter=500)  # Remove multi_class to suppress warning
        
        # Initialize model weights by fitting on a small subset
        if len(self.X_train) > 0:
            init_size = min(100, len(self.X_train))
            self.model.fit(self.X_train[:init_size], self.y_train[:init_size])

    def get_parameters(self, config) -> List[np.ndarray]:
        return [self.model.coef_, self.model.intercept_]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        self.model.coef_ = parameters[0]
        self.model.intercept_ = parameters[1]

    def fit(self, parameters: List[np.ndarray], config: Dict) -> Tuple[List[np.ndarray], int, Dict]:
        self.set_parameters(parameters)
        self.model.fit(self.X_train, self.y_train)
        return self.get_parameters(config), len(self.X_train), {'accuracy': float(accuracy_score(self.y_train, self.model.predict(self.X_train)))}

    def evaluate(self, parameters: List[np.ndarray], config: Dict) -> Tuple[float, int, Dict]:
        self.set_parameters(parameters)
        y_pred = self.model.predict(self.X_test)
        accuracy = accuracy_score(self.y_test, y_pred)
        return float(accuracy), len(self.X_test), {'accuracy': float(accuracy)}

# Client function for Flower
def client_fn(cid: str) -> fl.client.Client:
    file = client_files[int(cid)]
    file_path = os.path.join(clubbed_output_dir, file)
    X, y, _ = load_client_data(file_path, sample_frac=1.0)
    return IntrusionClient(cid, X, y)

In [3]:
# Start Flower clients
import threading

def start_client(cid: str):
    fl.client.start_numpy_client(server_address='localhost:8080', client=client_fn(cid))

# Start clients in separate threads
threads = []
for cid in ['0', '1', '2']:
    thread = threading.Thread(target=start_client, args=(cid,))
    threads.append(thread)
    thread.start()

# Wait for all threads to complete
for thread in threads:
    thread.join()

Non-numeric columns dropped from probe_data.csv: ['Flow ID', 'Src IP', 'Dst IP', 'Timestamp']


	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      
[92mINFO [0m:      Received: get_parameters message 98240b4f-4d9f-4302-9ca0-98d5d3094af3
[92mINFO [0m:      Sent reply


Non-numeric columns dropped from dos_data.csv: ['Flow ID', 'Src IP', 'Dst IP', 'Timestamp']
Non-numeric columns dropped from ddos_data.csv: ['Flow ID', 'Src IP', 'Dst IP', 'Timestamp']


	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client=FlowerClient().to_client(), # <-- where FlowerClient is of type flwr.client.NumPyClient object
	)
	Using `start_numpy_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use `flwr.client.start_client()` by ensuring you first call the `.to_client()` method as shown below: 
	flwr.client.start_client(
		server_address='<IP>:<PORT>',
		client