In [None]:
import csv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import os

def parse_data_csv(data_path:str, model:str):
    """Parses data.csv, filtering rows based on the specified model.

    Args:
        model (str): Either "sage" or "gat"

    Returns:
        list: A list of dictionaries, each representing data for a unique graph.
    """
    graph2durations = {
           "friendster": {
               "Sample": np.array([]),
               "Load": np.array([]),
               "Train": np.array([])
           },
           "orkut": {
               "Sample": np.array([]),
               "Load": np.array([]),
               "Train": np.array([])
           },
           "papers100M": {
               "Sample": np.array([]),
               "Load": np.array([]),
               "Train": np.array([]),
           }
    }
    for system in ["dgl", "p3", "quiver"]:
        with open(data_path, 'r') as file:
            reader = csv.DictReader(file)
            for row in reader:
                if row['model'] == model and row['system'] == system:
                    graph_name = row['graph_name']
                    graph_data = graph2durations[graph_name]
                    graph_data['Sample'] = np.append(graph_data['Sample'], float(row['sampling (s)']))
                    graph_data['Load'] = np.append(graph_data['Load'], float(row['feature (s)']))
                    graph_data['Train'] = np.append(graph_data['Train'], round(float(row['forward (s)']) + float(row['backward (s)']), 1))

    return graph2durations


In [None]:

model = "sage"
data_path="../experiment/logs/main.csv"
graph2durations = parse_data_csv(data_path, model)

systems = (
    "DGL",
    "P3*",
    "Quiver",
)
item2color = {"Sample": "tab:green",
             "Train": "tab:blue",
             "Load": "tab:orange"}

font = {'size' : 50, "family": "serif"} 
plt.rc('font', **font)
nsys=3
width = 0.7
fig, axes = plt.subplots(ncols=3, figsize=(30, 8))
for idx, graph in enumerate(["friendster","papers100M", "orkut" ]):
    offset = 0
    # fig, ax = plt.subplots(figsize=(6, 4))
    ax = axes[idx]
    x = np.arange(nsys)
    bottom = np.zeros(nsys)
    durations = graph2durations[graph]
    e2e = durations["Load"] + durations["Sample"] + durations["Train"]
    for item, d in durations.items():
        p = ax.bar(systems, d, width, label=item, bottom=bottom, color=item2color[item], alpha=0.7)
        y_height = bottom + offset * d / 4
        y_percent = d / e2e * 100

        # if item == "Load":
        #     for sidx in range(nsys):
        #         s = f"{round(y_percent[sidx])}%"
        #         if (y_percent[sidx] < 30 and offset == 2):
        #             y_height[sidx] -= offset * d[sidx] / 4
        #         plt.text(x=sidx - 0.4 * width, y=y_height[sidx], s=s, fontdict={"size":18})
        bottom += d
        offset += 1
    
    ax.set_title(f"{graph.title()}")
    asp = np.diff(ax.get_xlim())[0] / np.diff(ax.get_ylim())[0]
    ax.set_aspect(asp)
    if idx == 0:
        ax.set_ylabel("Epoch Time (s)")
    # ax.legend(loc="upper left")
    # if idx == 1:
    #     ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), fancybox=False, shadow=False, ncol=3,  prop = { "size": 30 })
#ax.legend(loc='upper center', bbox_to_anchor=(-1.5, 0), fancybox=False, shadow=False, ncol=3, prop = { "size": 18 })

patches = []
patches.append(mpatches.Patch(color=item2color["Sample"], label='Sample', alpha=0.7))
patches.append(mpatches.Patch(color=item2color["Load"], label='Load', alpha=0.7))
patches.append(mpatches.Patch(color=item2color["Train"], label='Forward & Backward', alpha=0.7))
plt.legend(handles=patches, ncol=len(patches), bbox_to_anchor=(1,-0.08))

os.makedirs("plots", exist_ok=True)
# plt.tight_layout()
plt.savefig(f"plots/{model}_epoch_breakdown.png", dpi=300,bbox_inches="tight" )
plt.show()