<a href="https://colab.research.google.com/github/ambikad04/FL-ModelPoisoning/blob/main/FL_with_poisoning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Libraries Used
The following libraries are essential for implementing and running the Federated Learning with model poisoning script:

### Libraries Overview

- **`torch`**: Core library for tensor computations with GPU support.
- **`torch.nn`**: Provides tools to build and train neural networks (used to define `SimpleModel`).
- **`torch.optim`**: Contains optimization algorithms (e.g., `SGD` for training the model).
- **`numpy`**: Supports numerical operations (not directly used in the script but useful for data manipulation).

These libraries form the backbone of the script, enabling efficient implementation of federated learning with model poisoning.

In [None]:
# libraries

import torch
from torch import nn, optim
import numpy as np

### Function: `create_client_data`

This function generates synthetic data for each client in a Federated Learning (FL) setup.


#### **Purpose**
To create simulated datasets for multiple clients, where each client's data includes input features (`x`) and corresponding labels (`y`).

#### **Parameters**
- **`num_clients` (int):**  
  The number of clients for which data needs to be generated.
  
- **`num_samples` (int, default=100):**  
  The number of data samples each client will have.

#### **Logic**
1. Creates random 2D feature data (`x`) for each client using `torch.randn`.
2. Labels (`y`) are binary, based on the condition \( x_1 + x_2 > 0 \).
3. Stores datasets as `(x, y)` in a dictionary with keys like `'client_0'`, `'client_1'`, etc.


This function creates independent datasets for clients, ideal for federated training.

In [None]:
# for client data
def create_client_data(num_clients, num_samples=100):
    data = {}
    for i in range(num_clients):
        x = torch.randn(num_samples, 2)
        y = (x[:, 0] + x[:, 1] > 0).float().reshape(-1, 1)
        data[f'client_{i}'] = (x, y)
    return data

### Class: `SimpleModel`

A simple feedforward neural network designed for binary classification tasks. It processes input features through a fully connected layer followed by a sigmoid activation, outputting probabilities that represent the likelihood of belonging to one of two classes.


#### **Structure**
1. **Fully Connected Layer (`fc`)**: Maps 2 input features to 1 output.
2. **Sigmoid Activation**: Converts output to probabilities (0-1).


#### **Methods**
- `__init__`: Initializes the model layers.
- `forward`: Defines the forward pass:  
  Input → Fully Connected Layer → Sigmoid Activation.

This lightweight model is ideal for binary classification in Federated Learning setups.

In [None]:
# for neural network
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.fc(x))

### Function: `train_local_model`

Trains a local model on client data using a binary classification loss. This function simulates the local training process for Federated Learning, where each client trains its model independently before sharing updates with the global model.


#### **Parameters**
- `model`: The PyTorch model to be trained.
- `data`: Tuple `(x, y)` with input features and true labels.
- `epochs` (default=5): Number of training iterations.
- `lr` (default=0.01): Learning rate for the optimizer.


#### **Logic**
1. Uses `nn.BCELoss` for binary classification loss.
2. Optimizes with `SGD` over specified epochs.
3. Updates model weights via backpropagation.


This function handles local model training in a Federated Learning environment.

In [None]:
# local training
def train_local_model(model, data, epochs=5, lr=0.01):
    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(data[0]) # predict the output
        loss = criterion(output, data[1])  # loss between the predicted values with true lebel
        loss.backward()
        optimizer.step()
    return model.state_dict()

### Function: `poison_model`

Introduces malicious behavior by altering the global model's parameters to simulate a poisoning attack. This function demonstrates how a malicious client can disrupt the Federated Learning process by contributing poisoned updates to the global model.


#### **Parameters**
- `global_model`: The current global model to be poisoned.
- `scale` (default=10.0): Intensity of the parameter alteration.



#### **Logic**
1. Clones the global model into a `malicious_model`.
2. Alters the parameters by adding random noise scaled by `scale`.
3. Ensures no gradients are computed during this operation using `torch.no_grad`.


This function simulates a malicious client's attack in a Federated Learning setup.

In [None]:
# Poisoning function for a malicious client
def poison_model(global_model, scale=10.0):
    malicious_model = SimpleModel()
    malicious_model.load_state_dict(global_model.state_dict())  # Load global model parameters
    with torch.no_grad():  # ignore the gredient
        for param in malicious_model.parameters():
            param.add_(torch.randn_like(param) * scale)
    return malicious_model.state_dict()

