In [1]:
%load_ext autoreload
%autoreload 2

In [30]:

from GenZ import get_model_df, System, get_summary_table, simplify_df, get_configs, ParallelismConfig, get_runtime_breakdown
from GenZ.Models.get_language_model import create_full_chunked_model, create_full_decode_model
from GenZ.Models.attention import mha_flash_attention_chunked
from GenZ.utils.plot_rooflines import display_df
import os
import pandas as pd


In [3]:
system = System(frequency=1000 , flops=2000, off_chip_mem_size=(80*1024), compute_efficiency=0.8, memory_efficiency=0.8,
                    offchip_mem_bw=3500, bits='int8', external_mem_bw=128, interchip_link_bw=256, interchip_link_latency=2, num_nodes=8)

In [4]:
get_configs('meta-llama/meta-llama-3.1-405b')

<GenZ.Models.default_models.ModelConfig at 0x797d3c226c10>

In [24]:
def get_chunked_model_runtime(model_name = 'gpt-3', chunk_size = 256, system = System(), input_tokens = 1024, output_tokens = [1024]*8, batch_size = 1):
    
    model = create_full_chunked_model(chunk_size, model_name, input_tokens, output_tokens, tensor_parallel=8)
    model_df = get_model_df(model, system, batch_size=1,  model_characterstics = False)
    summary_table = get_summary_table(model_df, model_characterstics = False)
    runtime_df = get_runtime_breakdown(model_df) 
    return summary_table, runtime_df

In [28]:
import ipywidgets as widgets
from ipywidgets import interact

def interactive_chunked_model_runtime(input_tokens, output_tokens, output_batch):
    runtime_plot_df = pd.DataFrame(columns=['Layer Name', 'Runtime', 'Chunk Size'])
    for chunk_size in range(256, 2049, 256):
        summary_table, runtime_df = get_chunked_model_runtime(chunk_size=chunk_size, input_tokens=input_tokens, output_tokens=[output_tokens]*output_batch, system=system)
        # display(summary_table)
        # display(runtime_df)
    
        df = pd.DataFrame(
            [['Embedding', runtime_df.Embedding, chunk_size],
            ['Collective', runtime_df.Collective, chunk_size],
            ['LA_layers', runtime_df.LA_layers, chunk_size], 
            ['QKVO_layers', runtime_df.QKVO_layers, chunk_size],
            ['FFN_layers', runtime_df.FFN_layers, chunk_size]],
            columns=['Layer Name', 'Runtime', 'Chunk Size'])
        runtime_plot_df = pd.concat([runtime_plot_df, df])

    import plotly.express as px

    fig = px.bar(runtime_plot_df, x='Chunk Size', y='Runtime', color='Layer Name', barmode='group',
                title='Runtime Breakdown by Chunk Size, Input Tokens = {}, Output Tokens = {}, Output Batch = {}'.format(input_tokens, output_tokens, output_batch))
    fig.show()
    # return fig
    # display(runtime_plot_df)

input_tokens_slider = widgets.BoundedIntText(min=1, max=20480, step=10, value=1024, description='Input Tokens:')
output_tokens_slider = widgets.BoundedIntText(min=1, max=20480, step=10, value=1024, description='Output Tokens:')
batch_slider = widgets.BoundedIntText(min=1, max=512, step=8, value=8, description='Output Batch:')

interact(interactive_chunked_model_runtime, input_tokens=input_tokens_slider, output_tokens=output_tokens_slider, output_batch=batch_slider)

