In [122]:
import json
import os
from glob import glob
import pandas as pd
import re
import plotly.graph_objs as go

# Utility Functions to read saved experimental statistics and plotting

In [123]:
def read_experiment_stats(regex):
    all_files = glob("./stats/*.json")
    files = []
    for f in all_files:
        temp = os.path.split(f)[-1]
        if re.search(regex, temp) is not None:
            files.append(f)
    # print(files)
    method_acc = []
    method_total_flops = []
    method_client_flops = []
    method_comm = []
    steps = []
    names = []

    # hparams = {}
    # for k in json.load(open(files[0], "r"))['hparams'].keys():
    #     hparams[k] = []
        

    for f in files:
        j = json.load(open(f, "r"))
        method_acc.append(j['Method Accuracy'])
        method_total_flops.append(j['Method Flops'])
        method_client_flops.append(j['Method Client Flops'])
        method_comm.append(j['Method Comm Cost'])
        steps.append(j['Steps'])
        names.append(os.path.split(f)[-1].replace(".json", ""))
        
        # for k, v in j['hparams'].items():
        #     hparams[k].append(v)

    df = pd.DataFrame(
        zip(names, method_acc, method_comm, method_total_flops, method_client_flops), 
        columns=["Experiment Name", "Method Accuracy", "Method Comm Cost", "Total Flops", "Client Flops"]
    )

    # Make main results table
    table = pd.DataFrame()
    table['Experiment Name'] = df['Experiment Name']
    table['Method Accuracy'] = df['Method Accuracy'].apply(lambda x: max(eval(x)))
    table['Client TFLOPS'] = df['Client Flops'].apply(lambda x: eval(x)[-1] * 1e-12)
    table['Total TFLOPS'] = df['Total Flops'].apply(lambda x: eval(x)[-1] * 1e-12)
    table['Comm Cost (GBs)'] = df['Method Comm Cost'].apply(lambda x: eval(x)[-1] * 1e-3)

    # Plot Accuracy Figure
    fig = go.Figure()
    for i in range(len(df)):
        fig.add_trace(
            go.Scatter(x=eval(steps[i]), y=eval(df['Method Accuracy'][i]), name=df['Experiment Name'][i])
        )
    fig.update_layout(title_text="",
            paper_bgcolor='rgba(255,255,255,1)',\
        plot_bgcolor='rgba(255,255,255,1)',)

    fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True,ticks='outside', showgrid=True, gridcolor="LightGray")
    fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True, ticks='outside', showgrid=True, gridcolor="LightGray")

    fig.update_layout(template=None, xaxis_title="Training Steps", yaxis_title="Per-Client Test Accuracy", font=dict(family='Times New Roman', size=14, color='Black'))

    return df, table, fig


# SplitCIFAR (non-i.i.d) Results

In [124]:
regex = r'cifar\_NIID'

df, table, fig = read_experiment_stats(regex)
table

Unnamed: 0,Experiment Name,Method Accuracy,Client TFLOPS,Total TFLOPS,Comm Cost (GBs)
0,cifar_NIID-Vanilla-10-1010,0.459333,140.530471,386.27899,420.872192
1,cifar_NIID-CESL-Random-60-100100,0.527,141.380813,178.461211,63.504384
2,cifar_NIID-CESL-30-100100,0.8625,141.380813,159.921012,31.752192
3,cifar_NIID-CESL-3-1010,0.913833,140.549161,158.9803,31.565414
4,cifar_NIID-CESL-Random-30-100100,0.392167,141.380813,159.921012,31.752192
5,cifar_NIID-CESL-60-100100,0.8846,141.380813,178.461211,63.504384
6,cifar_NIID-CESL-Random-6-1010,0.899667,140.549161,177.411439,63.130829
7,cifar_NIID-CESL-Random-3-1010,0.912833,140.549161,158.9803,31.565414
8,cifar_NIID-CESL-6-1010,0.927833,140.549161,177.411439,63.130829
9,cifar_NIID-Vanilla-100-100100,0.389167,141.362012,388.564664,423.36256


In [125]:
fig.show()

# CIFAR10 (i.i.d) Results

In [126]:
df, table, fig = read_experiment_stats(r'cifar\_IID')
table

Unnamed: 0,Experiment Name,Method Accuracy,Client TFLOPS,Total TFLOPS,Comm Cost (GBs)
0,cifar_IID-Vanilla-100-100100,0.889237,133.0466,365.707919,398.45888
1,cifar_IID-CESL-Random-30-100100,0.676645,133.064294,150.513893,29.8856
2,cifar_IID-CESL-30-100100,0.670717,133.064294,150.513893,29.8856
3,cifar_IID-CESL-Random-3-1010,0.70524,130.569339,147.691758,29.325245
4,cifar_IID-CESL-6-1010,0.71686,130.569339,164.814177,58.649234
5,cifar_IID-CESL-Random-6-1010,0.71378,130.569339,164.814177,58.649234
6,cifar_IID-CESL-3-1010,0.7087,130.569339,147.691758,29.325245
7,cifar_IID-CESL-Random-60-100100,0.705408,133.064294,167.963492,59.76992
8,cifar_IID-CESL-60-100100,0.685502,133.064294,167.963492,59.76992
9,cifar_IID-Vanilla-10-1010,0.89221,130.551976,358.850896,390.987776


In [127]:
fig.show()

# Interrupt Duration Ablation on SplitCIFAR
- All settings are the same as CESL-6/10 except for Interrupt duration.

In [128]:
df, table, fig = read_experiment_stats(r'interrupt\_range\_expt')
table

Unnamed: 0,Experiment Name,Method Accuracy,Client TFLOPS,Total TFLOPS,Comm Cost (GBs)
0,interrupt_range_expt_0.910,0.897167,140.549161,155.294072,25.252332
1,interrupt_range_expt_0.4510,0.829333,140.549161,221.646172,138.887823
2,interrupt_range_expt_0.7510,0.936667,140.549161,177.411439,63.130829
3,interrupt_range_expt_0.610,0.897333,140.549161,199.528805,101.009326
4,interrupt_range_expt_0.310,0.9185,140.549161,243.763539,176.766321


In [129]:
fig.show()

In [130]:
df, table, fig = read_experiment_stats(r'interrupt\_range\_iid')
table

Unnamed: 0,Experiment Name,Method Accuracy,Client TFLOPS,Total TFLOPS,Comm Cost (GBs)
0,interrupt_range_iid_expt_0.910,0.70995,130.569339,144.267274,23.459267
1,interrupt_range_iid_expt_0.7510,0.72246,130.569339,164.814177,58.648166
2,interrupt_range_iid_expt_0.610,0.72466,130.569339,185.36108,93.837066
3,interrupt_range_iid_expt_0.310,0.74686,130.569339,226.454885,164.214866
4,interrupt_range_iid_expt_0.4510,0.72992,130.569339,205.907982,129.025966


In [131]:
fig.show()