# Federated Learning for Decentralized AI

This notebook demonstrates the implementation of federated learning (FL), a decentralized approach to training AI models where data remains on local devices (clients) and only model updates are shared with a central server. We will use the Flower (Flwr) framework to simulate a federated learning setup with multiple clients and a central server. The example uses a simple neural network for classification on a synthetic dataset (or MNIST as an alternative), showcasing the core concepts of FL, including local training, model aggregation, and evaluation.

## 1. Import Libraries

Let's start by importing the necessary libraries for data handling, model building, federated learning, and visualization.

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
import flwr as fl
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

## 2. Load and Prepare Data

We will use a synthetic dataset generated by `make_classification` for demonstration purposes to simulate decentralized data across multiple clients. If you prefer a real dataset like MNIST, uncomment the relevant section. The data will be split into subsets to mimic different clients holding unique data.

In [None]:
# Generate a synthetic dataset for classification
X, y = make_classification(n_samples=5000, n_features=20, n_informative=15, n_redundant=5, random_state=42)

# Scale the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Convert to DataFrame for better handling
X = pd.DataFrame(X_scaled, columns=[f'feature_{i}' for i in range(X.shape[1])])
y = pd.Series(y, name='target')

# Display the first few rows of the data
print("First 5 rows of features:")
print(X.head())
print("\nFirst 5 rows of target:")
print(y.head())

# Simulate data distribution across multiple clients (e.g., 5 clients)
num_clients = 5
client_data = []
data_per_client = len(X) // num_clients

for i in range(num_clients):
    start_idx = i * data_per_client
    end_idx = (i + 1) * data_per_client if i < num_clients - 1 else len(X)
    client_X = X.iloc[start_idx:end_idx]
    client_y = y.iloc[start_idx:end_idx]
    client_data.append((client_X, client_y))
    print(f"Client {i+1} data shape: {client_X.shape}")

# Reserve a small test set for final evaluation (from the last client's data as an example)
X_test, y_test = client_data[-1]
print("Test set shape:", X_test.shape)

## 3. Define the Neural Network Model

We will create a simple feedforward neural network using TensorFlow/Keras for classification. This model will be used by each client for local training and by the server for aggregation.

In [None]:
def create_model(input_shape, num_classes):
    model = models.Sequential([
        layers.Dense(64, activation='relu', input_shape=input_shape),
        layers.Dropout(0.2),
        layers.Dense(32, activation='relu'),
        layers.Dropout(0.2),
        layers.Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# Test the model creation
sample_model = create_model(input_shape=(X.shape[1],), num_classes=2)
sample_model.summary()

## 4. Define Federated Learning Client

Using the Flower framework, we define a client class that handles local training and evaluation for each client. Each client will train on its own subset of data.

In [None]:
class FederatedClient(fl.client.NumPyClient):
    def __init__(self, model, X_train, y_train, X_val, y_val):
        self.model = model
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val

    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        history = self.model.fit(self.X_train, self.y_train,
                                 epochs=1, batch_size=32,
                                 validation_data=(self.X_val, self.y_val),
                                 verbose=0)
        return self.model.get_weights(), len(self.X_train), {'loss': float(history.history['loss'][0])}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.X_val, self.y_val, verbose=0)
        return loss, len(self.X_val), {'accuracy': float(accuracy)}

## 5. Simulate Clients and Server Setup

We will simulate multiple clients, each with their own data, and define a strategy for the server to aggregate updates using Federated Averaging (FedAvg). Since running a full server-client setup in a notebook can be complex, we simulate the process in a simplified manner. For a real deployment, Flower typically runs clients and servers in separate processes.

In [None]:
# Function to create client instances
def create_client_fn(cid):
    client_idx = int(cid)
    X_data, y_data = client_data[client_idx]
    # Split client data into train and validation
    X_train, X_val, y_train, y_val = train_test_split(X_data, y_data, test_size=0.2, random_state=42)
    model = create_model(input_shape=(X.shape[1],), num_classes=2)
    return FederatedClient(model, X_train, y_train, X_val, y_val)

# Define the federated learning strategy (FedAvg)
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,  # Use all clients for training
    fraction_evaluate=1.0,  # Use all clients for evaluation
    min_fit_clients=num_clients,
    min_evaluate_clients=num_clients,
    min_available_clients=num_clients
)

