# MoE Kernel Benchmark Analysis

This notebook analyzes benchmark results

## Hypotheses to Test:
1. **Does the activation affect the time and kernel selection? (why?)**

In [1]:
# Imports (if not already run)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

print("Libraries imported successfully!")

Libraries imported successfully!


In [None]:
# Activation Comparison Table for Jupyter Notebook
# Copy these cells into moe_kernel_analysis.ipynb

# ============================================================================
# Cell 1: Load Data
# ============================================================================

# Load activation comparison results
act_df = pd.read_csv('../results_activation_comparison.csv')

# Filter valid results
valid_df = act_df[act_df['error'] != 'failed'].copy()
valid_df['error_pct'] = valid_df['error'].str.rstrip('%').astype(float)
valid_df = valid_df[valid_df['error_pct'] < 50.0]

print(f"Loaded {len(valid_df)} valid kernel results")
print(f"  Silu: {len(valid_df[valid_df['act_type'] == 'ActivationType.Silu'])}")
print(f"  Gelu: {len(valid_df[valid_df['act_type'] == 'ActivationType.Gelu'])}")

# ============================================================================
# Cell 2: Create Comparison Table
# ============================================================================

print("\nCreating Silu vs Gelu comparison table...")

# Normalize kernel names (remove activation)
def normalize_kernel_name(kernel_name):
    base = kernel_name
    if '_silu_' in base.lower():
        base = base.replace('_silu_', '_ACT_').replace('_Silu_', '_ACT_')
    elif '_gelu_' in base.lower():
        base = base.replace('_gelu_', '_ACT_').replace('_Gelu_', '_ACT_')
    return base

valid_df['kernel_base_name'] = valid_df['kernel_name'].apply(normalize_kernel_name)


# Define matching columns
match_cols = [
    'token', 'model_dim', 'inter_dim', 'expert', 'topk',
    'dtype', 'q_dtype_a', 'q_dtype_w', 'q_type',
    'use_g1u1', 'doweight_stage1',
    'stage', 'block_m', 'kernel_type', 'tile_m', 'tile_n',
    'kernel_base_name'
]

# Separate Silu and Gelu
silu_df = valid_df[valid_df['act_type'] == 'ActivationType.Silu'].copy()
gelu_df = valid_df[valid_df['act_type'] == 'ActivationType.Gelu'].copy()

# Merge on matching columns
comparison = silu_df.merge(
    gelu_df,
    on=match_cols,
    how='inner',
    suffixes=('_silu', '_gelu')
)

# Calculate differences
comparison['time_diff_us'] = comparison['time_us_gelu'] - comparison['time_us_silu']
comparison['time_diff_pct'] = (comparison['time_diff_us'] / comparison['time_us_silu']) * 100
comparison['faster'] = comparison.apply(
    lambda row: 'Silu' if row['time_us_silu'] < row['time_us_gelu'] else 'Gelu',
    axis=1
)

# Select display columns
display_cols = [
    'config_idx_silu', 'token', 'model_dim', 'expert', 'topk', 
    'stage', 'kernel_type', 'block_m',  # Added kernel_type and block_m
    'kernel_name_silu', 'kernel_name_gelu',
    'time_us_silu', 'time_us_gelu', 
    'time_diff_pct', 'faster',
    'error_silu', 'error_gelu'
]

result = comparison[display_cols].copy()
result = result.rename(columns={'config_idx_silu': 'config_idx'})

# Sort by absolute difference
result = result.sort_values('time_diff_pct', key=abs, ascending=False)

print(f"Matched kernel pairs: {len(result)}")
print(f"Silu faster: {(result['faster'] == 'Silu').sum()} ({(result['faster'] == 'Silu').sum()/len(result)*100:.1f}%)")
print(f"Gelu faster: {(result['faster'] == 'Gelu').sum()} ({(result['faster'] == 'Gelu').sum()/len(result)*100:.1f}%)")
print(f"Average difference: {result['time_diff_pct'].abs().mean():.2f}%")

# ============================================================================
# Cell 3: Display Table
# ============================================================================

print("\nComparison Table (sorted by largest performance difference):")
print("="*80)
display(result.head(20))

print("\n**Table Columns:**")
print("- time_us_silu: Kernel execution time with Silu activation")
print("- time_us_gelu: Kernel execution time with Gelu activation")
print("- time_diff_pct: Percentage difference (Gelu relative to Silu)")
print("  - Negative = Silu faster")
print("  - Positive = Gelu faster")
print("- faster: Which activation is faster for this kernel")

# ============================================================================
# Cell 4: Summary Statistics
# ============================================================================

print("\nSummary Statistics:")
print("="*80)

