In [1]:
import torch

  import pynvml  # type: ignore[import]


In [9]:
import torch


@torch.no_grad()
def bench_gemv_us(
    D=128,
    dtype=torch.float32,
    ks=(128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536),
    iters=10000,
):
    device = "cuda"
    q = torch.randn(D, device=device, dtype=dtype)

    start = torch.cuda.Event(True)
    end = torch.cuda.Event(True)

    out = []
    for k in ks:
        M = torch.randn((k, D), device=device, dtype=dtype)

        # warmup
        for _ in range(200):
            _ = M @ q
        torch.cuda.synchronize()

        start.record()
        for _ in range(iters):
            _ = M @ q
        end.record()
        torch.cuda.synchronize()

        us = (start.elapsed_time(end) * 1000.0) / iters  # ms -> us
        out.append((k, us))

    return out


results = bench_gemv_us(D=128)
for k, us in results:
    print(f"k={k:6d}  {us:8.2f} us  ({us/k:10.6f} us/row)")

k=   128      7.33 us  (  0.057266 us/row)
k=   256      5.68 us  (  0.022170 us/row)
k=   512      5.67 us  (  0.011073 us/row)
k=  1024      5.67 us  (  0.005540 us/row)
k=  2048      5.89 us  (  0.002877 us/row)
k=  4096      5.89 us  (  0.001437 us/row)
k=  8192      5.69 us  (  0.000695 us/row)
k= 16384      5.69 us  (  0.000347 us/row)
k= 32768      8.19 us  (  0.000250 us/row)
k= 65536     12.28 us  (  0.000187 us/row)