# Simulate federated learning rounds (for demonstration, we limit to a few rounds)
num_rounds = 3
print(f"Starting federated learning simulation for {num_rounds} rounds...")

# Since running a full FL server in a notebook is not straightforward, we simulate the process
# In a real setup, use fl.server.start_server() and fl.client.start_numpy_client()
# Here, we manually simulate the process for educational purposes

global_model = create_model(input_shape=(X.shape[1],), num_classes=2)
history = {'round': [], 'loss': [], 'accuracy': []}

for round in range(num_rounds):
    print(f"Round {round + 1}/{num_rounds}")
    client_weights = []
    client_sizes = []
    
    # Simulate client training
    for cid in range(num_clients):
        client = create_client_fn(str(cid))
        weights, size, _ = client.fit(global_model.get_weights(), {})
        client_weights.append(weights)
        client_sizes.append(size)
    
    # Aggregate weights using FedAvg (weighted average based on data size)
    total_size = sum(client_sizes)
    aggregated_weights = [np.zeros_like(w) for w in client_weights[0]]
    for idx, weights in enumerate(client_weights):
        for i, layer_weights in enumerate(weights):
            aggregated_weights[i] += layer_weights * (client_sizes[idx] / total_size)
    
    # Update global model
    global_model.set_weights(aggregated_weights)
    
    # Simulate evaluation on a client (using the last client's validation data as an example)
    loss, _, metrics = client.evaluate(aggregated_weights, {})
    history['round'].append(round + 1)
    history['loss'].append(loss)
    history['accuracy'].append(metrics['accuracy'])
    print(f"Round {round + 1} - Loss: {loss:.4f}, Accuracy: {metrics['accuracy']:.4f}")

## 6. Visualize Federated Learning Results

Let's plot the training progress over the rounds to see how the loss and accuracy evolved during federated learning.

In [None]:
# Plot loss and accuracy over rounds
plt.figure(figsize=(12, 5))

# Plot loss
plt.subplot(1, 2, 1)
plt.plot(history['round'], history['loss'], marker='o', color='red')
plt.title('Federated Learning Loss Over Rounds')
plt.xlabel('Round')
plt.ylabel('Loss')
plt.grid(True)

# Plot accuracy
plt.subplot(1, 2, 2)
plt.plot(history['round'], history['accuracy'], marker='o', color='blue')
plt.title('Federated Learning Accuracy Over Rounds')
plt.xlabel('Round')
plt.ylabel('Accuracy')
plt.grid(True)

plt.tight_layout()
plt.show()

## 7. Evaluate the Global Model on Test Data

Finally, we evaluate the aggregated global model on the reserved test set to assess its performance.

In [None]:
# Evaluate the global model on the test set
test_loss, test_accuracy = global_model.evaluate(X_test, y_test, verbose=0)
print(f"Global Model Test Loss: {test_loss:.4f}")
print(f"Global Model Test Accuracy: {test_accuracy:.4f}")

# Generate predictions for confusion matrix
y_pred = np.argmax(global_model.predict(X_test, verbose=0), axis=1)
cm = tf.math.confusion_matrix(y_test, y_pred)

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix for Global Model on Test Set')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()

## 8. Save the Global Model

We will save the final global model for future use or deployment.

In [None]:
# Save the global model
global_model.save('federated_global_model.h5')
print("Global model saved as 'federated_global_model.h5'")

## 9. Summary and Next Steps

In this notebook, we implemented a federated learning setup using the Flower framework (simulated in a notebook environment). Key aspects covered include:
- Simulating decentralized data across multiple clients.
- Local training on client data and aggregation of model updates using FedAvg.
- Visualization of training progress and evaluation of the global model.

**Limitations in Notebook Environment:**
- This notebook simulates the FL process manually due to the complexity of running a full Flower server and clients in a single notebook. In a real setup, you would run `fl.server.start_server()` and `fl.client.start_numpy_client()` in separate scripts or processes.

**Next Steps:**
- Deploy a full federated learning system with Flower using separate server and client scripts.
- Experiment with different datasets (e.g., MNIST, CIFAR-10) or model architectures.
- Explore advanced FL strategies (e.g., FedProx, differential privacy for enhanced security).
- Integrate with Web3 or Solana for decentralized AI in a blockchain context.