In [None]:
import os
import json
import ast
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np

# =========================
# CONFIG
# =========================
DATA_DIR = "../src/plots"
NAME_MAP_FILE = "algorithm_names.json"
target_accuracy = 0.794
dataset_name = "fashion_mnist"
total_clients = 5


# =========================
# HELPERS
# =========================
def load_dict(filepath):
    with open(filepath, "r") as f:
        content = f.read()
        try:
            return json.loads(content)
        except json.JSONDecodeError:
            return ast.literal_eval(content)


def parse_filename(filename):
    """
    default_5_7_claudiogsc_emnist_balance
    returns:
    algo=default, clients=5, classes=7, dataset=claudiogsc_emnist_balanced
    """
    parts = filename.split("_")
    return parts[0], int(parts[1]), int(parts[2]), "_".join(parts[3:])


# =========================
# COLLECT FILES
# =========================
results_files = {}
round_time_files = {}

for file in os.listdir(DATA_DIR):
    if file.endswith("_results.json"):
        key = file.replace("_results.json", "")
        results_files[key] = os.path.join(DATA_DIR, file)
    elif file.endswith("_round_times.json"):
        key = file.replace("_round_times.json", "")
        round_time_files[key] = os.path.join(DATA_DIR, file)

common_keys = set(results_files.keys()) & set(round_time_files.keys())

# =========================
# PROCESS DATA
# =========================
algo_data = {}

for key in common_keys:
    algo, clients, classes, dataset = parse_filename(key)

    if dataset != dataset_name or clients != total_clients:
        continue

    unique_algo_key = f"{algo}_{classes}"

    results = load_dict(results_files[key])
    times = load_dict(round_time_files[key])

    round_acc = defaultdict(list)

    for client_id, rounds in results.items():
        for rnd, data in rounds.items():
            round_acc[int(rnd)].append(data["metrics"]["accuracy"])

    rounds_sorted = sorted(round_acc.keys())

    avg_accuracy = []
    avg_time = []
    cumulative_time = 0
    final_round = 0
    for rnd in rounds_sorted:
        avg_accu = np.mean(round_acc[rnd])
        final_round = rnd

        if avg_accu >= target_accuracy:
            avg_accuracy.append(avg_accu)
            cumulative_time += times[str(rnd)]
            avg_time.append(cumulative_time)
            break

        avg_accuracy.append(avg_accu)
        cumulative_time += times[str(rnd)]
        avg_time.append(cumulative_time)

    algo_data[unique_algo_key] = {
        "algo": algo,
        "clients": clients,
        "classes": classes,
        "final_round": final_round,
        "accuracy": np.array(avg_accuracy),
        "time": np.array(avg_time),
    }