interactive(children=(BoundedIntText(value=1024, description='Input Tokens:', max=20480, min=1, step=10), Boun…

<function __main__.interactive_chunked_model_runtime(input_tokens, output_tokens, output_batch)>

In [7]:
summary_table = get_chunked_model_runtime()

In [26]:
model = create_full_chunked_model(512, 'gpt-3', 512, [1024]*250, tensor_parallel=8)
model_df = get_model_df(model, system, batch_size=1,  model_characterstics = False)
summary_table = get_summary_table(model_df, model_characterstics = False)
runtime_df = get_runtime_breakdown(model_df) 

In [21]:
summary_table

Unnamed: 0,MACs (MFLOP),Total Data (MB),Total Weights (MB),Unused Weights (MB),KV Cache (MB),On-chip Memory Footprint (MB),Latency (msec),Cycles,Attn Latency (msec),Linear Latency (msec),Comm Latency (msec)
0,22732034.31014,25939.9962,20883.28125,0.0,575.71875,82.70898,22.51949,22519492.33828,0.32899,7.99016,14.20034


In [27]:
summary_table

Unnamed: 0,MACs (MFLOP),Total Data (MB),Total Weights (MB),Unused Weights (MB),KV Cache (MB),On-chip Memory Footprint (MB),Latency (msec),Cycles,Attn Latency (msec),Linear Latency (msec),Comm Latency (msec)
0,22655247.77779,97434.75666,20883.28125,0.0,72217.6875,82.70898,47.45585,47455850.6137,25.28636,7.98863,14.18086


In [31]:
model = create_full_decode_model(1024, 'gpt-3', output_gen_tokens=1, tensor_parallel=8)
model_df = get_model_df(model, system, batch_size=250,  model_characterstics = False)
summary_table = get_summary_table(model_df, model_characterstics = False)
runtime_df = get_runtime_breakdown(model_df) 

In [32]:
summary_table

Unnamed: 0,MACs (MFLOP),Total Data (MB),Total Weights (MB),Unused Weights (MB),KV Cache (MB),On-chip Memory Footprint (MB),Latency (msec),Cycles,Attn Latency (msec),Linear Latency (msec),Comm Latency (msec)
0,11061387.264,95134.93035,20809.64062,0.0,72070.3125,378.2959,42.69823,42698225.8153,25.18385,7.36428,10.1501


In [16]:
display_df(simplify_df(model_df))

Unnamed: 0,Layer Name,Op Type,Dimension,Op Intensity,Latency (msec),Num ops (MFLOP),Input_a (MB),Input_w (MB),Output (MB),Total Data (MB),Compute time (msec),Memory time (msec),Communication time (msec),Bound,C/M ratio,Cycles,% of total time,Throughput (Tflops),Compute cycle,Memory cycle,C Effcy,Communication cycle
0,embeddings,GEMM,"[((1, 16032, 480), (16384, 16032), (1, 16384, 480))]",906.312756,0.092543,252161.55648,7.338867,250.5,7.5,265.338867,0.0788,0.092543,0.0,Memory,0.851503,92542.852674,0.226627,2724.808553,78800.4864,92542.852674,0.8,0.0
1,Emb_AR,Sync,"(1, 480, 16384)",0.0,0.083068,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.083068,Collective,0.0,83067.901611,0.203424,0.0,0.0,0.0,0.8,83067.901611
2,QKV,GEMM,"[((1, 16384, 512), (2304, 16384), (1, 2304, 512))]",816.930748,1.660004,4870492.913664,1008.0,4536.0,141.75,5685.75,1.522029,1.660004,0.0,Memory,0.916883,1660003.662109,4.065164,2934.025403,1522029.03552,1660003.662109,0.8,0.0
3,Logit Pre,Logit,"((1, 16, 480, 128), (1, 1, 992, 128), (1, 16, 480, 992))",223.444392,0.108604,245744.27136,118.125,15.257812,915.46875,1048.851562,0.076795,0.108604,0.0,Memory,0.707108,108604.431152,0.26596,2262.746269,76795.0848,108604.431152,0.8,0.0
4,Attend Pre,Attend,"((1, 16, 480, 992), (1, 1, 992, 128), (1, 16, 480, 128))",223.444392,0.108604,245744.27136,915.46875,15.257812,118.125,1048.851562,0.076795,0.108604,0.0,Memory,0.707108,108604.431152,0.26596,2262.746269,76795.0848,108604.431152,0.8,0.0
5,Logit Dec,Logit,"((1, 16, 1, 128), (1, 1, 1024, 128), (1, 16, 1, 1024))",28.054795,0.005643,528.482304,0.246094,15.75,1.96875,17.964844,0.000165,0.005643,0.0,Memory,0.029265,5643.367767,0.01382,93.646618,165.15072,5643.367767,0.8,0.0
6,Attend Dec,Attend,"((1, 16, 1, 1024), (1, 1, 1024, 128), (1, 16, 1, 128))",28.054795,0.005643,528.482304,1.96875,15.75,0.246094,17.964844,0.000165,0.005643,0.0,Memory,0.029265,5643.367767,0.01382,93.646618,165.15072,5643.367767,0.8,0.0
7,Logit Dec,Logit,"((1, 16, 1, 128), (1, 1, 1024, 128), (1, 16, 1, 1024))",28.054795,0.005643,528.482304,0.246094,15.75,1.96875,17.964844,0.000165,0.005643,0.0,Memory,0.029265,5643.367767,0.01382,93.646618,165.15072,5643.367767,0.8,0.0
8,Attend Dec,Attend,"((1, 16, 1, 1024), (1, 1, 1024, 128), (1, 16, 1, 128))",28.054795,0.005643,528.482304,1.96875,15.75,0.246094,17.964844,0.000165,0.005643,0.0,Memory,0.029265,5643.367767,0.01382,93.646618,165.15072,5643.367767,0.8,0.0
9,Logit Dec,Logit,"((1, 16, 1, 128), (1, 1, 1024, 128), (1, 16, 1, 1024))",28.054795,0.005643,528.482304,0.246094,15.75,1.96875,17.964844,0.000165,0.005643,0.0,Memory,0.029265,5643.367767,0.01382,93.646618,165.15072,5643.367767,0.8,0.0
