# Plot attention Time 

In [3]:
"""gpu,dtype,num_heads,head_size,batch_size,seq_len,time(us)"""
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv('mha_0.csv')
original_df = df

In [7]:
import ipywidgets as widgets
from IPython.display import display

df = original_df.copy()

# Create widgets for multi-selection
dtype_options = df['dtype'].unique()
num_heads_options = df['num_heads'].unique()
head_size_options = df['head_size'].unique()
batch_size_options = df['batch_size'].unique()

dtype_select = widgets.SelectMultiple(
    options=dtype_options,
    value=['half'],
    description='Dtype',
    disabled=False
)

num_heads_select = widgets.SelectMultiple(
    options=num_heads_options,
    value=[num_heads_options[0]],
    description='Num Heads',
    disabled=False
)

head_size_select = widgets.SelectMultiple(
    options=head_size_options,
    value=[head_size_options[0]],
    description='Head Size',
    disabled=False
)

batch_size_select = widgets.SelectMultiple(
    options=batch_size_options,
    value=[batch_size_options[0]],
    description='Batch Size',
    disabled=False
)

def update_plot(dtype, num_heads, head_size, batch_size):
    plt.figure(figsize=(10, 6))
    filtered_df = df[
        (df['dtype'].isin(dtype)) &
        (df['num_heads'].isin(num_heads)) &
        (df['head_size'].isin(head_size)) &
        (df['batch_size'].isin(batch_size))
    ]
    
    grouped = filtered_df.groupby(['gpu', 'dtype', 'num_heads', 'head_size', 'batch_size'])
    
    for name, group in grouped:
        plt.plot(group['seq_len'], group['time(us)'], label=f'{name}')
    
    plt.xlabel('seq_len')
    plt.ylabel('time(us)')
    plt.title('Time vs Sequence Length for Different Configurations')
    plt.legend(title='(gpu, dtype, num_heads, head_size, batch_size)', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

# Display widgets and plot
ui = widgets.HBox([dtype_select, num_heads_select, head_size_select, batch_size_select])
out = widgets.interactive_output(update_plot, {'dtype': dtype_select, 'num_heads': num_heads_select, 'head_size': head_size_select, 'batch_size': batch_size_select})

display(out, ui)


Output()

HBox(children=(SelectMultiple(description='Dtype', index=(0,), options=('half', 'float'), value=('half',)), Se…

# Plot MLP and Attention Time

In [10]:
gemm_df = pd.read_csv('gemm_0.csv') # gpu,dtype,m,k,n,time(us)
attn_df = pd.read_csv('mha_0.csv') # gpu,dtype,num_heads,head_size,batch_size,seq_len,time(us)

In [None]:
def MLP_Time(
    seq_len, batch_size, 
    num_heads, head_dim, # 
):
    
    pass

In [None]:
import ipywidgets as widgets
from IPython.display import display

df = original_df.copy()

# Only keep the dtype == "half"
df = df[df['dtype'] == 'half']

# Create widgets for multi-selection
num_heads_options = df['num_heads'].unique()
head_size_options = df['head_size'].unique()
batch_size_options = df['batch_size'].unique()

num_heads_select = widgets.SelectMultiple(
    options=num_heads_options,
    value=[num_heads_options[0]],
    description='Num Heads',
    disabled=False
)

head_size_select = widgets.SelectMultiple(
    options=head_size_options,
    value=[head_size_options[0]],
    description='Head Size',
    disabled=False
)

batch_size_select = widgets.SelectMultiple(
    options=batch_size_options,
    value=[batch_size_options[0]],
    description='Batch Size',
    disabled=False
)

def update_plot(num_heads, head_size, batch_size):
    plt.figure(figsize=(10, 6))
    filtered_df = df[
        (df['num_heads'].isin(num_heads)) &
        (df['head_size'].isin(head_size)) &
        (df['batch_size'].isin(batch_size))
    ]
    
    grouped = filtered_df.groupby(['gpu', 'dtype', 'num_heads', 'head_size', 'batch_size'])
    
    for name, group in grouped:
        plt.plot(group['seq_len'], group['time(us)'], label=f'{name}')
    
    plt.xlabel('seq_len')
    plt.ylabel('time(us)')
    plt.title('Time vs Sequence Length for Different Configurations')
    plt.legend(title='(gpu, dtype, num_heads, head_size, batch_size)', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

# Display widgets and plot
ui = widgets.HBox([num_heads_select, head_size_select, batch_size_select])
out = widgets.interactive_output(update_plot, {'num_heads': num_heads_select, 'head_size': head_size_select, 'batch_size': batch_size_select})

display(out, ui)
