In [None]:
import os
import sys
sys.path.insert(0, '..') 

import torch
import torch.sparse

from tqdm.notebook import tqdm

from src.sae_function import topk_sae, topk_sae_fwd_sparse_coo, topk_sae_fwd_sparse_csr, topk_sae_sparse_fused
from src.utils.compare import compare_reconstruct, compare_reconstruct_grad, compare_latency
from src.config import SAETestConfig

device='cuda'
dtype=torch.float32

In [None]:
d_model = 4096
batch_size = 4096
atol = 1e-6
rtol = 1e-6

k = 128

expansion_factor_list = range(4, 33, 4)

In [None]:
fwd_time_base = []
fwd_time_sparse = []
bwd_time_base = []
bwd_time_sparse = []
all_times_base = []
all_times_sparse = []

for expansion_factor in tqdm(expansion_factor_list):
    config = SAETestConfig(
        batch_size=batch_size,
        d_model=d_model,
        expansion_factor=expansion_factor,
        topk=k,
        preheat_repeat=10,
        timing_repeat=10,
    )
    
    compare_reconstruct_grad(
        sae_fwd_1=topk_sae,
        sae_fwd_2=topk_sae_sparse_fused,
        config=config,
        verbose=False,
        atol=1e-6,
        rtol=1e-6,
    )

    fwd_bwd_times_base, fwd_bwd_times_sparse, times_base, times_sparse = compare_latency(
        sae_fwd_1=topk_sae,
        sae_fwd_2=topk_sae_sparse_fused,
        config=config,
        verbose=False,
    )
    
    fwd_time_base.append(fwd_bwd_times_base['forward'])
    fwd_time_sparse.append(fwd_bwd_times_sparse['forward'])
    bwd_time_base.append(fwd_bwd_times_base['backward'])
    bwd_time_sparse.append(fwd_bwd_times_sparse['backward'])
    all_times_base.append(fwd_bwd_times_base['forward'] + fwd_bwd_times_base['backward'])
    all_times_sparse.append(fwd_bwd_times_sparse['forward'] + fwd_bwd_times_sparse['backward'])

In [None]:
save_result_path = '/inspire/hdd/global_user/hezhengfu-240208120186/jx_project/ml_sys_course_pj/result/overall_speed_up'

os.makedirs(save_result_path, exist_ok=True)

torch.save(fwd_time_base, os.path.join(save_result_path, 'fwd_time_base.pt'))
torch.save(fwd_time_sparse, os.path.join(save_result_path, 'fwd_time_sparse.pt'))
torch.save(bwd_time_base, os.path.join(save_result_path, 'bwd_time_base.pt'))
torch.save(bwd_time_sparse, os.path.join(save_result_path, 'bwd_time_sparse.pt'))
torch.save(all_times_base, os.path.join(save_result_path, 'all_times_base.pt'))
torch.save(all_times_sparse, os.path.join(save_result_path, 'all_times_sparse.pt'))


In [None]:
import os

import torch

load_result_path = '/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/overall_speed_up'

fwd_time_base = torch.load(os.path.join(load_result_path, 'fwd_time_base.pt'))
fwd_time_sparse = torch.load(os.path.join(load_result_path, 'fwd_time_sparse.pt'))
bwd_time_base = torch.load(os.path.join(load_result_path, 'bwd_time_base.pt'))
bwd_time_sparse = torch.load(os.path.join(load_result_path, 'bwd_time_sparse.pt'))
all_times_base = torch.load(os.path.join(load_result_path, 'all_times_base.pt'))
all_times_sparse = torch.load(os.path.join(load_result_path, 'all_times_sparse.pt'))


In [69]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# MLSys style configuration
mlsys_colors = {
    'baseline': '#2E86AB',  # 深蓝色
    'sparse': '#E94F37',    # 红色
}

# 创建图表
fig = go.Figure()

expansion_factors = list(expansion_factor_list)

# 添加baseline曲线
fig.add_trace(go.Scatter(
    x=expansion_factors,
    y=fwd_time_base,
    mode='lines+markers',
    name='Dense (Baseline)',
    line=dict(color=mlsys_colors['baseline'], width=2.5),
    marker=dict(size=8, symbol='circle'),
))

# 添加sparse曲线
fig.add_trace(go.Scatter(
    x=expansion_factors,
    y=fwd_time_sparse,
    mode='lines+markers',
    name='FlashSAE (Ours)',
    line=dict(color=mlsys_colors['sparse'], width=2.5),
    marker=dict(size=8, symbol='diamond'),
))