print(f"\nOverall:")
print(f"  Matched pairs: {len(result)}")
print(f"  Average |difference|: {result['time_diff_pct'].abs().mean():.2f}%")
print(f"  Median |difference|: {result['time_diff_pct'].abs().median():.2f}%")
print(f"  Max |difference|: {result['time_diff_pct'].abs().max():.2f}%")

print(f"\nBy Stage:")
for stage in result['stage'].unique():
    stage_data = result[result['stage'] == stage]
    print(f"  {stage}: {len(stage_data)} kernels, avg diff = {stage_data['time_diff_pct'].abs().mean():.2f}%")

if 'kernel_type' in result.columns:
    print(f"\nBy Kernel Type:")
    for ktype in result['kernel_type'].unique():
        ktype_data = result[result['kernel_type'] == ktype]
        silu_wins = (ktype_data['faster'] == 'Silu').sum()
        print(f"  {ktype}: Silu wins {silu_wins}/{len(ktype_data)} ({silu_wins/len(ktype_data)*100:.1f}%)")

# ============================================================================
# Cell 5: Interactive Visualization (2x2 Grid with Plotly)
# ============================================================================

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create plot groups (stage1-asm, stage1-ck, stage2, asm_1stage)
plot_groups = []

stage1_data = result[result['stage'] == 'stage1']
for ktype in sorted(stage1_data['kernel_type'].unique()):
    data = stage1_data[stage1_data['kernel_type'] == ktype]
    plot_groups.append((f'Stage1-{ktype.upper()}', data))

for stage in ['stage2', 'asm_1stage']:
    stage_data = result[result['stage'] == stage]
    if len(stage_data) > 0:
        plot_groups.append((stage, stage_data))

# Create 2x2 subplot grid
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=[f'{label} ({len(data)} kernels)' for label, data in plot_groups],
    horizontal_spacing=0.12,
    vertical_spacing=0.15
)

for idx, (label, data) in enumerate(plot_groups):
    row = (idx // 2) + 1
    col = (idx % 2) + 1
    
    # Create hover text with kernel details
    hover_text = []
    for _, r in data.iterrows():
        block_m_info = f"block_m={r['block_m']}<br>" if 'block_m' in data.columns else ""
        text = (f"<b>Config #{r['config_idx']}</b><br>"
                f"Token={r['token']}, Model={r['model_dim']}, Expert={r['expert']}, TopK={r['topk']}<br>"
                f"{block_m_info}"
                f"<br><b>Silu Kernel:</b><br>{r['kernel_name_silu'][:80]}<br>"
                f"Time: {r['time_us_silu']:.2f} us (err={r['error_silu']})<br>"
                f"<br><b>Gelu Kernel:</b><br>{r['kernel_name_gelu'][:80]}<br>"
                f"Time: {r['time_us_gelu']:.2f} us (err={r['error_gelu']})<br>"
                f"<br><b>Performance:</b> {r['time_diff_pct']:.2f}% ({r['faster']} faster)")
        hover_text.append(text)
    
    # Add scatter trace
    fig.add_trace(
        go.Scatter(
            x=data['time_us_silu'],
            y=data['time_us_gelu'],
            mode='markers',
            marker=dict(
                size=8,
                color=data['time_diff_pct'],
                colorscale='RdYlGn_r',
                cmin=-10,
                cmax=10,
                showscale=False,  # Remove colorbar
                line=dict(width=1, color='DarkSlateGray')
            ),
            hovertext=hover_text,
            hoverinfo='text',
            name=label,
            showlegend=False
        ),
        row=row, col=col
    )
    
    # Add diagonal line (equal performance)
    if len(data) > 0:
        max_val = max(data['time_us_silu'].max(), data['time_us_gelu'].max())
        fig.add_trace(
            go.Scatter(
                x=[0, max_val],
                y=[0, max_val],
                mode='lines',
                line=dict(color='red', dash='dash', width=2),
                name='Equal Performance',
                showlegend=(idx == 0),
                hoverinfo='skip'
            ),
            row=row, col=col
        )
    
    # Update axes
    fig.update_xaxes(title_text='Silu Time (us)', row=row, col=col)
    fig.update_yaxes(title_text='Gelu Time (us)', row=row, col=col)

# Update layout
fig.update_layout(
    height=900,
    width=1400,
    title_text='Interactive Kernel Performance: Silu vs Gelu<br><sub>Hover over points for details</sub>',
    title_font_size=18,
    showlegend=True,
    hovermode='closest'
)

fig.show()

print("\nâœ… Interactive Features:")
print("  - Hover over any point to see full kernel details")
print("  - Zoom: Click and drag to select area")
print("  - Pan: Hold shift and drag")
print("  - Reset: Double-click")
print("\nðŸ“Š Color Legend:")
print("  - Red (above diagonal) = Gelu is SLOWER")
print("  - Green (below diagonal) = Gelu is FASTER")
print("  - White (on diagonal) = Equal performance")
