In [1]:
import re
import pandas as pd
import plotly.express as px

from glob import glob

In [2]:
pt_time = re.compile(r"Done epoch (\d): Time (\d+\.\d+)")
pt_dali = re.compile(r"da(\d+)")
pt_node = re.compile(r"node(\d+)")

In [3]:
logs = sorted(glob("./**/**/torch.log"))
rows = []

for log in logs:
    with open(log, "r") as f:
        lines = f.readlines()

    time_per_epoch = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
    for line in lines:
        sh_time = pt_time.search(line)
        if sh_time:
            time_per_epoch[int(sh_time.group(1))] = float(sh_time.group(2))

    dali = int(pt_dali.search(log).group(1))
    if dali == 0:
        dali = "off"
    elif dali == 1:
        dali = "cpu"
    else:
        dali = "gpu"

    node_num = int(pt_node.search(log).group(1))
    if node_num == 5:
        gpus = "H100 * 4"
    elif node_num == 7:
        gpus = "l40 * 4"
    elif node_num == 8:
        gpus = "a40 * 4"
    elif node_num == 9:
        gpus = "l4 * 4"

    rows.append(
        {
            "gpus": gpus,
            "storage": log.split("-")[3],
            "dali": dali,
            "epoch_1": round(time_per_epoch[1], 2),
            "epoch_2": round(time_per_epoch[2], 2),
            "epoch_3": round(time_per_epoch[3], 2),
            "epoch_4": round(time_per_epoch[4], 2),
            "epoch_5": round(time_per_epoch[5], 2),
        }
    )

In [4]:
df = pd.DataFrame(rows)
df.sort_values(by=["gpus", "dali"], ascending=[True, True], inplace=True)
df.reset_index(drop=True, inplace=True)
df["mean"] = df.iloc[:, [3, 4, 5, 6, 7]].mean(axis=1)
df["standard_deviation"] = df.iloc[:, [3, 4, 5, 6, 7]].std(axis=1)
df.to_csv("data_storage.csv", index=False)
df

Unnamed: 0,gpus,storage,dali,epoch_1,epoch_2,epoch_3,epoch_4,epoch_5,mean,standard_deviation
0,H100 * 4,local,cpu,272.3,267.38,267.18,265.94,265.15,267.59,2.787131
1,H100 * 4,ontap,cpu,337.24,328.43,326.03,328.4,326.68,329.356,4.53189
2,H100 * 4,local,gpu,228.54,214.52,214.21,214.18,212.58,216.806,6.603157
3,H100 * 4,ontap,gpu,302.28,303.1,301.31,300.98,301.77,301.888,0.835925
4,H100 * 4,local,off,415.46,406.8,397.05,397.92,392.85,402.016,9.070305
5,H100 * 4,ontap,off,437.13,431.06,428.32,426.55,427.0,430.012,4.349847
6,a40 * 4,local,cpu,451.48,446.35,446.03,445.94,444.95,446.95,2.585894
7,a40 * 4,ontap,cpu,447.8,442.66,442.52,443.67,442.89,443.908,2.220624
8,a40 * 4,local,gpu,513.77,505.92,504.56,504.97,507.28,507.3,3.765043
9,a40 * 4,ontap,gpu,519.76,511.65,512.18,511.88,510.79,513.252,3.674693


In [None]:
fig = px.line(
    df,
    x="number_of_nodes",
    y="mean",
    color="dali",
    error_y="standard_deviation",
    markers=True,
    labels={"mean": "elapsed_time_per_epoch"},
)
fig.show()