## 1. Setup and Configuration

In [None]:
# Import required modules
import numpy as np
import tensorflow as tf
import warnings
warnings.filterwarnings('ignore')

# Import project modules
from config import ExperimentConfig
from dataset import generate_synthetic_dataset, partition_data_for_clients, create_centralized_datasets, get_dataset_statistics
from model import create_compiled_model, get_model_summary, count_model_parameters
from centralized import CentralizedTrainer, run_centralized_experiment
from visualization import plot_training_convergence, plot_final_comparison, plot_client_participation, create_performance_summary_table, generate_all_visualizations

# Try importing federated module (may fail if TensorFlow Federated not available)
TFF_AVAILABLE = False
try:
    from federated import FederatedTrainer, run_federated_experiment
    TFF_AVAILABLE = True
    print("✅ TensorFlow Federated available - Full functionality enabled")
except ImportError as e:
    print("⚠️  TensorFlow Federated not available - Federated experiments will be skipped")
    print(f"   Error: {e}")
    print("   Note: This is expected on Python 3.12. Use Python 3.9-3.11 for full functionality.")

print("\n✅ All available modules loaded successfully")
print(f"TensorFlow version: {tf.__version__}")
print(f"NumPy version: {np.__version__}")

In [None]:
# Display experiment configuration
print("="*70)
print("EXPERIMENT CONFIGURATION")
print("="*70)

config = ExperimentConfig.get_config_summary()
for key, value in config.items():
    print(f"{key:.<30} {value}")

print("="*70)

## 2. Dataset Generation and Analysis

In [None]:
# Set random seeds for reproducibility
ExperimentConfig.set_random_seeds()

# Generate synthetic dataset
print("Generating synthetic dataset...")
X, y = generate_synthetic_dataset()

print(f"\nDataset Properties:")
print(f"  - Total samples: {len(X)}")
print(f"  - Feature dimension: {X.shape[1]}")
print(f"  - Number of classes: {len(np.unique(y))}")
print(f"  - Class distribution: {np.bincount(y)}")
print(f"  - Class balance: {np.bincount(y) / len(y)}")

In [None]:
# Partition data for federated clients
print("Partitioning data for federated clients...")
client_datasets = partition_data_for_clients(X, y, ExperimentConfig.NUM_CLIENTS)

# Analyze client data distribution
stats = get_dataset_statistics(client_datasets)

print(f"\nFederated Client Statistics:")
print(f"  - Number of clients: {stats['num_clients']}")
print(f"  - Total samples distributed: {stats['total_samples']}")
print(f"  - Samples per client (min/max/mean): {min(stats['samples_per_client'])}/{max(stats['samples_per_client'])}/{np.mean(stats['samples_per_client']):.1f}")

print(f"\nClass Distribution per Client (non-IID):")
for i, dist in enumerate(stats['class_distribution_per_client'][:5]):  # Show first 5 clients
    print(f"  Client {i}: {dist}")
print(f"  ... (showing first 5 of {stats['num_clients']} clients)")

In [None]:
# Create centralized dataset splits
print("Creating centralized train/validation/test splits...")
centralized_data = create_centralized_datasets(X, y)

X_train, y_train = centralized_data['train']
X_val, y_val = centralized_data['validation']
X_test, y_test = centralized_data['test']

print(f"\nCentralized Splits:")
print(f"  - Training: {X_train.shape}")
print(f"  - Validation: {X_val.shape}")
print(f"  - Test: {X_test.shape}")

## 3. Model Architecture

In [None]:
# Display model architecture
print("Model Architecture:")
print(get_model_summary())

# Count parameters
params = count_model_parameters()
print(f"\nModel Parameters:")
print(f"  - Total: {params['total']:,}")
print(f"  - Trainable: {params['trainable']:,}")
print(f"  - Non-trainable: {params['non_trainable']:,}")

## 4. Centralized Training Experiment

In [None]:
# Run centralized training
print("\n" + "="*70)
print("STARTING CENTRALIZED TRAINING")
print("="*70 + "\n")

# Combine train and validation for full training
X_full_train = np.vstack([X_train, X_val])
y_full_train = np.concatenate([y_train, y_val])

# Create trainer
centralized_trainer = CentralizedTrainer()

