In [34]:
import os
import json
import pandas as pd
import altair as alt

In [35]:
def load_reference_jsons(root_path, filter = ""):
    records = []

    for dirpath, _, filenames in os.walk(root_path):
        if filter in dirpath:
            for file in filenames:
                if file.endswith(".json"):
                    file_path = os.path.join(dirpath, file)
                    try:
                        with open(file_path, 'r') as f:
                            data = json.load(f)
                            records.append(data)
                    except Exception as e:
                        print(f"Error loading {file_path}: {e}")

    return pd.DataFrame(records)

In [36]:
root_path = "../../../slurm_logs/latest"
df = load_reference_jsons(root_path)
df

Unnamed: 0,world_size,m,n,k,debug,validate,trace_tiles,benchmark,datatype,algorithm,...,streamk_registers,streamk_spills,success,success_partial,triton_tflops,triton_ms,streamk_ms,streamk_experiments,communication_ms,communication_experiments
0,1,8192,4096,20480,True,True,False,True,fp32,all_scatter,...,244,0,True,True,130.097174,10.56433,8.858744,126,10.318066,126
1,4,8192,1024,22528,True,True,False,True,fp32,all_scatter,...,244,0,True,True,208.743414,7.242521,5.097331,126,6.898155,126
2,8,8192,512,20480,True,True,False,True,fp32,all_scatter,...,244,0,True,True,205.87944,6.675701,4.350475,126,6.319639,126
3,4,8192,1024,20480,True,True,False,True,fp32,all_scatter,...,244,0,True,True,211.089421,6.510935,4.367122,126,6.144581,126
4,4,8192,1024,14336,True,True,False,True,fp32,all_scatter,...,244,0,True,True,176.617999,5.447195,3.315763,126,5.148321,126
5,2,8192,2048,20480,True,True,False,True,fp32,all_scatter,...,244,0,True,True,183.545478,7.488005,4.653128,126,7.234178,126
6,2,8192,2048,22528,True,True,False,True,fp32,all_scatter,...,244,0,True,True,188.330243,8.02754,5.214358,126,7.77159,126
7,1,8192,4096,22528,True,True,False,True,fp32,all_scatter,...,244,0,True,True,129.907239,11.637754,10.148404,126,11.386985,126
8,1,8192,4096,14336,True,True,False,True,fp32,all_scatter,...,244,0,True,True,115.294281,8.344496,6.371719,126,8.105937,126
9,8,8192,512,14336,True,True,False,True,fp32,all_scatter,...,244,0,True,True,171.599914,5.606487,3.250437,126,5.307569,126


In [37]:
for algorithm in df['algorithm'].unique():
    filtered_df = df[df['algorithm'] == algorithm].copy()
    filtered_df["shape"] = filtered_df.apply(lambda row: f"M{row['M']}N{row['N']}K{row['K']}", axis=1)

    title = ''
    if 'all_scatter' in algorithm:
        title += 'All Scatter'
    if 'all_reduce' in algorithm:
        title += 'All Reduce'
    if 'one_shot' in algorithm:
        title += 'One Shot'        
    title += ' (Iris)'

    filtered_df = filtered_df.sort_values(by=["K", "world_size"])

    chart = alt.Chart(filtered_df).mark_bar().encode(
        x=alt.X("world_size:O", title="World Size"),
        y=alt.Y("triton_tflops:Q", title="FLOPS (GFLOP/s)"),
        color=alt.Color("world_size:N", title="World Size"),
        column=alt.Column("shape:N", title="", sort=filtered_df["shape"].unique().tolist()),
        tooltip=["shape", "world_size", "triton_tflops"]
    ).properties(
        title=title,
        height=300
    ).configure_axisX(
        labelAngle=0
    ).configure_title(
        anchor="middle",
        fontSize=18,
        font='Helvetica'
    )

    chart.display()


In [None]:
for algorithm in df['algorithm'].unique():
    filtered_df = df[df['algorithm'] == algorithm].copy()
    filtered_df["shape"] = filtered_df.apply(lambda row: f"M{row['M']}N{row['N']}K{row['K']}", axis=1)

    title = ''
    if 'all_scatter' in algorithm:
        title += 'All Scatter'
    if 'all_reduce' in algorithm:
        title += 'All Reduce'
    if 'one_shot' in algorithm:
        title += 'One Shot'        
    title += ' (Iris)'

    filtered_df = filtered_df.sort_values(by=["K", "shape"])

    chart = alt.Chart(filtered_df).mark_bar().encode(
        x=alt.X("shape:N", title="", sort=filtered_df["shape"].unique().tolist()),
        y=alt.Y("triton_tflops:Q", title="FLOPS (GFLOP/s)"),
        color=alt.Color("shape:N", legend=None),
        column=alt.Column("world_size:N", title="World Size"),
        tooltip=["shape", "world_size", "triton_tflops"]
    ).properties(
        title=title,
        height=300
    ).configure_axisX(
        labelAngle=45
    )

    chart = chart.configure_title(
        anchor="middle",     # Center the title
        fontSize=18,
        font='Helvetica'
    )

    display(chart)
