In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# ======== Load CSV files ========
base = pd.read_csv("base.csv")
fedavg = pd.read_csv("fedavg.csv")
fedprox = pd.read_csv("fedprox.csv")

# ======== Normalize column names ========
base.columns = [c.strip().lower() for c in base.columns]
fedavg.columns = [c.strip().lower() for c in fedavg.columns]
fedprox.columns = [c.strip().lower() for c in fedprox.columns]

# ======== Extract metrics ========
base_rounds = base.index + 1
fedavg_rounds = fedavg["round"]
fedprox_rounds = fedprox["round"]

# Training Loss
base_loss = base["loss"]
fedavg_loss = fedavg["train_loss"]
fedprox_loss = fedprox["train_loss"]

# Accuracy
base_acc = base["accuracy"]
fedavg_acc = fedavg["accuracy"]
fedprox_acc = fedprox["accuracy"]

# ======== Plot ========
plt.figure(figsize=(12, 5))

# --- Plot 1: Training Loss ---
plt.subplot(1, 2, 1)
plt.plot(base_rounds, base_loss, label="Base", marker="o")
plt.plot(fedavg_rounds, fedavg_loss, label="FedAvg", marker="s")
plt.plot(fedprox_rounds, fedprox_loss, label="FedProx", marker="^")
plt.title("Training Loss Comparison")
plt.xlabel("Round / Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)

# --- Plot 2: Accuracy ---
plt.subplot(1, 2, 2)
plt.plot(base_rounds, base_acc, label="Base", marker="o")
plt.plot(fedavg_rounds, fedavg_acc, label="FedAvg", marker="s")
plt.plot(fedprox_rounds, fedprox_acc, label="FedProx", marker="^")
plt.title("Accuracy Comparison")
plt.xlabel("Round / Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


: 