In [57]:
import seaborn as sns;sns.set_theme()
import matplotlib.pyplot as plt
# % matplotlib inline
import pandas as pd
import numpy as np

from typing import List

rounds = [1, 3, 5,10]
datasets = ["mnist", "cifar"]

class acc_data:
    samples: int
    test_loss: float
    acc: float

    def __init__(self, samples, test_loss, acc) -> None:
        self.samples = int(samples)
        self.test_loss = float(test_loss)
        self.acc = float(acc)

def load_acc_txt(file_name: str):
    with open(file_name) as f:
        content_lines = f.readlines()
        # print(len(content_lines))
        points = []
        all_acc = []
        for line in content_lines:
            items = line.split(",")
            nums = []
            for item in items:
                num = item.split(" ")[-1]
                nums.append(num)
            point = acc_data(nums[0], nums[1], nums[2])
            all_acc.append(point.acc)
            points.append(point)
        return points, all_acc
        # for point in points:
        #     print(point.samples, point.test_loss, point.acc)

_, all_acc = load_acc_txt("../singleclient/client/acc-data/noniid-one/mnist-round-1.txt")

prefix = "../singleclient/client/acc-data/noniid-one/"
prefix = {
    "noniid-one": "../singleclient/client/acc-data/noniid-one/",
    "noniid-one-rand": "../singleclient/client/acc-data/noniid-one-rand/",
    "noniid-two": "../singleclient/client/acc-data/noniid-two/"
}
suffix = ".txt"

def get_data_txt_paths(datasets, rounds, case="noniid-one"):
    data_txt_paths = []

    for dataset in datasets:
        for round in rounds:
            data_txt_paths.append(get_path(dataset, round, case))
    # print(data_txt_paths)
    return data_txt_paths

def get_path(dataset, round, case="noniid-one"):
    return prefix[case] + dataset + "-round-" + str(round) + suffix

def get_fed_avg_path(dataset, case = "noniid-one",sub_case="seqs"):
    ret = []
    for d in dataset:
        if sub_case != "":
            ret.append(prefix[case] + "fedavg-" + d + "0-" + sub_case + suffix)
        else:
            ret.append(prefix[case] + "fedavg-" + d + suffix)
    return ret

def get_df(datasets, rounds, case = "noniid-one", sub_case = "seqs"):
    if case.endswith("two"):
        data_paths = get_data_txt_paths(datasets, rounds, case)
        data_paths.extend(get_fed_avg_path(datasets, case=case, sub_case=""))
    else:
        data_paths = get_data_txt_paths(datasets, rounds, case)
        data_paths.extend(get_fed_avg_path(datasets, case=case, sub_case=sub_case))
    names = []
    for r in rounds:
        names.append("HC round = " + str(r))
    names.append("FedAvg")
    data_dfs = {}
    for idx, path in enumerate(data_paths):
        _, acc_data = load_acc_txt(path) 
        # data_dfs.append(pd.DataFrame({names[idx]:acc_data}))
        data_dfs[names[idx]] = acc_data
    data_dfs = pd.DataFrame(data_dfs)

    return data_dfs

cases = ["noniid-one", "noniid-one-rand", "noniid-two"]
def save_acc_figs():
    for case in cases:
        for dataset in datasets:
            if case.endswith("two") and dataset == "cifar":
                continue
            if case.endswith("rand"):
                dfs = get_df([dataset], rounds, case, sub_case="rand")
            else:
                dfs = get_df([dataset], rounds, case)
            sns.lineplot(data=dfs)
            plt.xlabel("Number of Rounds")
            plt.ylabel("Test Accuracy(%)")
            fig_name = dataset + "-" + case
            plt.savefig(fig_name, dpi=600)           
            plt.cla()

def get_improvement():
    res = []
    for case in cases:
        for dataset in datasets:
            if case.endswith("two") and dataset == "cifar":
                continue
            if case.endswith("rand"):
                dfs = get_df([dataset], rounds, case, sub_case="rand")
            else:
                dfs = get_df([dataset], rounds, case)

            print(dataset + '-' + case)
            after_hc_cmp = np.array([])
            after_hc_fedavg = np.array([])
            for r in rounds:
                name = "HC round = " + str(r)
                # print(dfs[name][r+1])
                after_hc_cmp = np.append(after_hc_cmp, dfs[name][r+1])
                after_hc_fedavg = np.append(after_hc_fedavg, dfs["FedAvg"][r+1])
            after_hc_cmp /= after_hc_fedavg
            res.append(after_hc_cmp)
            print(after_hc_cmp)
    return np.array(res[2:])



# dfs = get_df(["mnist"], rounds)
# sns.lineplot(data=dfs)
# plt.xlabel("Number of Rounds")
# plt.ylabel("Test Accuracy(%)")
# plt.show()
# plt.savefig("mnist-noniid-one", dpi=600)
# save_acc_figs()
res = get_improvement()
res = np.sort(res.flatten())
print("all res: ")
print(res)

        

mnist-noniid-one
[ 2.56374419  2.69540474  9.00326851 11.40061896]
cifar-noniid-one
[5.         5.26       4.87329435 5.        ]
mnist-noniid-one-rand
[1.33912669 1.2956196  1.14515399 1.12997993]
cifar-noniid-one-rand
[4.66076421 4.953      4.19798658 4.999     ]
mnist-noniid-two
[1.34109764 1.35472653 1.34170932 1.30958084]
all res: 
[1.12997993 1.14515399 1.2956196  1.30958084 1.33912669 1.34109764
 1.34170932 1.35472653 4.19798658 4.66076421 4.953      4.999     ]
