In [1]:
import os
import sys
sys.path.insert(0, '..') 

import torch
import torch.sparse

from tqdm.notebook import tqdm

from src.sae_fuction import topk_sae, topk_sae_fwd_sparse_coo, topk_sae_fwd_sparse_csr, topk_sae_sparse_fused
from src.utils import compare_reconstruct, compare_reconstruct_grad, compare_latency
from src.config import TestConfig

device='cuda'
dtype=torch.float32

In [2]:
d_model = 4096
batch_size = 4096
atol = 1e-6
rtol = 1e-6

k = 128

expansion_factor_list = range(4, 33, 4)

In [6]:
time_base = []
time_fusion = []

for expansion_factor in tqdm(expansion_factor_list):
    config = TestConfig(
        batch_size=batch_size,
        d_model=d_model,
        expansion_factor=expansion_factor,
        k=k,
        preheat_repeat=10,
        timing_repeat=20,
    )
    
    compare_reconstruct_grad(
        sae_fwd_1=topk_sae_fwd_sparse_csr,
        sae_fwd_2=topk_sae_sparse_fused,
        config=config,
        verbose=False,
        atol=1e-6,
        rtol=1e-6,
    )

    fwd_bwd_times_base, fwd_bwd_times_fusion, times_base, times_fusion = compare_latency(
        sae_fwd_1=topk_sae_fwd_sparse_csr,
        sae_fwd_2=topk_sae_sparse_fused,
        config=config,
        verbose=False,
    )
    
    time_base.append(times_base['fwd_build_feature_acts']+times_base['fwd_feature_acts_to_sparse'])
    time_fusion.append(times_fusion['fwd_sort_topk_result'] + times_fusion['fwd_build_sparse_feature_acts'])

  0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
save_result_path = '/inspire/hdd/global_user/hezhengfu-240208120186/jx_project/ml_sys_course_pj/result/kernel_fusion'

os.makedirs(save_result_path, exist_ok=True)

torch.save(time_base, os.path.join(save_result_path, 'time_base.pt'))
torch.save(time_fusion, os.path.join(save_result_path, 'time_fusion.pt'))

In [1]:
import os

import torch

load_result_path = '/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/kernel_fusion'

time_base = torch.load(os.path.join(load_result_path, 'time_base.pt'))
time_fusion = torch.load(os.path.join(load_result_path, 'time_fusion.pt'))

In [4]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# MLSys style configuration
mlsys_colors = {
    'baseline': '#2E86AB',  # 深蓝色
    'sparse': '#E94F37',    # 红色
}

# 创建图表
fig = go.Figure()

expansion_factors = list(expansion_factor_list)

# 添加baseline曲线
fig.add_trace(go.Scatter(
    x=expansion_factors,
    y=time_base,
    mode='lines+markers',
    name='Baseline',
    line=dict(color=mlsys_colors['baseline'], width=2.5),
    marker=dict(size=8, symbol='circle'),
))

# 添加sparse曲线
fig.add_trace(go.Scatter(
    x=expansion_factors,
    y=time_fusion,
    mode='lines+markers',
    name='Fusion',
    line=dict(color=mlsys_colors['sparse'], width=2.5),
    marker=dict(size=8, symbol='diamond'),
))

# MLSys风格布局
fig.update_layout(
    title=dict(
        text=f'Build Sparse Latency Comparison',
        font=dict(size=32, family='Times New Roman'),
        x=0.5,
    ),
    xaxis=dict(
        title=dict(text='Expansion Factor', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    yaxis=dict(
        title=dict(text='Latency (ms)', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    legend=dict(
        x=0.02,
        y=0.98,
        bgcolor='rgba(255,255,255,0.8)',
        bordercolor='black',
        borderwidth=1,
        font=dict(size=24, family='Times New Roman'),
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    width=700,
    height=500,
    margin=dict(l=60, r=40, t=80, b=60),
)

fig.show()
fig.write_image(
    "/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/kernel_fusion/build_sparse_latency_comparison.pdf", 
    format="pdf",
)




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




In [6]:
speedup = [base / sparse for base, sparse in zip(time_base, time_fusion)]

fig_speedup = go.Figure()

fig_speedup.add_trace(go.Bar(
    x=expansion_factors,
    y=speedup,
    marker=dict(
        color=speedup,
        colorscale=[[0, '#E94F37'], [0.5, '#F4A261'], [1, '#2A9D8F']],
        showscale=True,
        colorbar=dict(
            title=dict(text='Speedup', font=dict(size=16, family='Times New Roman')),
            tickfont=dict(size=18, family='Times New Roman'),
        ),
    ),
    text=[f'{s:.2f}x' for s in speedup],
    textposition='outside',
    textfont=dict(size=16, family='Times New Roman'),
))

fig_speedup.update_layout(
    title=dict(
        text=f'Build Sparse Speedup (Fusion vs Baseline)',
        font=dict(size=32, family='Times New Roman'),
        x=0.5,
    ),
    xaxis=dict(
        title=dict(text='Expansion Factor', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
    ),
    yaxis=dict(
        title=dict(text='Speedup (×)', font=dict(size=24, family='Times New Roman')),
        tickfont=dict(size=18, family='Times New Roman'),
        gridcolor='lightgray',
        gridwidth=0.5,
        showline=True,
        linewidth=1,
        linecolor='black',
        mirror=True,
        rangemode='tozero',
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    width=700,
    height=500,
    margin=dict(l=60, r=40, t=80, b=60),
    showlegend=False,
)

fig_speedup.show()

fig_speedup.write_image(
    "/Users/jxwang/Programs/sii-courses/ai-system-25/course_project/result/kernel_fusion/build_sparse_speedup_comparison.pdf", 
    format="pdf",
)



Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).