# Train model
centralized_history = centralized_trainer.train(
    X_full_train, y_full_train,
    epochs=ExperimentConfig.CENTRALIZED_EPOCHS,
    batch_size=ExperimentConfig.BATCH_SIZE,
    validation_split=ExperimentConfig.VALIDATION_SPLIT,
    verbose=0  # Suppress epoch-by-epoch output
)

print("\nCentralized training complete!")

In [None]:
# Evaluate centralized model on test set
centralized_test_metrics = centralized_trainer.evaluate(X_test, y_test)

print("\nCentralized Model - Test Set Results:")
print("="*50)
for metric, value in centralized_test_metrics.items():
    print(f"{metric:.<35} {value:.4f}")
print("="*50)

In [None]:
# Save centralized results
centralized_results_path = centralized_trainer.save_results(centralized_test_metrics)
centralized_model_path = centralized_trainer.save_model()

print(f"\nCentralized results saved to: {centralized_results_path}")
print(f"Centralized model saved to: {centralized_model_path}")

## 5. Federated Learning Experiment

In [None]:
# Run federated training (skip if TFF not available)
if not TFF_AVAILABLE:
    print("⚠️  SKIPPING FEDERATED TRAINING - TensorFlow Federated not available")
    print("   To run federated experiments, use Python 3.9-3.11 with:")
    print("   pip install tensorflow-federated")
    federated_trainer = None
    federated_metrics = None
else:
    print("\n" + "="*70)
    print("STARTING FEDERATED TRAINING")
    print("="*70 + "\n")

    # Create federated trainer
    federated_trainer = FederatedTrainer(
        client_datasets=client_datasets,
        test_data=(X_test, y_test)
    )

    # Train federated model
    federated_metrics = federated_trainer.train(
        num_rounds=ExperimentConfig.NUM_ROUNDS,
        client_fraction=ExperimentConfig.CLIENT_FRACTION,
        local_epochs=ExperimentConfig.LOCAL_EPOCHS
    )

    print("\nFederated training complete!")

In [None]:
# Get federated final metrics
if TFF_AVAILABLE and federated_trainer:
    federated_final_metrics = federated_trainer.get_final_test_metrics()

    print("\nFederated Model - Final Test Set Results:")
    print("="*50)
    for metric, value in federated_final_metrics.items():
        print(f"{metric:.<35} {value if isinstance(value, int) else f'{value:.4f}'}")
    print("="*50)
else:
    print("⚠️  Skipped - TensorFlow Federated not available")
    federated_final_metrics = None

In [None]:
# Save federated results
if TFF_AVAILABLE and federated_trainer:
    federated_results_path = federated_trainer.save_results()
    print(f"\nFederated results saved to: {federated_results_path}")
else:
    print("⚠️  Skipped - TensorFlow Federated not available")
    federated_results_path = None

## 6. Results Comparison and Visualization

In [None]:
# Generate performance summary table
if TFF_AVAILABLE and federated_final_metrics:
    summary_table = create_performance_summary_table(
        centralized_test_metrics,
        federated_final_metrics
    )
    print(summary_table)
else:
    print("⚠️  Skipped - TensorFlow Federated not available")
    print("\nCentralized Model Performance:")
    print("="*50)
    for metric, value in centralized_test_metrics.items():
        print(f"{metric:.<35} {value:.4f}")
    print("="*50)

In [None]:
# Plot training convergence
import matplotlib.pyplot as plt

if TFF_AVAILABLE and federated_metrics:
    centralized_training_metrics = centralized_trainer.get_training_metrics()

    fig = plot_training_convergence(
        centralized_training_metrics,
        federated_metrics
    )

    plt.show()
