# 🌸 Federated Learning with Flower (FLwr)

In this notebook, we’ll explore how to implement **Federated Learning (FL)** using the [**Flower**](https://flower.dev/) framework.

Flower (FLwr) provides a simple yet powerful interface for simulating and deploying FL systems : ideal for experiments and production grade federated systems.

## 🎯 Learning Objectives
- Understand the **Flower FL architecture**
- Learn how to build a **client-server setup** for Federated Learning
- Train a simple **MNIST model** using multiple clients
- Observe how Flower handles **aggregation and orchestration** automatically

## 🧠 1. What is Flower (FLwr)?

**Flower** is an open-source framework that makes federated learning accessible to everyone. It abstracts away the communication, orchestration, and aggregation details so that you can focus on model logic.

![Flower Architecture](https://flower.dev/images/overview/flower-architecture.png)

Each **client** trains locally on its data and communicates with a **server**, which performs aggregation (usually via Federated Averaging).

In [None]:
# 🌸 2. Installing Dependencies
# Uncomment to install Flower (if not already installed)
# !pip install flwr torch torchvision

## ⚙️ 3. Setting Up the Model and Dataset

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import flwr as fl

# Simple CNN model for MNIST
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Split data among clients
client_datasets = random_split(train_dataset, [12000, 12000, 12000, 12000, 6000])

## 🤖 4. Define the Flower Client

Each client implements **three core methods**:
- `get_parameters()` → Returns current model parameters.
- `fit(parameters)` → Trains locally and returns updated parameters.
- `evaluate(parameters)` → Evaluates model on local data.

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, train_data, test_data):
        self.model = model
        self.train_data = DataLoader(train_data, batch_size=64, shuffle=True)
        self.test_data = DataLoader(test_data, batch_size=64)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(model.parameters(), lr=0.01)

    def get_parameters(self, config=None):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config=None):
        self.set_parameters(parameters)
        self.model.train()
        for _ in range(1):
            for X, y in self.train_data:
                self.optimizer.zero_grad()
                loss = self.criterion(self.model(X), y)
                loss.backward()
                self.optimizer.step()
        return self.get_parameters(), len(self.train_data.dataset), {}

    def evaluate(self, parameters, config=None):
        self.set_parameters(parameters)
        self.model.eval()
        loss, correct = 0, 0
        with torch.no_grad():
            for X, y in self.test_data:
                preds = self.model(X)
                loss += self.criterion(preds, y).item()
                correct += (preds.argmax(1) == y).type(torch.float).sum().item()
        accuracy = correct / len(self.test_data.dataset)
        return float(loss), len(self.test_data.dataset), {"accuracy": accuracy}

## 🧩 5. Launching a Flower Simulation

Flower makes it easy to **simulate multiple clients** on a single machine using `flwr.simulation.start_simulation()`.

In [None]:
def client_fn(cid: str):
    model = Net()
    train_data = client_datasets[int(cid)]
    test_data = random_split(test_dataset, [5000, 5000])[0]
    return FlowerClient(model, train_data, test_data)

strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    min_fit_clients=5,
    min_available_clients=5,
)

# Simulate 5 clients
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=5,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
)

print("\n✅ Federated Training Complete!")

## 📊 6. Observations
- Each client trains **locally on its data subset**.
- The server aggregates updates via **Federated Averaging (FedAvg)**.
- Flower’s built-in simulation allows rapid prototyping before real-world deployment.
- Supports PyTorch, TensorFlow, Scikit-learn, and even custom ML frameworks.

## 🧭 7. Summary
- **Flower (FLwr)** is a flexible framework for Federated Learning.
- Supports both **simulation** and **real-world deployments**.
- Handles client orchestration, aggregation, and communication seamlessly.

Next, we’ll explore **Privacy-Preserving ML** : combining Federated Learning with techniques like **Differential Privacy** and **Homomorphic Encryption** to secure model updates even further.