In [None]:
# Torch
import torch

# Triton 

# Utils
import time

# Local
from attention_impl.naive_mla import MLA as DS_MLA
from attention_impl.Shared_Args import Args
from attention_impl.General_Layers import precompute_freqs_cis, EmbeddingLayer

In [None]:
default_dtype = "bf16"
Naive_Args = Args(dtype=default_dtype)
Absorb_Args = Args(attn_impl = "absorb", dtype=default_dtype)

In [None]:
Naive_MLA = DS_MLA(Naive_Args).cuda().eval()
Absorbed_MLA = DS_MLA(Absorb_Args).cuda().eval()

In [None]:
# Shared Args
qk_rope_dim = Naive_Args.qk_rope_head_dim
vocab_size = Naive_Args.vocab_size
dim = Naive_Args.dim

if Naive_Args.dtype == "bf16":
    dtype = torch.bfloat16
else:
    dtype = torch.float8_e4m3fn

# Set Default Torch dtype and device
torch.set_default_device('cuda')
torch.set_default_dtype(dtype)
torch.manual_seed(0)

In [None]:
# Setup Args
bsz = 4
seqlen = 512

# Create Embedding Layer
embedding_layer = EmbeddingLayer(vocab_size,dim)

# Create X then Embedd it
x = torch.randint(0, vocab_size, (bsz, seqlen), device='cuda')
x_emb = embedding_layer(x)
print(x_emb.size())

# Create a Freq_Cis and Mask
freqs_cis = precompute_freqs_cis(Naive_Args)
mask = torch.full((seqlen, seqlen), float("-inf"), device=x.device).triu_(1)

# Set Start Position
start_pos = 0

In [None]:
# Warm-up
for _ in range(10):
    freqs_cis = freqs_cis[start_pos:start_pos+seqlen]
    _ = Naive_MLA(x_emb, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)

# 4. Timed benchmarking
torch.cuda.synchronize()
start = time.time()
for _ in range(50):
    _ = Naive_MLA(x, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)
torch.cuda.synchronize()
end = time.time()

print(f"Avg forward pass time: {(end - start)/50:.6f} sec")