else:
    print("⚠️  Skipping convergence comparison - TensorFlow Federated not available")
    print("   Can only show centralized training metrics")
    
    # Plot centralized training only
    history = centralized_trainer.get_training_metrics()
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot accuracy
    ax1.plot(history['accuracy'], label='Train Accuracy', linewidth=2)
    ax1.plot(history['val_accuracy'], label='Val Accuracy', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.set_title('Centralized Training - Accuracy')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot loss
    ax2.plot(history['loss'], label='Train Loss', linewidth=2)
    ax2.plot(history['val_loss'], label='Val Loss', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.set_title('Centralized Training - Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Plot final performance comparison
if TFF_AVAILABLE and federated_final_metrics:
    fig = plot_final_comparison(
        centralized_test_metrics,
        federated_final_metrics
    )
    plt.show()
else:
    print("⚠️  Skipped - TensorFlow Federated not available")
    print("   Cannot compare without federated results")

In [None]:
# Plot client participation heatmap
if TFF_AVAILABLE and federated_metrics:
    fig = plot_client_participation(federated_metrics)
    plt.show()
else:
    print("⚠️  Skipped - TensorFlow Federated not available")

In [None]:
# Generate and save all visualizations
if TFF_AVAILABLE and federated_results_path:
    print("Generating all visualizations...")

    figures = generate_all_visualizations(
        centralized_results_path,
        federated_results_path
    )

    print("\nGenerated figures:")
    for name, path in figures.items():
        print(f"  - {name}: {path}")
else:
    print("⚠️  Skipped - TensorFlow Federated not available")
    print("   Cannot generate comparison visualizations without federated results")

## 7. Experiment Analysis and Insights

### Key Questions to Address:

1. **Convergence Behavior**
   - How does federated learning convergence compare to centralized?
   - Are there oscillations or instability in federated training?
   - How many rounds are needed to reach comparable performance?

2. **Generalization Performance**
   - What is the accuracy gap between centralized and federated?
   - Is the gap acceptable for privacy-preserving scenarios?
   - How does test loss compare?

3. **Client Heterogeneity Impact**
   - How does non-IID data distribution affect learning?
   - Are certain clients more influential?
   - Does client sampling strategy matter?

4. **Practical Implications**
   - When is federated learning a viable alternative?
   - What are the trade-offs in communication vs performance?
   - How can we improve federated learning performance?

### Next Steps for Research:

- Experiment with different client fractions
- Vary the degree of non-IID data
- Test different aggregation strategies
- Analyze communication efficiency
- Implement differential privacy mechanisms
- Compare with other federated algorithms (FedProx, FedAdam, etc.)

---

## 8. Experiment Summary

In [None]:
# Print final experiment summary
print("="*70)
print("EXPERIMENT COMPLETE")
print("="*70)
print("\nExperiment Configuration:")
print(f"  - Dataset size: {len(X)}")
print(f"  - Number of clients: {ExperimentConfig.NUM_CLIENTS}")
print(f"  - Federated rounds: {ExperimentConfig.NUM_ROUNDS}")
print(f"  - Centralized epochs: {ExperimentConfig.CENTRALIZED_EPOCHS}")
print(f"  - Client fraction: {ExperimentConfig.CLIENT_FRACTION}")
print(f"  - Local epochs: {ExperimentConfig.LOCAL_EPOCHS}")

print("\nFinal Results:")
print(f"  Centralized Test Accuracy: {centralized_test_metrics['test_accuracy']:.4f}")

if TFF_AVAILABLE and federated_final_metrics:
    print(f"  Federated Test Accuracy:   {federated_final_metrics['test_accuracy']:.4f}")
    print(f"  Performance Gap:           {abs(centralized_test_metrics['test_accuracy'] - federated_final_metrics['test_accuracy']):.4f}")
else:
    print(f"  Federated Test Accuracy:   N/A (TensorFlow Federated not available)")
    print(f"  Performance Gap:           N/A")

print("\nArtifacts Generated:")
print(f"  - Centralized results: {centralized_results_path}")
print(f"  - Centralized model: {centralized_model_path}")

if TFF_AVAILABLE and federated_results_path:
    print(f"  - Federated results: {federated_results_path}")
    print(f"  - Visualizations: {ExperimentConfig.FIGURES_DIR}")
else:
    print(f"  - Federated results: N/A")
    print(f"  - Visualizations: Limited (centralized only)")

if not TFF_AVAILABLE:
    print("\n" + "⚠️ "*35)
    print("NOTE: Federated learning experiments were skipped")
    print("      TensorFlow Federated is not available in this environment")
    print("      To run full experiments, use Python 3.9-3.11 with:")
    print("      pip install tensorflow-federated")
    print("⚠️ "*35)

print("\n" + "="*70)
print("Centralized training results saved and ready for analysis")
print("="*70)