In [1]:
# Torch
import torch

# Triton 

# Utils
import time

# Local
from attention_impl.naive_mla import MLA as DS_MLA
from attention_impl.Shared_Args import Args
from attention_impl.General_Layers import precompute_freqs_cis, EmbeddingLayer

In [2]:
default_dtype = "bf16"
Naive_Args = Args(dtype=default_dtype)
Absorb_Args = Args(attn_impl = "absorb", dtype=default_dtype)

In [3]:
Naive_MLA = DS_MLA(Naive_Args).cuda().eval()
Absorb_MLA = DS_MLA(Absorb_Args).cuda().eval()

In [4]:
# Shared Args
qk_rope_dim = Naive_Args.qk_rope_head_dim
vocab_size = Naive_Args.vocab_size
dim = Naive_Args.dim

if Naive_Args.dtype == "bf16":
    dtype = torch.bfloat16
else:
    dtype = torch.float8_e4m3fn

# Set Default Torch dtype and device
torch.set_default_device('cuda')
torch.set_default_dtype(dtype)
torch.manual_seed(0)

<torch._C.Generator at 0x7ff337556cb0>

In [5]:
# Setup Args
bsz = 4
seqlen = 512

# Create Embedding Layer
embedding_layer = EmbeddingLayer(vocab_size,dim)

# Create X then Embedd it
x = torch.randint(0, vocab_size, (bsz, seqlen), device='cuda')
x_emb = embedding_layer(x)
print(x_emb.size())

# Create a Freq_Cis and Mask
freqs_cis = precompute_freqs_cis(Naive_Args)
mask = torch.full((seqlen, seqlen), float("-inf"), device=x.device).triu_(1)

# Set Start Position
start_pos = 0

torch.Size([4, 512, 2048])


In [6]:
def simple_benchmark(MLA_impl, x_emb, start_pos, freqs_cis, mask, num_warmup, num_run):
    # Warm-up
    for _ in range(num_warmup):
        freqs_cis = freqs_cis[start_pos:start_pos+seqlen]
        _ = MLA_impl(x_emb, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)

    # Timed benchmarking
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_run):
        _ = MLA_impl(x_emb, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)
    torch.cuda.synchronize()
    end = time.time()

    print(f"Avg forward pass time: {(end - start)/num_run:.6f} sec")

simple_benchmark(Naive_MLA, x_emb, start_pos, freqs_cis, mask, 10, 500)
simple_benchmark(Absorb_MLA, x_emb, start_pos, freqs_cis, mask, 10, 500)


RuntimeError: expected scalar type Float but found BFloat16

In [None]:
simple_benchmark(Naive_MLA, x_emb, start_pos, freqs_cis, mask, 10, 50)
simple_benchmark(Absorb_MLA, x_emb, start_pos, freqs_cis, mask, 10, 50)