# Introduction to Flower Framework

This notebook provides a hands-on introduction to the Flower federated learning framework.

## What You'll Learn

1. What is Flower and why use it
2. Core concepts: Clients, Server, and Strategies
3. Building a simple federated learning client
4. Understanding the federated learning workflow
5. Running your first federated learning simulation

## Setup

First, let's import the necessary libraries:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import flwr as fl
import numpy as np
from collections import OrderedDict

print(f"Torch version: {torch.__version__}")
print(f"Flower version: {fl.__version__}")

## 1. What is Flower?

Flower (flwr) is a federated learning framework that makes it easy to:
- Build federated learning systems
- Work with any ML framework (PyTorch, TensorFlow, JAX, etc.)
- Scale from simulation to production

### Key Components:

1. **Client**: Trains model on local data
2. **Server**: Coordinates training and aggregates updates
3. **Strategy**: Defines how to aggregate client updates (e.g., FedAvg)

## 2. Simple Example: Federated Averaging

Let's create a simple neural network and train it using federated learning.

In [None]:
# Simple neural network
class SimpleModel(nn.Module):
    def __init__(self, input_dim=10, hidden_dim=20, output_dim=2):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create model
model = SimpleModel()
print("Model created:")
print(model)

## 3. Creating a Flower Client

A Flower client needs to implement three main methods:
- `get_parameters()`: Return current model parameters
- `fit()`: Train model on local data
- `evaluate()`: Evaluate model on local data

In [None]:
class FlowerClient(fl.client.NumPyClient):
    """Simple Flower client for demonstration."""
    
    def __init__(self, model, train_data, test_data):
        self.model = model
        self.train_data = train_data
        self.test_data = test_data
    
    def get_parameters(self, config):
        """Return model parameters as NumPy arrays."""
        return [param.cpu().detach().numpy() for param in self.model.parameters()]
    
    def set_parameters(self, parameters):
        """Set model parameters from NumPy arrays."""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)
    
    def fit(self, parameters, config):
        """Train model on local data."""
        self.set_parameters(parameters)
        
        # Simple training loop (1 epoch)
        optimizer = optim.SGD(self.model.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        
        self.model.train()
        for X, y in self.train_data:
            optimizer.zero_grad()
            output = self.model(X)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
        
        return self.get_parameters(config={}), len(self.train_data), {}
    
    def evaluate(self, parameters, config):
        """Evaluate model on local data."""
        self.set_parameters(parameters)
        
        criterion = nn.CrossEntropyLoss()
        self.model.eval()
        
        total_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for X, y in self.test_data:
                output = self.model(X)
                loss = criterion(output, y)
                total_loss += loss.item()
                
                _, predicted = torch.max(output, 1)
                total += y.size(0)
                correct += (predicted == y).sum().item()
        
        accuracy = correct / total
        return total_loss, len(self.test_data), {"accuracy": accuracy}

print("FlowerClient class defined!")

## 4. Creating Synthetic Data

Let's create some synthetic data to demonstrate federated learning:

In [None]:
def create_synthetic_data(n_samples=100, n_features=10, n_classes=2):
    """Create synthetic data for demonstration."""
    X = torch.randn(n_samples, n_features)
    y = torch.randint(0, n_classes, (n_samples,))
    return [(X, y)]  # Return as list of batches

# Create data for 3 clients
print("Creating data for 3 clients...")
client_datasets = []
for i in range(3):
    train_data = create_synthetic_data(n_samples=100)
    test_data = create_synthetic_data(n_samples=20)
    client_datasets.append((train_data, test_data))
    print(f"  Client {i}: 100 train samples, 20 test samples")

print("Data created!")

## 5. Understanding the Workflow

### Federated Learning Process:

```
Round 1:
  Server → Clients: Send initial model
  Clients: Train locally
  Clients → Server: Send updates
  Server: Aggregate updates (FedAvg)
  
Round 2:
  Server → Clients: Send updated model
  ...
  
Repeat for N rounds
```

### Key Points:

1. **Data stays local**: Never leaves the client
2. **Model updates only**: Only weights/gradients are shared
3. **Aggregation**: Server combines updates (weighted average)
4. **Iterative**: Process repeats for multiple rounds

## 6. Flower Simulation

Flower provides a simulation mode to test federated learning on a single machine:

In [None]:
from flwr.simulation import start_simulation

# Create client factory function
def client_fn(cid: str):
    """Create a client with given ID."""
    client_id = int(cid)
    train_data, test_data = client_datasets[client_id]
    
    # Create new model for this client
    client_model = SimpleModel()
    
    return FlowerClient(client_model, train_data, test_data).to_client()

print("Client factory function created!")
print("\nYou can now run simulation with:")
print("""\nstart_simulation(
    client_fn=client_fn,
    num_clients=3,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=fl.server.strategy.FedAvg(),
)""")

## 7. Key Takeaways

### What We Learned:

1. **Flower is framework-agnostic**: Works with PyTorch, TensorFlow, etc.
2. **Simple API**: Only need to implement 3 methods (get_parameters, fit, evaluate)
3. **Privacy-preserving**: Data never leaves clients
4. **Flexible**: Can customize aggregation strategies
5. **Scalable**: From simulation to production

### Client Responsibilities:
- Hold local data
- Train model locally
- Send updates (not data!)

### Server Responsibilities:
- Coordinate training rounds
- Select clients
- Aggregate updates
- Maintain global model

### Strategies:
- **FedAvg**: Weighted average of client models
- **FedProx**: Handles heterogeneous data
- **Custom**: Build your own!

## 8. Next Steps

Ready to move forward? Check out:

1. **Notebook 2**: Complete credit fraud detection example
2. **Documentation**: `docs/FLOWER_BASICS.md` for detailed reference
3. **Source code**: `src/client.py` and `src/server.py` for production-ready implementation
4. **Run script**: `./run_federated_learning.sh` to run the full simulation

### Try it yourself:

```bash
# Start server
python src/server.py

# In separate terminals, start clients
python src/client.py --client-id 0
python src/client.py --client-id 1
python src/client.py --client-id 2
```

## Resources

- [Flower Documentation](https://flower.dev/docs/)
- [Flower Examples](https://github.com/adap/flower/tree/main/examples)
- [Federated Learning Paper](https://arxiv.org/abs/1602.05629)
- [This Repository](https://github.com/Omega-Makena/intro-to-federated-learning)