### Function: `federated_averaging`

Aggregates model updates from all clients to update the global model. This function ensures that the global model reflects the combined knowledge of client models in a Federated Learning setup.

#### **Logic**
1. Retrieves the global model's parameters (`state_dict`).
2. Computes the average of corresponding parameters from all client models.
3. Updates the global model with the averaged parameters.

#### **Output**
- The global model is updated with the aggregated parameters.


This function ensures that the global model reflects the collective contributions of all clients in a Federated Learning setup.

In [None]:
# Federated averaging
def federated_averaging(global_model, client_models):
    global_state_dict = global_model.state_dict()  # get the parameter of global model
    for key in global_state_dict:
        global_state_dict[key] = torch.stack([client_models[i][key] for i in range(len(client_models))]).mean(dim=0)
    global_model.load_state_dict(global_state_dict)

### Function: `federated_learning_with_poisoning`

Simulates Federated Learning with the presence of a malicious client introducing poisoned updates. This function demonstrates how malicious behavior impacts the global model's performance while coordinating updates from honest clients.

#### **Logic**
1. **Initialization:**  
   Creates a global model (`SimpleModel`) shared across clients.
   
2. **Global Training Loop:**  
   - For each epoch:
     - Collects updates from all clients:
       - **Malicious Client:** Adds poisoned model updates using `poison_model`.  
       - **Honest Clients:** Locally trains models using their data (`train_local_model`).
     - Aggregates updates using `federated_averaging`.

3. **Logging:**  
   Prints information about malicious activity and epoch completion.

4. **Output:**  
   Returns the trained global model.

This function demonstrates the impact of a malicious client in Federated Learning.

In [None]:
# Federated learning with a malicious client
def federated_learning_with_poisoning(data, num_clients, global_epochs=3, local_epochs=5, malicious_client=0):
    global_model = SimpleModel()
    for epoch in range(global_epochs): # outer loop
        client_models = []
        for client in range(num_clients):
            if client == malicious_client:  # If the client is malicious
                print(f"Client {client} is malicious in epoch {epoch + 1}.")
                poisoned_model_state = poison_model(global_model)  # Poison the model
                client_models.append(poisoned_model_state)
            else:  # to check honest clients
                local_model = SimpleModel()
                local_model.load_state_dict(global_model.state_dict())  # to store the global model into the local model
                client_data = data[f'client_{client}']   # retrive the client data
                client_model_state = train_local_model(local_model, client_data, epochs=local_epochs)
                client_models.append(client_model_state) # add the trained model’s state
        federated_averaging(global_model, client_models)
        print(f"Global epoch {epoch + 1} completed.")
    return global_model

### Main Script

Runs the Federated Learning simulation with a malicious client and evaluates the global model's performance. This script demonstrates the end-to-end process of data generation, model training, and testing in a Federated Learning setup with a malicious participant.

#### **Steps**
1. **Setup:**  
   - Defines the total number of clients (`num_clients=5`).  
   - Assigns a malicious client (`malicious_client_id=1`).  

2. **Data Generation:**  
   Creates synthetic data for all clients using `create_client_data`.

3. **Federated Learning:**  
   Trains the global model with client updates using `federated_learning_with_poisoning`.

4. **Testing:**  
   - Generates test data and labels.  
   - Evaluates the global model's accuracy by comparing predictions to true labels.

5. **Output:**  
   Prints the test accuracy of the global model.

This script demonstrates the workflow of Federated Learning with a malicious client.

In [None]:
# Main script
if __name__ == "__main__":
    num_clients = 5  # Total number of clients
    malicious_client_id = 1

    client_data = create_client_data(num_clients) # generates fake client data by called 'create_client_dat'

    global_model = federated_learning_with_poisoning(client_data, num_clients, malicious_client=malicious_client_id)

    # testing the global model
    test_data = torch.randn(20, 2)
    test_labels = (test_data[:, 0] + test_data[:, 1] > 0).float().reshape(-1, 1)  # labels for the test data
    global_model.eval()
    predictions = global_model(test_data).round()  # Make predictions
    accuracy = (predictions == test_labels).float().mean()  # Calculate accuracy
    print(f"Test Accuracy: {accuracy.item() * 100:.2f}%")

Client 1 is malicious in epoch 1.
Global epoch 1 completed.
Client 1 is malicious in epoch 2.
Global epoch 2 completed.
Client 1 is malicious in epoch 3.
Global epoch 3 completed.
Test Accuracy: 65.00%
