In [1]:
import os, sys, warnings
script_dir = os.getcwd()
module_path = script_dir
for _ in range(1):
    module_path = os.path.abspath(os.path.join(module_path, '../'))
    if module_path not in sys.path:
        sys.path.insert(0,module_path)
        
from src import decode_moddeling, prefill_moddeling
import pandas as pd
from plotnine import *
import plotnine as p9
from tqdm import tqdm

from Systems.system_configs import *
All_model_list = ['opt_125m', 'opt_350m', 'opt_1b', 'opt_175b', 'gemma_7b', 'LLaMA_7b', 'llama3_8b',  'llama_13b', 'mixtral_7x8',  'LLaMA_70b', 'dbrx', 'grok-1', 'gpt-3',  'gpt-4']
All_models_name = ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 'facebook/opt-175b', 'google/gemma-7b', 'meta-llama/Llama-2-7b', 'meta-llama/Meta-Llama-3', 'meta-llama/Llama-2-13b', 'mistralai/Mixtral-8x7B', 'meta-llama/Llama-2-70b', 'databricks/dbrx-base', 'xai-org/grok-1', 'openai/gpt-3', 'openai/gpt-4']



Matplotlib is building the font cache; this may take a moment.


In [2]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import plotly.express as px
from plotnine import *
import plotnine as p9


# Set up interactive widgets for the variables
from ipywidgets import interact, IntSlider, Checkbox, BoundedIntText, BoundedFloatText, Dropdown
import ipywidgets as widgets


# Define the function to generate the demand curve
def generate_demand_curve(system_box, num_nodes_slider, model_box, quantization_box, batch_slider, input_token_slider, output_token_slider):
    warnings.filterwarnings("ignore")
    data = []
    batch_size_list = [1,2,4,8,16,32,48,64,80,96,112,128,136,144,160, 172, 180, 200, 224, 240, 256]
    for batch_size in tqdm(batch_size_list):
        for model in model_box:
            if batch_size <= batch_slider:
                try: 
                    prefill_outputs = prefill_moddeling(model = model, batch_size = batch_size,
                                            input_tokens = input_token_slider, output_tokens = output_token_slider, FLAT = True,
                                            system_name = system_box,
                                            bits=quantization_box,
                                            tensor_parallel = num_nodes_slider, debug=False, time_breakdown=True) 
                    data.append([model,'Prefill',batch_size, prefill_outputs['Latency'], prefill_outputs['Throughput']] + prefill_outputs['Runtime_breakdown'])
                    decode_outputs = decode_moddeling(model = model, batch_size = batch_size, Bb = 4 ,
                                            input_tokens = input_token_slider, output_tokens = output_token_slider, FLAT = True,
                                            system_name = system_box,
                                            bits=quantization_box,
                                            tensor_parallel = num_nodes_slider, debug=False, time_breakdown=True) 
                    data.append([model,'Decode',batch_size,  decode_outputs['Latency'], decode_outputs['Throughput']] + decode_outputs['Runtime_breakdown'])
                except:
                    ValueError
    assert len(data) > 0, "No Model fits in the given # of GPUs. Increase GPUs or use different Model"
    data_df = pd.DataFrame(data, columns = ['Model', 'Stage','Batch', 'Latency(ms)', 'Tokens/s', 'GEMM Time', 'Attn Time', 'Communication Time'])
    data_df = data_df.replace(All_model_list, All_models_name)
    data_df['Stage'] = pd.Categorical(data_df['Stage'], categories=['Prefill','Decode'])
    
    fig = px.line(data_df, x="Batch", y="Tokens/s",  line_group="Model", color="Model", facet_row='Stage', 
                labels={"Batch": "Batch", "Tokens/s": "Tokens/s", "Model": "Model"},
                width=1800, height=600, markers=True)

    # Customize axis labels
    fig.update_xaxes(title_font=dict(size=24))
    fig.update_yaxes(title_font=dict(size=24))

    # Customize tick labels
    fig.update_xaxes(tickfont=dict(size=24))
    fig.update_yaxes(tickfont=dict(size=24))

    fig.update_yaxes(matches=None)

    # # Customize facet labels
    fig.update_layout(
        font_color="black",
        title_font_color="black",
        legend_title_font_color="black",
        font_size=24
    )

    fig.show()



batch_slider = BoundedIntText( value=8, min=1, max=128, step=1, description='Max Batch Size:', disabled=False , style={'description_width': 'initial'})
input_token_slider = BoundedIntText( value=512, min=1, max= 100000, step=1, description='Input Tokens:', disabled=False , style={'description_width': 'initial'})
output_token_slider = BoundedIntText( value=128, min=1, max= 100000, step=1, description='Output Tokens:', disabled=False , style={'description_width': 'initial'})

quantization_box = Dropdown( options=['bf16', 'int8', 'int4'], value='int8', description='Quantization:', disabled=False , style={'description_width': 'initial'},)
model_box = widgets.SelectMultiple( options=[
    ('facebook/opt-125m','opt_125m'),
    ('facebook/opt-350m','opt_350m'),
    ('facebook/opt-1.3b','opt_1b'),
    ('facebook/opt-175b','opt_175b'),
    ('google/gemma-7b','gemma_7b'),
    ('meta-llama/Llama-2-7b','LLaMA_7b'),
    ('meta-llama/Meta-Llama-3-8B','llama3_8b'), 
    ('meta-llama/Llama-2-13b','llama_13b'),
    ('mistralai/Mixtral-8x7B','mixtral_7x8'), 
    ('meta-llama/Llama-2-70b','LLaMA_70b'),
    ('databricks/dbrx-base','dbrx'),
    ('xai-org/grok-1','grok-1'),
    ('openai/gpt-3','gpt-3'), 
    ('openai/gpt-4','gpt-4')
    ], value=['LLaMA_7b'], description='Models:', disabled=False,)
system_box = Dropdown( options=['A100-40GB', 'A100-80GB', 'H100','GH200', 'TPUv4','TPUv5e', 'MI300X', 'Gaudi3'], value='H100', description='System:', disabled=False,)
num_nodes_slider = BoundedIntText( value=2, min=1, max=128, step=1, description='# Nodes:', disabled=False)


# Create an interactive plot
interact(generate_demand_curve,
         system_box=system_box, num_nodes_slider=num_nodes_slider, model_box=model_box, quantization_box=quantization_box,
         batch_slider=batch_slider, input_token_slider=input_token_slider, output_token_slider=output_token_slider, )

interactive(children=(Dropdown(description='System:', index=2, options=('A100-40GB', 'A100-80GB', 'H100', 'GH2…

<function __main__.generate_demand_curve(system_box, num_nodes_slider, model_box, quantization_box, batch_slider, input_token_slider, output_token_slider)>