In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

from GenZ import get_model_df, System, get_summary_table, simplify_df, get_configs, ParallelismConfig, get_runtime_breakdown
from GenZ.Models.get_language_model import create_full_chunked_model, create_full_decode_model
from GenZ.Models.attention import mha_flash_attention_chunked
from GenZ.utils.plot_rooflines import display_df
import os
import pandas as pd


In [3]:
system = System(frequency=1000 , flops=2000, off_chip_mem_size=(80*1024), compute_efficiency=0.8, memory_efficiency=0.8,
                    offchip_mem_bw=3500, bits='int8', external_mem_bw=128, interchip_link_bw=256, interchip_link_latency=2, num_nodes=8)

In [None]:
get_configs('meta-llama/meta-llama-3.1-405b')

<GenZ.Models.default_models.ModelConfig at 0x77b3e7f2c0d0>

In [5]:
def get_chunked_model_runtime(model_name = 'gpt-3', chunk_size = 256, system = System(), input_tokens = 1024, output_tokens = [1024]*8, batch_size = 1):
    
    model = create_full_chunked_model(chunk_size, model_name, input_tokens, output_tokens, tensor_parallel=8)
    model_df = get_model_df(model, system, batch_size=1,  model_characterstics = False)
    summary_table = get_summary_table(model_df, model_characterstics = False)
    runtime_df = get_runtime_breakdown(model_df) 
    return summary_table, runtime_df

In [6]:
import ipywidgets as widgets
from ipywidgets import interact

def interactive_chunked_model_runtime(input_tokens, output_tokens, output_batch):
    runtime_plot_df = pd.DataFrame(columns=['Layer Name', 'Runtime', 'Chunk Size'])
    for chunk_size in range(256, 2049, 256):
        summary_table, runtime_df = get_chunked_model_runtime(chunk_size=chunk_size, input_tokens=input_tokens, output_tokens=[output_tokens]*output_batch, system=system)
        # display(summary_table)
        # display(runtime_df)
    
        df = pd.DataFrame(
            [['Embedding', runtime_df.Embedding, chunk_size],
            ['Collective', runtime_df.Collective, chunk_size],
            ['LA_layers', runtime_df.LA_layers, chunk_size], 
            ['QKVO_layers', runtime_df.QKVO_layers, chunk_size],
            ['FFN_layers', runtime_df.FFN_layers, chunk_size]],
            columns=['Layer Name', 'Runtime', 'Chunk Size'])
        runtime_plot_df = pd.concat([runtime_plot_df, df])

    import plotly.express as px

    fig = px.bar(runtime_plot_df, x='Chunk Size', y='Runtime', color='Layer Name', barmode='group',
                title='Runtime Breakdown by Chunk Size, Input Tokens = {}, Output Tokens = {}, Output Batch = {}'.format(input_tokens, output_tokens, output_batch))
    fig.show()
    # return fig
    # display(runtime_plot_df)

input_tokens_slider = widgets.BoundedIntText(min=1, max=20480, step=10, value=1024, description='Input Tokens:')
output_tokens_slider = widgets.BoundedIntText(min=1, max=20480, step=10, value=1024, description='Output Tokens:')
batch_slider = widgets.BoundedIntText(min=1, max=512, step=8, value=8, description='Output Batch:')

interact(interactive_chunked_model_runtime, input_tokens=input_tokens_slider, output_tokens=output_tokens_slider, output_batch=batch_slider)

