In [None]:
# -----------------------------------------------------------------------------
# BatteryMind – Federated Learning Demo
# -----------------------------------------------------------------------------
# Cell 1 – Imports & Configuration
import json, time, logging, numpy as np, pandas as pd
from ai_models.federated_learning.simulation_framework.federated_simulator import FederatedSimulator
from ai_models.federated_learning.server.federated_server import FederatedServer
from ai_models.federated_learning.client_models.client_manager import ClientManager
from ai_models.training_data.synthetic_datasets import generate_battery_telemetry_data

logging.basicConfig(level=logging.INFO)

NUM_CLIENTS = 10
ROUNDS = 20
LOCAL_EPOCHS = 3

# Cell 2 – Generate Per-Client Data
client_datasets = {}
for cid in range(NUM_CLIENTS):
    csv_path = generate_battery_telemetry_data(
        num_batteries=50, duration_days=30,
        output_path=f"./client_{cid}_data.csv"
    )
    client_datasets[cid] = csv_path

# Cell 3 – Instantiate Server & Clients
server = FederatedServer(aggregation_algorithm="FedAvg", dp_noise=True)
clients = ClientManager(
    num_clients=NUM_CLIENTS,
    local_epochs=LOCAL_EPOCHS,
    client_dataset_paths=client_datasets
)

sim = FederatedSimulator(server, clients, rounds=ROUNDS)

# Cell 4 – Run Simulation
history = sim.run()
history_df = pd.DataFrame(history)
history_df.head()

# Cell 5 – Visualise Global Accuracy
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=history_df["round"], y=history_df["global_accuracy"],
    mode="lines+markers", name="Global Accuracy"
))
fig.update_layout(title="Federated Learning Convergence")
fig.show()

# Cell 6 – Compare With Centralised Model
from ai_models.transformers.battery_health_predictor.trainer import (
    train_centralised_transformer
)
centralised_metrics = train_centralised_transformer(
    csv_path="training-data/synthetic_datasets/battery_telemetry.csv",
    epochs=LOCAL_EPOCHS * ROUNDS
)
print("Centralised accuracy:", centralised_metrics["val_accuracy"])
