# Federated Learning Basics

In this notebook, we introduce **Federated Learning (FL)** : a modern approach that allows multiple clients (devices, institutions, or nodes) to collaboratively train a shared machine learning model **without sharing their raw data**.

This paradigm enables **data privacy, security, and decentralized computation**, which is critical in healthcare, finance, and edge AI applications.

## 🎯 Learning Objectives
- Understand the motivation and concept of Federated Learning
- Learn the architecture of a federated learning system
- Explore **Federated Averaging (FedAvg)** algorithm
- Implement a simple simulation of Federated Learning using PyTorch

## 1. What is Federated Learning?

**Federated Learning** allows training of ML models across multiple devices or servers that hold local datasets, **without centralizing the data**.

Each client (e.g., smartphone, hospital, IoT device) trains the model locally on its data and only shares **model parameters (weights)** with a central server. The server then aggregates these updates to form a global model.

### 🏗️ Federated Learning Architecture
1. **Server initializes a global model.**
2. **Clients train** the model locally on their private data.
3. **Clients send** updated model weights to the server.
4. The **server aggregates** updates (e.g., via averaging).
5. Repeat the process for multiple rounds until convergence.

![Federated Learning Workflow](https://upload.wikimedia.org/wikipedia/commons/e/e2/Federated_learning_process.png)

## ⚙️ 2. Federated Averaging (FedAvg) Algorithm

The **FedAvg** algorithm is the cornerstone of most FL systems. It computes a **weighted average** of local model updates:

$$ w_{t+1} = \sum_{k=1}^{K} \frac{n_k}{n} w_{t+1}^{(k)} $$

Where:
- $w_{t+1}^{(k)}$ → local model weights from client k
- $n_k$ → number of samples at client k
- $n = \sum_k n_k$ → total samples across all clients

In [None]:
# 🧠 3. Simulating Federated Learning with PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import copy

# Simple model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Dataset and data partitioning
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
client_data = random_split(train_data, [12000, 12000, 12000, 12000, 6000])  # 5 clients

def train_local(model, data, epochs=1):
    loader = DataLoader(data, batch_size=64, shuffle=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    model.train()
    for _ in range(epochs):
        for X, y in loader:
            optimizer.zero_grad()
            loss = criterion(model(X), y)
            loss.backward()
            optimizer.step()
    return model.state_dict()

def average_weights(client_weights):
    avg_weights = copy.deepcopy(client_weights[0])
    for key in avg_weights.keys():
        for i in range(1, len(client_weights)):
            avg_weights[key] += client_weights[i][key]
        avg_weights[key] = torch.div(avg_weights[key], len(client_weights))
    return avg_weights

# Federated training simulation
global_model = Net()
for round in range(3):
    local_weights = []
    for client in client_data:
        local_model = Net()
        local_model.load_state_dict(global_model.state_dict())
        client_update = train_local(local_model, client)
        local_weights.append(client_update)
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)
    print(f"Round {round+1} complete ✅")

## 🔒 4. Advantages of Federated Learning
- **Privacy preservation** — data never leaves the client.
- **Reduced data transfer** — only weights are communicated.
- **Collaboration** across organizations without data sharing.
- **Scalable** — can include many clients/devices.

## ⚠️ 5. Challenges in Federated Learning
- **Non-IID data**: clients may have very different data distributions.
- **Communication cost** between clients and server.
- **Privacy leaks** from model updates.
- **Device heterogeneity** in compute and storage.

## 🧭 6. Summary
- Federated Learning trains models collaboratively without centralizing data.
- The **FedAvg** algorithm aggregates local model updates.
- Offers privacy-preserving machine learning at scale.
- Widely applied in healthcare, mobile AI, and finance.

### ✅ Next Notebook: `07-Privacy_Preserving_ML.ipynb`
In the next notebook, we’ll explore **privacy-preserving ML** approaches such as **Differential Privacy** and **Secure Multi-party Computation**, which further enhance the trustworthiness of Federated Learning systems.