In [None]:
import os
from io import StringIO
import pandas as pd
from scipy.stats import sem
import matplotlib.pyplot as plt

model_name = {'meta-llama_Llama-2-13b-hf': 'Llama-2-13B', 
              'meta-llama_Llama-2-7b-hf': 'Llama-2-7B',
              'lmsys_vicuna-7b-v1.5': 'Vicuna-7B',
              'lmsys_vicuna-13b-v1.5': 'Vicuna-13B',
              'lmsys_vicuna-33b-v1.3': 'Vicuna-33B'}

tp = 2
num_runs = 3
num_reqs = 1024
folder = ''
dataset = 'sharegpt'
drafter = 'Vicuna-160M'
bench_result_dir = "[PATH_TO_BENCH_RESULTS]"

# for model in ['lmsys_vicuna-13b-v1.5', 'meta-llama_Llama-2-13b-hf', 'lmsys_vicuna-33b-v1.3', 'lmsys_vicuna-7b-v1.5']:
for model in ['meta-llama_Llama-2-7b-hf', 'meta-llama_Llama-2-13b-hf']:
    # Plotting the figure
    fig, axs = plt.subplots(1, 3, figsize=(5*3, 4))

    plt.suptitle(f'{drafter}/{model_name[model]}: Dataset={dataset}, Tensor_Parallel={tp}', fontsize=16)

    for axs_idx, max_seq_nums in enumerate([64, 128, 256]):
        for method in ['tetris', 'baseline_sd', 'dsd']:
            for extra_idx, extra in enumerate([1, 2, 4]):
                # Only plot once for baseline_sd
                if method in ['baseline_sd', 'dsd', 'no_sd'] and extra_idx > 0:
                    continue

                extra_str = f'_extra{extra}' if method == 'tetris' else ''
                csv_file = f'{bench_result_dir}/{folder}{model}_{dataset}_{method}_tp{tp}_{num_reqs}_max{max_seq_nums}{extra_str}.csv'
                
                # Check if the file exists
                if not os.path.exists(csv_file):
                    print(f'Error: {csv_file} does not exist')
                    continue

                # Read the CSV file and remove duplicate headers
                with open(csv_file, 'r') as file:
                    lines = file.readlines()

                # Keep the first header and remove subsequent headers
                cleaned_lines = [lines[0]] + [line for line in lines[1:] if not line.startswith('model')]

                # Convert cleaned_lines to a single string
                cleaned_data = ''.join(cleaned_lines)

                # Use StringIO to load the cleaned data into a DataFrame
                df = pd.read_csv(StringIO(cleaned_data))

                # # Display the first few rows of the dataframe
                # print(df.head())

                # Plot
                labels = {'baseline_sd': 'Standard SD', 'tetris': 'Tetris', 'dsd': 'DSD', 'no_sd': 'w/o SD'}
                if method == 'tetris':
                    df['num_speculative_tokens'] = df['num_speculative_tokens'] - extra
                extra_label = f' (extra={extra})' if method == 'tetris' else ''
                target = 'throughput' # 'throughput', 'mean_TTFT', 'mean_TPOT', 'mean_e2el_latency'
                if num_runs > 1:
                    grouped = df.groupby('num_speculative_tokens').agg({target: ['mean', sem, 'count']})
                else:
                    grouped = df.groupby('num_speculative_tokens').agg({target: ['mean', 'count']})
                
                # Validate number of runs
                if (grouped[target]['count'] != num_runs).sum() != 0:
                    print(f'Error: {csv_file} has missing runs')
                
                if method == 'dsd':
                    # Draw dotted horizontal line for DSD
                    axs[axs_idx].axhline(y=grouped[target]['mean'].iloc[0], color='purple', linestyle='--', label=labels[method] + extra_label)
                    if num_runs > 1:
                        axs[axs_idx].fill_between(grouped.index, grouped[target]['mean'] - grouped[target]['sem'], grouped[target]['mean'] + grouped[target]['sem'], color='purple', alpha=0.2)
                elif method == 'no_sd':
                    # Draw dotted horizontal line for no sd
                    axs[axs_idx].axhline(y=grouped[target]['mean'].iloc[0], color='black', linestyle='--', label=labels[method] + extra_label)
                    if num_runs > 1:
                        axs[axs_idx].fill_between(grouped.index, grouped[target]['mean'] - grouped[target]['sem'], grouped[target]['mean'] + grouped[target]['sem'], color='black', alpha=0.2)
                else:
                    axs[axs_idx].plot(grouped[target]['mean'], marker='o', linestyle='-', label=labels[method] + extra_label)
                    if num_runs > 1:
                        axs[axs_idx].fill_between(grouped.index, grouped[target]['mean'] - grouped[target]['sem'], grouped[target]['mean'] + grouped[target]['sem'], alpha=0.2)

                # Add labels and title
                axs[axs_idx].set_xlabel('No. Speculative Tokens')
                ylabels = {'throughput': 'Mean Throughput', 'mean_TTFT': 'Mean TTFT', 'mean_TPOT': 'Mean TPOT', 'mean_e2el_latency': 'Mean End-to-end Latency'}
                axs[axs_idx].set_ylabel(ylabels[target])
                axs[axs_idx].set_title(f'max_seq_nums={max_seq_nums}')
                axs[axs_idx].legend()

    # Show the plot
    plt.tight_layout()
    plt.show()