## What is Federated Learning?

Federated Learning (FL) is a **decentralized machine learning approach** where multiple clients (e.g., hospitals, mobile devices, or institutions) collaboratively train a shared global model **without sharing their raw data**.

Instead of pooling all data in a central server, each client:
1. **Trains locally** on its own private dataset.
2. **Sends model updates (parameters or gradients)** ‚Äî not the data ‚Äî to a central server.
3. The **server aggregates** these updates to improve the **global model**.
4. This process repeats over several **communication rounds** until convergence.

---

## ‚öôÔ∏è Why Federated Learning?

Traditional centralized ML requires moving all data to one place ‚Äî often impractical or illegal due to **privacy**, **security**, or **data ownership** concerns.

Federated Learning enables:
- ‚úÖ **Data privacy** ‚Äî raw data never leaves the client.
- ‚úÖ **Scalability** ‚Äî computation distributed across clients.
- ‚úÖ **Collaboration** ‚Äî multiple institutions can contribute without sharing sensitive data.

---

## üèóÔ∏è Federated Learning Workflow

**1Ô∏è‚É£ Initialization (Server)**
- The server defines a **global model** (e.g., CNN, LSTM).
- Sends the model weights to all clients.

**2Ô∏è‚É£ Local Training (Clients)**
- Each client trains the model on its **local dataset** for a few epochs.
- The resulting **updated model weights** are sent back to the server.

**3Ô∏è‚É£ Aggregation (Server)**
- The server combines all local updates (e.g., via **Federated Averaging ‚Äî FedAvg**).
- Produces a new **global model** that captures knowledge from all clients.

**4Ô∏è‚É£ Iteration**
- The process repeats for multiple **federated rounds** until the global model converges.

---

## üß© Federated Averaging (FedAvg)

FedAvg is the **core aggregation algorithm** in Federated Learning *(McMahan et al., 2017)*.  
It computes the **weighted average of all client model parameters**:

$$
w_{t+1} = \sum_{k=1}^{K} \frac{n_k}{N} \, w_{t+1}^{(k)}
$$

Where:

- $w_{t+1}^{(k)}$: Model parameters from client *k*  
- $n_k$: Number of training samples on client *k*  
- $N = \sum_k n_k$: Total number of samples across all clients  
- $w_{t+1}$: Updated global model parameters  

Clients with **more data** influence the global model **more**.




---

## ‚öñÔ∏è IID vs Non-IID Data

In FL, client data distributions are often **non-IID** (not identically distributed):

| Scenario | Description | Example |
|-----------|--------------|----------|
| **IID** | Clients have data drawn from similar distributions | Each hospital has balanced patient demographics |
| **Non-IID** | Clients have biased or skewed data distributions | One hospital has mostly older patients, another mostly younger |

Non-IID settings make FL more challenging since local models may diverge significantly ‚Äî proper aggregation and hyperparameter tuning are key.

---

## üìà Evaluating FL Models

After each round:
- Each client‚Äôs **local model** can be evaluated on its **own validation/test set**.
- The **global model** can also be evaluated on each client‚Äôs data to assess overall generalization.

Common metrics:
- **AUC (Area Under ROC Curve)** ‚Äî measures discrimination ability.
- **APR (Average Precision Recall)** ‚Äî evaluates precision-recall trade-off.
- **Loss (e.g., BCE, MSE)** ‚Äî measures prediction error.

---

## üß© Summary

| Component | Role |
|------------|------|
| **Server** | Coordinates training, aggregates client updates |
| **Clients** | Train local models on private data |
| **FedAvg** | Averages model weights across clients |
| **Rounds** | Number of communication cycles between server & clients |

---

## üìö References

- McMahan, B. et al. (2017). *Communication-Efficient Learning of Deep Networks from Decentralized Data*. AISTATS.  
- Google AI Blog (2017). *Federated Learning: Collaborative Machine Learning without Centralized Training Data.*

---

