Open
Description
def flash_mla(): torch.cuda.synchronize() tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
I added a sync()
, and found that the performance was much worse. With sync()
, it took 360us, while without it, it only took 50us.
Why does it feel like the cost time is a CPU's time? (The kernel submits asynchronously and hasn't finished executing yet.)
Metadata
Metadata
Assignees
Labels
No labels