In [2]:
import torch

import time

In [None]:


# -------------------------------
# config
# -------------------------------
M = 4096
N = 4096
K = 4096

dtype = torch.bfloat16
device = "cuda"

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

# -------------------------------
# tensors
# -------------------------------
A = torch.randn(M, K, device=device, dtype=dtype)
B = torch.randn(K, N, device=device, dtype=dtype)

# warmup
for _ in range(20):
    C = A @ B

torch.cuda.synchronize()

# -------------------------------
# timing
# -------------------------------
iters = 200

start = time.time()
for _ in range(iters):
    C = A @ B
torch.cuda.synchronize()
end = time.time()

elapsed = (end - start) / iters

# -------------------------------
# FLOPs
# -------------------------------
# GEMM = 2*M*N*K
flops = 2 * M * N * K
tflops = flops / elapsed / 1e12

print("=========================================")
print(f"Matrix: {M} x {K} @ {K} x {N}")
print(f"Average time: {elapsed * 1e3:.6f} ms")
print(f"Throughput: {tflops:.2f} TFLOP/s")
print("=========================================")

Matrix: 4096 x 4096 @ 4096 x 4096
Average time: 0.761187 ms
Throughput: 180.56 TFLOP/s