üëâ **In this notebook**, we simulate Federated Learning using the **FedAvg algorithm** with multiple clients, each training on its own subset of the dataset (non-IID partitions).  
This setup allows you to **understand, visualize, and experiment** with client heterogeneity, communication rounds, and model aggregation in FL.


# üî¨ How We Simulate Federated Learning (single-process simulation)

In this notebook we **simulate** a federated learning (FL) experiment inside a single Python process.  
This is an educational and reproducible setup that mimics the high-level behavior of a real FL system while keeping the code simple to run on a laptop.

Below is a description of what the code does and how the simulation maps to real FL components.

---

## Simulation design (high level)

1. **Create clients (nodes)**  
   We simulate `K` clients by partitioning the dataset into `K` disjoint subsets (one per client).  
   Partitioning can be IID or non-IID ‚Äî in this notebook we use a Dirichlet-based split to create **non-IID** client data.

2. **Server initialises a global model**  
   A single global model (e.g., a CNN) lives on the server and is copied to each client at the start of every round.

3. **Federated rounds (main loop)**  
   For each round:
   - The server **sends** the current global weights to all clients (in code: `local_model.load_state_dict(global_model.state_dict())`).
   - Each client trains the model **locally** on its own data for a few epochs and returns its updated model object.
     - In our simulation, local training is done **serially** in a loop (client 0, client 1, ... client K-1).
     - In real FL, clients usually train in parallel and communicate asynchronously or in rounds.
   - The server **aggregates** the returned local model parameters using **FedAvg** (weighted average, typically by client sample count).
   - The aggregated parameters replace the server's global model for the next round.

4. **Logging & saving**  
   - We record per-client training losses and validation metrics after local updates.
   - We evaluate the **global model** on each client‚Äôs test set and save the best global `state_dict()` when it improves.

5. **Final evaluation**  
   - After training, we load and evaluate the saved **best local models** (if saved per-client) and the **final/global model** on every client‚Äôs test data to compare local vs global performance.

---

## Key implementation details (mapping to code cells)

- **Data partitioning**: `split_data_non_iid(...)` (in `loaders_federated_learning.py`) uses a Dirichlet distribution to simulate heterogeneity across clients.  
- **Client DataLoaders**: `get_loaders(...)` (in `loaders_federated_learning.py`) returns a list `Loaders` where `Loaders[k] = [train_loader, val_loader, test_loader]` for client `k`.  
- **Local training**: `train_model(...)` (in `Client.py`) performs training for a client and returns the updated `local_model` and training loss.  
- **Aggregation**: `federated_averaging(models, weights)` computes weighted average of client parameters; server updates global model via `global_model.load_state_dict(...)`.  
- **Evaluation**: `evaluate_models(...)` and `evaluate_models_test(...)` compute AUC / APR per client; `prediction_binary(...)` runs final test predictions.

---

## Why serial simulation is OK for simulation
- **Deterministic & simple** ‚Äî easy to debug and reproduce in a classroom or on a laptop.
- **Focuses on algorithmic ideas** (FedAvg, non-IID effects, client heterogeneity) without the complexity of networking, concurrency, or device management.
- **Easily extensible** to parallel/real FL frameworks later (e.g., Flower, TensorFlow Federated, PySyft).

---

## Limitations vs real FL systems (be aware)
- **No network conditions**: We ignore latency, bandwidth, and dropped clients.
- **No client asynchrony**: Clients are simulated serially; we don‚Äôt model stragglers or partial participation unless explicitly coded.
- **Single process memory**: All model copies and data live in the same process and memory; not realistic for large-scale deployments.
- **Privacy guarantees**: Simulation alone does not provide privacy (e.g., differential privacy or secure aggregation must be added explicitly).

# üß© Exercises ‚Äî Explore and Extend the Federated Learning Framework

Now that weve explored how federated learning (FL) is simulated in this notebook, it‚Äôs time to apply your understanding.  
Complete the following exercises to strengthen your conceptual and practical grasp of FL.

