In [None]:
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt

def plot_shakespeare_results(folder=".", out_folder="figs"):
    os.makedirs(out_folder, exist_ok=True)

    # Find all result CSVs in the folder
    files = glob.glob(os.path.join(folder, "results_only_split_by_role_gamma_uniform_nc_*_J_*.csv"))

    # Group files by Nc (number of classes per client)
    grouped = {}
    for f in files:
        base = os.path.basename(f)
        parts = base.replace(".csv", "").split("_")
        # find indices explicitly
        if "nc" in parts and "J" in parts:
            nc_idx = parts.index("nc") + 1
            j_idx = parts.index("J") + 1
            nc = parts[nc_idx]
            j = parts[j_idx]
            grouped.setdefault(nc, []).append((int(j), f))


    # Plot accuracy and loss vs rounds per Nc
    for nc, entries in grouped.items():
        entries.sort()  # sort by J
        plt.figure(figsize=(10, 6))
        for j, f in entries:
            df = pd.read_csv(f)
            plt.plot(df["round"], df["test_acc"], label=f"J={j}")
        plt.xlabel("Rounds")
        plt.ylabel("Test Accuracy")
        plt.title(f"Shakespeare – Test Accuracy vs Rounds (Nc={nc})")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(out_folder, f"shakespeare_acc_vs_rounds_{nc}.png"))
        plt.close()

        plt.figure(figsize=(10, 6))
        for j, f in entries:
            df = pd.read_csv(f)
            plt.plot(df["round"], df["test_loss"], label=f"J={j}")
        plt.xlabel("Rounds")
        plt.ylabel("Test Loss")
        plt.title(f"Shakespeare – Test Loss vs Rounds (Nc={nc})")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(out_folder, f"shakespeare_loss_vs_rounds_{nc}.png"))
        plt.close()

    # Final summary across Nc and J
    final_acc = []
    final_loss = []
    for nc, entries in grouped.items():
        for j, f in entries:
            df = pd.read_csv(f)
            row = df.iloc[-1]  # last round
            final_acc.append({"Nc": nc, "J": j, "TestAcc": row["test_acc"]})
            final_loss.append({"Nc": nc, "J": j, "TestLoss": row["test_loss"]})

    acc_df = pd.DataFrame(final_acc)
    loss_df = pd.DataFrame(final_loss)

    # Accuracy vs J for each Nc
    plt.figure(figsize=(10, 6))
    for nc in sorted(acc_df["Nc"].unique(), key=lambda x: int(x)):
        sub = acc_df[acc_df["Nc"] == nc]
        plt.plot(sub["J"], sub["TestAcc"], marker="o", label=f"Nc={nc}")
    plt.xlabel("Local steps J")
    plt.ylabel("Final Test Accuracy")
    plt.title("Final Test Accuracy vs J")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(out_folder, "final_acc_vs_j.png"))
    plt.close()

    # Loss vs J for each Nc
    plt.figure(figsize=(10, 6))
    for nc in sorted(loss_df["Nc"].unique(), key=lambda x: int(x)):
        sub = loss_df[loss_df["Nc"] == nc]
        plt.plot(sub["J"], sub["TestLoss"], marker="o", label=f"Nc={nc}")
    plt.xlabel("Local steps J")
    plt.ylabel("Final Test Loss")
    plt.title("Final Test Loss vs J")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(out_folder, "final_loss_vs_j.png"))
    plt.close()


    plt.figure(figsize=(10, 6))
    for nc in sorted(acc_df["Nc"].unique(), key=lambda x: int(x)):
        sub = acc_df[acc_df["Nc"] == nc]
        plt.plot(sub["J"], sub["TestAcc"], marker="o", label=f"Nc={nc}")
    plt.xlabel("Local steps J")
    plt.ylabel("Final Test Accuracy")
    plt.title("Shakespeare – Final Test Accuracy vs J")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(out_folder, "shakespeare_final_acc_vs_j.png"))
    plt.close()



    plt.figure(figsize=(10, 6))
    for nc in sorted(loss_df["Nc"].unique(), key=lambda x: int(x)):
        sub = loss_df[loss_df["Nc"] == nc]
        plt.plot(sub["J"], sub["TestLoss"], marker="o", label=f"Nc={nc}")
    plt.xlabel("Local steps J")
    plt.ylabel("Final Test Loss")
    plt.title("Shakespeare – Final Test Loss vs J")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(out_folder, "shakespeare_final_loss_vs_j.png"))
    plt.close()

    print(f"Plots saved in {out_folder}/")

# Run this
plot_shakespeare_results(folder=".", out_folder="figs")


Plots saved in figs/