interactive(children=(BoundedIntText(value=1024, description='Input Tokens:', max=20480, min=1, step=10), Boun…

<function __main__.interactive_chunked_model_runtime(input_tokens, output_tokens, output_batch)>

In [7]:
summary_table = get_chunked_model_runtime()

In [8]:
model = create_full_chunked_model(512, 'gpt-3', 512, [1024]*250, tensor_parallel=8)
model_df = get_model_df(model, system, batch_size=1,  model_characterstics = False)
summary_table = get_summary_table(model_df, model_characterstics = False)
runtime_df = get_runtime_breakdown(model_df) 

In [9]:
summary_table

Unnamed: 0,MACs (MFLOP),Total Data (MB),Total Weights (MB),Unused Weights (MB),KV Cache (MB),On-chip Memory Footprint (MB),Latency (msec),Cycles,Attn Latency (msec),Linear Latency (msec),Comm Latency (msec)
0,22655250.0,97434.75666,20883.28125,0.0,72217.6875,82.708984,47.455851,47455850.0,25.286364,7.988627,14.18086


In [10]:
summary_table

Unnamed: 0,MACs (MFLOP),Total Data (MB),Total Weights (MB),Unused Weights (MB),KV Cache (MB),On-chip Memory Footprint (MB),Latency (msec),Cycles,Attn Latency (msec),Linear Latency (msec),Comm Latency (msec)
0,22655250.0,97434.75666,20883.28125,0.0,72217.6875,82.708984,47.455851,47455850.0,25.286364,7.988627,14.18086


In [11]:
model = create_full_decode_model(1024, 'gpt-3', output_gen_tokens=1, tensor_parallel=8)
model_df = get_model_df(model, system, batch_size=250,  model_characterstics = False)
summary_table = get_summary_table(model_df, model_characterstics = False)
runtime_df = get_runtime_breakdown(model_df) 

In [12]:
summary_table

Unnamed: 0,MACs (MFLOP),Total Data (MB),Total Weights (MB),Unused Weights (MB),KV Cache (MB),On-chip Memory Footprint (MB),Latency (msec),Cycles,Attn Latency (msec),Linear Latency (msec),Comm Latency (msec)
0,11061390.0,95134.930351,20809.640625,0.0,72070.3125,378.295898,42.698226,42698230.0,25.183851,7.364275,10.150099


In [13]:
display_df(simplify_df(model_df))

Unnamed: 0,Layer Name,Op Type,Dimension,Op Intensity,Latency (msec),Num ops (MFLOP),Input_a (MB),Input_w (MB),Output (MB),Total Data (MB),Compute time (msec),Memory time (msec),Communication time (msec),Bound,C/M ratio,Cycles,% of total time,Throughput (Tflops),Compute cycle,Memory cycle,C Effcy,Communication cycle
0,QKV,GEMM,"[((250, 12288, 1), (4608, 12288), (250, 4608, 1))]",465.290049,1.834262,2717908.992,281.25,5184.0,105.46875,5570.71875,0.849347,1.834262,0.0,Memory,0.463045,1834261.757987,4.295873,1481.745438,849346.56,1834261.757987,0.8,0.0
1,Logit Pre,Logit,"((250, 12, 1, 128), (250, 12, 1024, 128), (250, 12, 1, 1024))",1.982575,12.577261,75497.472,35.15625,36000.0,281.25,36316.40625,0.023593,12.577261,0.0,Memory,0.001876,12577261.243548,29.456168,6.002696,23592.96,12577261.243548,0.8,0.0
2,Logit Suf,Logit,"((250, 12, 1, 128), (250, 12, 1, 128), (250, 12, 1, 1))",0.996109,0.014664,73.728,35.15625,35.15625,0.274658,70.587158,2.3e-05,0.014664,0.0,Memory,0.001571,14664.339168,0.034344,5.027707,23.04,14664.339168,0.8,0.0
3,Attend Pre,Attend,"((250, 12, 1, 1024), (250, 12, 1024, 128), (250, 12, 1, 128))",1.982575,12.577261,75497.472,281.25,36000.0,35.15625,36316.40625,0.023593,12.577261,0.0,Memory,0.001876,12577261.243548,29.456168,6.002696,23592.96,12577261.243548,0.8,0.0
4,Attend Suf,Attend,"((250, 12, 1, 1), (250, 12, 1, 128), (250, 12, 1, 128))",0.996109,0.014664,73.728,0.274658,35.15625,35.15625,70.587158,2.3e-05,0.014664,0.0,Memory,0.001571,14664.339168,0.034344,5.027707,23.04,14664.339168,0.8,0.0
5,Out Proj,GEMM,"[((250, 1536, 1), (12288, 1536), (250, 12288, 1))]",422.616591,0.624136,905969.664,35.15625,1728.0,281.25,2044.40625,0.283116,0.624136,0.0,Memory,0.453612,624136.243548,1.461738,1451.55753,283115.52,624136.243548,0.8,0.0
6,MHA AR,Sync,"(250, 1, 12288)",0.0,5.045546,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.045546,Collective,0.0,5045546.310425,11.816759,0.0,0.0,0.0,0.8,5045546.310425
7,up+gate,GEMM,"[((250, 12288, 1), (6144, 12288), (250, 6144, 1))]",471.23792,2.439325,3623878.656,281.25,6912.0,140.625,7333.875,1.132462,2.439325,0.0,Memory,0.464252,2439324.515206,5.712941,1485.607443,1132462.08,2439324.515206,0.8,0.0
8,down,GEMM,"[((250, 6144, 1), (12288, 6144), (250, 12288, 1))]",471.23792,2.439325,3623878.656,140.625,6912.0,281.25,7333.875,1.132462,2.439325,0.0,Memory,0.464252,2439324.515206,5.712941,1485.607443,1132462.08,2439324.515206,0.8,0.0
9,FFN AR,Sync,"(250, 1, 12288)",0.0,5.045546,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.045546,Collective,0.0,5045546.310425,11.816759,0.0,0.0,0.0,0.8,5045546.310425