---

## üß† Exercise 1 ‚Äî Implement the Federated Averaging Function

**Objective:**  
Implement the `federated_averaging(models, weights=None)` function to aggregate local client models into a single global model.

**Instructions:**

1. Extract parameters (or the `state_dict`) from each client model.
2. Compute a **weighted average** of all client model parameters.
3. If no `weights` are provided, assume **equal weighting** for all clients.
4. Return the averaged parameters as a new `state_dict` that can be loaded directly into the global model:


## üß† Exercise 2 ‚Äî Using LSTM-based classifier

**Objective:**  
Replace `CNNClassifier` with `LSTMCLassifier` and analyse the difference in performance.

**Instructions:**
Replace both global and local models. 

In [1]:
# -----------------------------------------
# Imports, utilities, and experiment setup
# -----------------------------------------
import random
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch import nn

# Local modules (models, data loader, client training and utils)
# - Models: contains model classes (e.g. LSTMClassifier, CNNClassifier)
# - loaders_federated_learning: data splitting and per-client DataLoader creation
# - Client: per-client training/evaluation logic for federated simulation
# - utils: metrics, logging helpers, etc.

from Models import LSTMClassifier, CNNClassifier   # In this excersice, we will use 1-D CNN model for time-series classification.
from loaders_federated_learning import get_loaders
from Client import *
from utils import *

# -------------------------
# Warnings & display setup
# -------------------------
warnings.filterwarnings("ignore")  # hide noisy warnings (useful for notebook runs)

# -------------------------
# Deterministic seeds / RNG
# -------------------------
# Set seeds for reproducibility. Note: exact reproducibility across platforms/hardware
# (especially with CUDA) may still vary; the flags below reduce nondeterminism.
SEED = 20
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7013da73a6d0>

In [2]:
# -------------------------
# Define number of clients (Nodes) and get dataloaders for each client
# -------------------------
nodes = 5 ## lets create 5 clients
Loaders,weights=get_loaders(nodes=nodes)  ### weights is the percentage of training data assigned to each client

Client 1/5 Summary:
  Total samples: 1957
  Training samples: 1126, Validation: 732, Test: 99
  Positive class ratio: 0.440 (44.0%)
-----------------------------------
Client 2/5 Summary:
  Total samples: 7329
  Training samples: 5942, Validation: 973, Test: 414
  Positive class ratio: 0.031 (3.1%)
-----------------------------------
Client 3/5 Summary:
  Total samples: 1195
  Training samples: 81, Validation: 863, Test: 251
  Positive class ratio: 0.150 (15.0%)
-----------------------------------
Client 4/5 Summary:
  Total samples: 5667
  Training samples: 5078, Validation: 265, Test: 324
  Positive class ratio: 0.173 (17.3%)
-----------------------------------
Client 5/5 Summary:
  Total samples: 5308
  Training samples: 2571, Validation: 489, Test: 2248
  Positive class ratio: 0.133 (13.3%)
-----------------------------------


In [3]:
# -------------------------
# Select device (CPU or GPU)
# -------------------------
# get_device() is implemented in utils and should return either:
#   - torch.device("cuda")  if a GPU is available, or
#   - torch.device("cpu")   otherwise
device = get_device()
print(f"Using device: {device}")

# -------------------------
# Define the global model (server-side)
# -------------------------
# Infer number of input features from the first client's training dataset.
# Assumes TensorDataset where each sample is (features, label) and features shape is (seq_len, n_features)
n_features = Loaders[0][0].dataset[0][0].shape[1]

num_filters = 64              # number of convolutional filters in the CNN
global_model = CNNClassifier(n_features, num_filters, device)
global_model.to(device)       # move model to chosen device
print(global_model)
# -------------------------
# Loss function
# -------------------------
criterion = nn.BCELoss().to(device)