# MLSys风格布局
fig.update_layout(
    title=dict(
        text=f'Forward Pass Latency Comparison<br>',
        font=dict(size=32, family='Times New Roman'),
        x=0.5,
    ),
    xaxis=dict(
        title=dict(text='Expansion Factor', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    yaxis=dict(
        title=dict(text='Latency (ms)', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    legend=dict(
        x=0.02,
        y=0.98,
        bgcolor='rgba(255,255,255,0.8)',
        bordercolor='black',
        borderwidth=1,
        font=dict(size=24, family='Times New Roman'),
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    width=700,
    height=500,
    margin=dict(l=60, r=40, t=80, b=60),
)

fig.show()
fig.write_image(
    "/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/overall_speed_up/forward_pass_latency_comparison.pdf", 
    format="pdf",
)




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




In [70]:
# 计算并绘制Speedup图
speedup = [base / sparse for base, sparse in zip(fwd_time_base, fwd_time_sparse)]

fig_speedup = go.Figure()

fig_speedup.add_trace(go.Bar(
    x=expansion_factors,
    y=speedup,
    marker=dict(
        color=speedup,
        colorscale=[[0, '#E94F37'], [0.5, '#F4A261'], [1, '#2A9D8F']],
        showscale=True,
        colorbar=dict(
            title=dict(text='Speedup', font=dict(size=16, family='Times New Roman')),
            tickfont=dict(size=18, family='Times New Roman'),
        ),
    ),
    text=[f'{s:.2f}x' for s in speedup],
    textposition='outside',
    textfont=dict(size=16, family='Times New Roman'),
))

fig_speedup.update_layout(
    title=dict(
        text=f'Forward Pass Speedup (FlashSAE vs Dense)',
        font=dict(size=32, family='Times New Roman'),
        x=0.5,
    ),
    xaxis=dict(
        title=dict(text='Expansion Factor', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    yaxis=dict(
        title=dict(text='Speedup (×)', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
        rangemode='tozero',
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    width=700,
    height=500,
    margin=dict(l=60, r=40, t=80, b=60),
    showlegend=False,
)

fig_speedup.show()

fig_speedup.write_image(
    "/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/overall_speed_up/forward_pass_speedup_comparison.pdf", 
    format="pdf",
)




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




In [71]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# MLSys style configuration
mlsys_colors = {
    'baseline': '#2E86AB',  # 深蓝色
    'sparse': '#E94F37',    # 红色
}

# 创建图表
fig = go.Figure()

expansion_factors = list(expansion_factor_list)

# 添加baseline曲线
fig.add_trace(go.Scatter(
    x=expansion_factors,
    y=bwd_time_base,
    mode='lines+markers',
    name='Dense (Baseline)',
    line=dict(color=mlsys_colors['baseline'], width=2.5),
    marker=dict(size=8, symbol='circle'),
))

# 添加sparse曲线
fig.add_trace(go.Scatter(
    x=expansion_factors,
    y=bwd_time_sparse,
    mode='lines+markers',
    name='FlashSAE (Ours)',
    line=dict(color=mlsys_colors['sparse'], width=2.5),
    marker=dict(size=8, symbol='diamond'),
))

# MLSys风格布局
fig.update_layout(
    title=dict(
        text=f'Backward Pass Latency Comparison<br>',
        font=dict(size=32, family='Times New Roman'),
        x=0.5,
    ),
    xaxis=dict(
        title=dict(text='Expansion Factor', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    yaxis=dict(
        title=dict(text='Latency (ms)', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    legend=dict(
        x=0.02,
        y=0.98,
        bgcolor='rgba(255,255,255,0.8)',
        bordercolor='black',
        borderwidth=1,
        font=dict(size=24, family='Times New Roman'),
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    width=700,
    height=500,
    margin=dict(l=60, r=40, t=80, b=60),
)

fig.show()
fig.write_image(
    "/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/overall_speed_up/backward_pass_latency_comparison.pdf", 
    format="pdf",
)




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




In [72]:
# 计算并绘制Speedup图
speedup = [base / sparse for base, sparse in zip(bwd_time_base, bwd_time_sparse)]

fig_speedup = go.Figure()

fig_speedup.add_trace(go.Bar(
    x=expansion_factors,
    y=speedup,
    marker=dict(
        color=speedup,
        colorscale=[[0, '#E94F37'], [0.5, '#F4A261'], [1, '#2A9D8F']],
        showscale=True,
        colorbar=dict(
            title=dict(text='Speedup', font=dict(size=16, family='Times New Roman')),
            tickfont=dict(size=18, family='Times New Roman'),
        ),
    ),
    text=[f'{s:.2f}x' for s in speedup],
    textposition='outside',
    textfont=dict(size=16, family='Times New Roman'),
))

fig_speedup.update_layout(
    title=dict(
        text=f'Backward Pass Speedup (FlashSAE vs Dense)',
        font=dict(size=32, family='Times New Roman'),
        x=0.5,
    ),
    xaxis=dict(
        title=dict(text='Expansion Factor', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    yaxis=dict(
        title=dict(text='Speedup (×)', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
        rangemode='tozero',
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    width=700,
    height=500,
    margin=dict(l=60, r=40, t=80, b=60),
    showlegend=False,
)

fig_speedup.show()

fig_speedup.write_image(
    "/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/overall_speed_up/backward_pass_speedup_comparison.pdf", 
    format="pdf",
)




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




In [73]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# MLSys style configuration
mlsys_colors = {
    'baseline': '#2E86AB',  # 深蓝色
    'sparse': '#E94F37',    # 红色
}

# 创建图表
fig = go.Figure()

expansion_factors = list(expansion_factor_list)

# 添加baseline曲线
fig.add_trace(go.Scatter(
    x=expansion_factors,
    y=all_times_base,
    mode='lines+markers',
    name='Dense (Baseline)',
    line=dict(color=mlsys_colors['baseline'], width=2.5),
    marker=dict(size=8, symbol='circle'),
))

# 添加sparse曲线
fig.add_trace(go.Scatter(
    x=expansion_factors,
    y=all_times_sparse,
    mode='lines+markers',
    name='FlashSAE (Ours)',
    line=dict(color=mlsys_colors['sparse'], width=2.5),
    marker=dict(size=8, symbol='diamond'),
))

# MLSys风格布局
fig.update_layout(
    title=dict(
        # text=f'End-to-End Latency vs. Expansion Factor<br><sup>batch_size={batch_size}, d_model={d_model}, k={k}</sup>',
        text=f'End-to-End Latency vs. Expansion Factor<br>',
        font=dict(size=32, family='Times New Roman'),
        x=0.5,
    ),
    xaxis=dict(
        title=dict(text='Expansion Factor', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    yaxis=dict(
        title=dict(text='Latency (ms)', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    legend=dict(
        x=0.02,
        y=0.98,
        bgcolor='rgba(255,255,255,0.8)',
        bordercolor='black',
        borderwidth=1,
        font=dict(size=24, family='Times New Roman'),
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    width=700,
    height=500,
    margin=dict(l=60, r=40, t=80, b=60),
)

fig.show()
fig.write_image(
    "/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/overall_speed_up/end_to_end_latency_comparison.pdf", 
    format="pdf",
)




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




In [74]:
# 计算并绘制Speedup图
speedup = [base / sparse for base, sparse in zip(all_times_base, all_times_sparse)]

fig_speedup = go.Figure()

fig_speedup.add_trace(go.Bar(
    x=expansion_factors,
    y=speedup,
    marker=dict(
        color=speedup,
        colorscale=[[0, '#E94F37'], [0.5, '#F4A261'], [1, '#2A9D8F']],
        showscale=True,
        colorbar=dict(
            title=dict(text='Speedup', font=dict(size=16, family='Times New Roman')),
            tickfont=dict(size=18, family='Times New Roman'),
        ),
    ),
    text=[f'{s:.2f}x' for s in speedup],
    textposition='outside',
    textfont=dict(size=16, family='Times New Roman'),
))

fig_speedup.update_layout(
    title=dict(
        text=f'End-to-End Speedup (FlashSAE vs Dense)',
        font=dict(size=32, family='Times New Roman'),
        x=0.5,
    ),
    xaxis=dict(
        title=dict(text='Expansion Factor', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    yaxis=dict(
        title=dict(text='Speedup (×)', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
        rangemode='tozero',
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    width=700,
    height=500,
    margin=dict(l=60, r=40, t=80, b=60),
    showlegend=False,
)

fig_speedup.show()

fig_speedup.write_image(
    "/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/overall_speed_up/end_to_end_speedup_comparison.pdf", 
    format="pdf",
)




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).


