In [None]:
# Import Required Libraries
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import matplotlib.pyplot as plt
import sys
import os

# Add src directory to path
sys.path.append('../src')
from generate_data import generate_synthetic_data

## Generate Synthetic Dataset

Create synthetic datasets for multiple clients, simulating private data distributions.

In [None]:
# Parameters
num_clients = 10
samples_per_client = 1000
num_features = 10
num_classes = 2

# Generate synthetic data
client_data = generate_synthetic_data(num_clients, samples_per_client, num_features, num_classes)

print(f"Generated data for {len(client_data)} clients")
print(f"Each client has {client_data[0][0].shape[0]} samples with {client_data[0][0].shape[1]} features")

## Define Global Model

Define a simple neural network model (e.g., using Keras) that will be trained centrally and federatedly.

In [None]:
def create_keras_model():
    """Create a simple neural network model"""
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(num_features,)),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    return model

def model_fn():
    """TFF model function"""
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=(tf.TensorSpec(shape=[None, num_features], dtype=tf.float32),
                   tf.TensorSpec(shape=[None], dtype=tf.int32)),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=[tf.keras.metrics.BinaryAccuracy()]
    )

def create_tf_dataset(x, y, batch_size=32):
    """Create a TensorFlow dataset from numpy arrays"""
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    return dataset.shuffle(buffer_size=1000).batch(batch_size)

## Set Up Centralized Training

Prepare the dataset for centralized training by aggregating all client data.

In [None]:
# Combine all client data for centralized training
all_x = np.concatenate([x for x, y in client_data])
all_y = np.concatenate([y for x, y in client_data])

print(f"Centralized dataset: {all_x.shape[0]} samples, {all_x.shape[1]} features")

## Execute Centralized Training

Train the model centrally on the aggregated dataset and log performance metrics.

In [None]:
def centralized_training(all_x, all_y, num_epochs=10):
    """Perform centralized training for comparison"""
    model = create_keras_model()
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    history = model.fit(all_x, all_y, epochs=num_epochs, batch_size=32, verbose=1)
    
    return model, history.history['loss'], history.history['accuracy']

# Run centralized training
num_epochs = 20
cent_model, cent_losses, cent_accuracies = centralized_training(all_x, all_y, num_epochs)

## Set Up Federated Learning Environment

Initialize TensorFlow Federated components, including client datasets and federated averaging process.

In [None]:
# Create federated data
federated_data = [create_tf_dataset(x, y) for x, y in client_data]

# Initialize the federated learning process
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()

print("Federated learning environment initialized")

## Execute Federated Training

Run federated training rounds with parameterizable client count, updating the global model iteratively.

In [None]:
def federated_training(federated_data, state, iterative_process, num_rounds=10, num_clients_per_round=5):
    """Perform federated training"""
    training_losses = []
    training_accuracies = []
    
    for round_num in range(1, num_rounds + 1):
        # Select random clients for this round
        selected_clients = np.random.choice(len(federated_data), num_clients_per_round, replace=False)
        round_data = [federated_data[i] for i in selected_clients]
        
        # Perform one round of federated training
        state, metrics = iterative_process.next(state, round_data)
        
        training_losses.append(metrics['train']['loss'])
        training_accuracies.append(metrics['train']['binary_accuracy'])
        
        print(f'Round {round_num}: loss={metrics["train"]["loss"]:.4f}, '
              f'accuracy={metrics["train"]["binary_accuracy"]:.4f}')
    
    return state, training_losses, training_accuracies

# Run federated training
num_rounds = 20
num_clients_per_round = 5
fed_state, fed_losses, fed_accuracies = federated_training(
    federated_data, state, iterative_process, num_rounds, num_clients_per_round
)

## Compare Performance Metrics

Collect and compare training logs from centralized and federated approaches, including accuracy and loss.

In [None]:
print("Federated Training Results:")
print(f"Final Loss: {fed_losses[-1]:.4f}, Final Accuracy: {fed_accuracies[-1]:.4f}")

print("\nCentralized Training Results:")
print(f"Final Loss: {cent_losses[-1]:.4f}, Final Accuracy: {cent_accuracies[-1]:.4f}")

# Calculate improvement
fed_final_acc = fed_accuracies[-1]
cent_final_acc = cent_accuracies[-1]
improvement = ((fed_final_acc - cent_final_acc) / cent_final_acc) * 100
print(f"\nAccuracy Improvement: {improvement:.2f}%")

## Visualize Results

Generate charts comparing performance metrics between centralized and federated training.

In [None]:
def plot_comparison(fed_losses, fed_accuracies, cent_losses, cent_accuracies):
    """Plot performance comparison"""
    rounds = range(1, len(fed_losses) + 1)
    epochs = range(1, len(cent_losses) + 1)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Loss comparison
    ax1.plot(rounds, fed_losses, label='Federated', marker='o')
    ax1.plot(epochs, cent_losses, label='Centralized', marker='s')
    ax1.set_xlabel('Training Rounds/Epochs')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss Comparison')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy comparison
    ax2.plot(rounds, fed_accuracies, label='Federated', marker='o')
    ax2.plot(epochs, cent_accuracies, label='Centralized', marker='s')
    ax2.set_xlabel('Training Rounds/Epochs')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Training Accuracy Comparison')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig('../federated_vs_centralized_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

# Plot comparison
plot_comparison(fed_losses, fed_accuracies, cent_losses, cent_accuracies)

print("Federated Learning Simulation Complete!")
print("Check federated_vs_centralized_comparison.png for the comparison chart.")