Using device: cuda
CNNClassifier(
  (conv1): Conv1d(76, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=1, bias=True)
  (activation): Sigmoid()
  (dropout): Dropout(p=0.5, inplace=False)
)


In [4]:
### This is just to create the performance logs in each client.

DF = [0]*nodes
# List for best val auc at each client
Val_AUC = [0]*nodes
Val_APR = [0]*nodes

for h in range(0, nodes):
    DF[h] = pd.DataFrame(columns=['Train_Loss', 'Val_Loss', 'Val_AUC','Val_APR'])  ### Dataframe for each client to document training and validation performance. 

In [5]:
### ------------------------------------------------------------
### Federated Averaging (FedAvg)
### ------------------------------------------------------------
### Combines multiple client models into a single global model
### by averaging their parameters (weighted or uniform).
###
### Args:
###   models  : list of PyTorch model objects (each client‚Äôs local model)
###   weights : list or array of weights (one per client), optional.
###             If None ‚Üí uniform averaging (equal contribution).
###
### Returns:
###   averaged_params : list of averaged parameter tensors
###                     (same order as model.state_dict().values())
### ------------------------------------------------------------

def federated_averaging(models, weights=None, device=None):
    """
    Average client models' state_dicts (FedAvg) and return an averaged state_dict.

    Args:
        models (list): list of PyTorch model objects (all same architecture)
        weights (list or None): optional list of weights (one per client). If None, uniform weights used.
        device (torch.device or str or None): device where resulting tensors should be placed.
                                             If None, uses CPU for the averaged tensors.
    Returns:
        avg_state_dict (dict): averaged state_dict (ready for global_model.load_state_dict(avg_state_dict))
    """
    if len(models) == 0:
        raise ValueError("No models provided to federated_averaging.")

    # 1) Collect state_dicts
    state_dicts = [m.state_dict() for m in models]
    keys = list(state_dicts[0].keys())

    # 2) Validate that all models have the same keys / shapes
    for sd in state_dicts[1:]:
        if sd.keys() != state_dicts[0].keys():
            raise ValueError("All models must have the same state_dict keys/architecture.")

    num_clients = len(state_dicts)

    # 3) Prepare/normalize weights
    if weights is None:
        weights = [1.0 / num_clients] * num_clients
    else:
        total = float(sum(weights))
        if total == 0.0:
            raise ValueError("Sum of weights must be > 0.")
        weights = [float(w) / total for w in weights]

    # 4) Decide device for averaged params
    if device is None:
        device = torch.device("cpu")
    else:
        device = torch.device(device)

    # 5) Build averaged state dict
    avg_state = {}
    for key in keys:
        # initialize accumulator with zeros on correct device & dtype
        accum = torch.zeros_like(state_dicts[0][key], device=device, dtype=state_dicts[0][key].dtype)
        for sd, w in zip(state_dicts, weights):
            # move client tensor to accumulation device and multiply by weight
            client_tensor = sd[key].to(device)
            accum += client_tensor * w
        avg_state[key] = accum

    return avg_state



In [6]:
# ============================================================
# Main Federated Training Loop (Server Orchestration)
# ------------------------------------------------------------
# Each round represents one complete cycle of:
#   1. Server sending the global model to all clients
#   2. Each client training locally on its own data
#   3. Clients returning their updated local models
#   4. Server aggregating the local updates (FedAvg)
# ============================================================

num_rounds = 15    # total federated communication rounds
best = 0           # track the best global AUROC during training

for round_num in range(num_rounds):
    print(f"\n================ Round {round_num + 1}/{num_rounds} ================\n")

    client_samples = []   # store each client‚Äôs locally trained model
    LOSS = []             # store local training losses for logging

    # ------------------------------------------------------------
    # Client-side training (simulated serially here for simplicity)
    # In real FL, this happens in parallel on different devices
    # ------------------------------------------------------------
    for client_id in range(nodes):
        print(f"--> Training on Client {client_id + 1}/{nodes}")

        # Initialize a fresh local model and load current global weights
        local_model = CNNClassifier(n_features, num_filters, device).to(device)
        local_model.load_state_dict(global_model.state_dict())  # sync with server

        # Define optimizer and perform local training on this client‚Äôs data
        optimizer = torch.optim.Adam(local_model.parameters(), lr=0.001)
        local_model, loss = train_model(
            local_model,
            Loaders[client_id][0],   # train loader for this client
            criterion,
            optimizer,
            device=device,
            num_epochs=5
        )  ## this function is from Clients.py
        LOSS.append(loss)
        client_samples.append(local_model)

    # ------------------------------------------------------------
    # Server-side aggregation (FedAvg)
    # ------------------------------------------------------------
    # aggregated_params = federated_averaging(client_samples, weights) ### if you dont pass weights, each client will be given equal importance in the aggregation

    # # Update the global model‚Äôs state dictionary with averaged parameters
    # new_state_dict = {k: aggregated_params[i] for i, k in enumerate(global_model.state_dict().keys())}
    # global_model.load_state_dict(new_state_dict)
    # aggregated_state is a dict (state_dict-like) returned by federated_averaging
    aggregated_state = federated_averaging(client_samples, weights=weights, device=device)
    global_model.load_state_dict(aggregated_state)


    # ------------------------------------------------------------
    # Validate local models (post-training performance on validation) and store best performing one at each client
    # ------------------------------------------------------------
    for k in range(nodes):
        local_model = client_samples[k]
        DF[k], Val_AUC[k], cur_auc, cur_apr = evaluate_models(
            k, Loaders, local_model, criterion, device,
            DF[k], Val_AUC, LOSS[k], 'FedAvg'
        )

        print(
            f"Client {k + 1:02d} | "
            f"Train Loss: {LOSS[k]:.3f} | "
            f"Best Val AUC: {Val_AUC[k]:.3f} | "
            f"Current AUC: {cur_auc:.3f} | "
            f"Current APR: {cur_apr:.3f}"
        )

    print("------------------------------------------------------------")

    # ------------------------------------------------------------
    # Validate the updated global model on each client‚Äôs test set
    # ------------------------------------------------------------
    total_auc = 0
    for k in range(nodes):
        _, cur_auc, cur_apr = evaluate_models_test(k, Loaders, global_model, criterion, device)
        total_auc += cur_auc

    global_auc = total_auc / nodes
    print(f"\n>>> Global Model Average AUC (Round {round_num + 1}): {global_auc:.4f}")

    # Save the global model if it achieves a new best AUC
    if global_auc > best:
        best = global_auc
        torch.save(global_model.state_dict(), './trained_models/FedAvg/global_model_state.pt')
    print("============================================================\n")




--> Training on Client 1/5
--> Training on Client 2/5
--> Training on Client 3/5
--> Training on Client 4/5
--> Training on Client 5/5
Client 01 | Train Loss: 0.532 | Best Val AUC: 0.814 | Current AUC: 0.814 | Current APR: 0.614
Client 02 | Train Loss: 0.050 | Best Val AUC: 0.754 | Current AUC: 0.754 | Current APR: 0.110
Client 03 | Train Loss: 0.444 | Best Val AUC: 0.766 | Current AUC: 0.766 | Current APR: 0.101
Client 04 | Train Loss: 0.328 | Best Val AUC: 0.828 | Current AUC: 0.828 | Current APR: 0.843
Client 05 | Train Loss: 0.412 | Best Val AUC: 0.792 | Current AUC: 0.792 | Current APR: 0.491
------------------------------------------------------------

>>> Global Model Average AUC (Round 1): 0.8126



--> Training on Client 1/5
--> Training on Client 2/5
--> Training on Client 3/5
--> Training on Client 4/5
--> Training on Client 5/5
Client 01 | Train Loss: 0.518 | Best Val AUC: 0.822 | Current AUC: 0.822 | Current APR: 0.634
Client 02 | Train Loss: 0.048 | Best Val AUC: 0.754 

In [7]:
# ============================================================
# üîç Evaluation at Each Client
# ------------------------------------------------------------
# Using the best-performing local models stored during
# federated training (FedAvg).
# ------------------------------------------------------------
# For each client:
#   1. Load the locally saved model
#   2. Evaluate on that client's test set
#   3. Record metrics (AUC, APR)
#   4. Save results to CSV
# ------------------------------------------------------------
# Finally, compute and display average metrics across clients.
# ============================================================

from utils import *

total_auc = 0.0
total_apr = 0.0

print("\n========== Evaluating Best Local Models ==========\n")

for client_id in range(nodes):
    model_path = f'./trained_models/FedAvg/node{client_id}'
    result_path = f'./Results/FedAvg/node{client_id}.csv'

    # Load best local model for this client
    local_model = torch.load(model_path, map_location=device)
    local_model.to(device)
    local_model.eval()

    # Evaluate on client's test data (index 2 = test loader)
    test_loss, test_auc, test_apr = prediction_binary(local_model, Loaders[client_id][2], criterion, device)

    total_auc += test_auc
    total_apr += test_apr

    print(f"Client {client_id + 1:02d} | "
          f"Test AUC: {test_auc:.4f} | "
          f"Test APR: {test_apr:.4f}")

    print("------------------------------------------------------------")

# ------------------------------------------------------------
# Compute and display overall performance
# --------------------------------------------




Client 01 | Test AUC: 0.8066 | Test APR: 0.3822
------------------------------------------------------------
Client 02 | Test AUC: 0.7651 | Test APR: 0.6094
------------------------------------------------------------
Client 03 | Test AUC: 0.7867 | Test APR: 0.8170
------------------------------------------------------------
Client 04 | Test AUC: 0.8284 | Test APR: 0.5197
------------------------------------------------------------
Client 05 | Test AUC: 0.8366 | Test APR: 0.1994
------------------------------------------------------------


In [8]:
# ============================================================
# üåç Evaluate the Final Global Model (FedAvg)
# ------------------------------------------------------------
# The saved global model is evaluated on each client‚Äôs test set
# to assess generalization across distributed data.
# ------------------------------------------------------------
# For each client:
#   1. Load the final global model
#   2. Evaluate on the client's local test data
#   3. Collect AUC and APR metrics
# ------------------------------------------------------------
# Finally, compute the mean AUC and APR across all clients.
# ============================================================

sum_auc = 0.0
sum_apr = 0.0

# Load the best global model saved during training
# recreate the model architecture, then load state_dict
model = CNNClassifier(n_features, num_filters, device)   # or the correct class
state = torch.load('./trained_models/FedAvg/global_model_state.pt', map_location=device)
model.load_state_dict(state)
model.to(device)
model.eval()


# print("\n========== Evaluating Final Global Model on All Clients ==========\n")

# for client_id in range(nodes):
#     # Evaluate global model on each client's test set
#     test_loss, test_auc, test_apr = prediction_binary(global_model, Loaders[client_id][2], criterion, device)

#     sum_auc += test_auc
#     sum_apr += test_apr

#     print(f"Client {client_id + 1:02d} | "
#           f"Test AUC: {test_auc:.4f} | "
#           f"Test APR: {test_apr:.4f}")
#     print("------------------------------------------------------------")

# # ------------------------------------------------------------
# # Compute and print global performance across all clients
# # ------------------------------------------------------------
# avg_auc = sum_auc / nodes
# avg_apr = sum_apr / nodes

# print("\n================ Global Model Summary ==================")
# print(f"Average Test AUC across {nodes} clients: {avg_auc:.4f}")
# print(f"Average Test APR across {nodes} clients: {avg_apr:.4f}")
# print("========================================================\n")


ModuleNotFoundError: No module named 'Model_IHM'