# Description

In this notebook, we will implement the normal matmul and multi-head attention to check the correctness and time

In [1]:
import os 
import torch 
import torch.nn as nn
import gemm_int8

# 1. Matmul

In [2]:
def quantize_row_int8_symmetric(mat: torch.Tensor):
    """
    Symmetric int8 quantization per row.
    mat: (N, M) float tensor
    Returns:
      q_mat: (N, M) int8
      scales: (N,) float32
    """
    qmin, qmax = -128, 127
    
    max_vals = mat.abs().amax(dim=1, keepdim=True)  # (N, 1)
    max_vals = max_vals.clamp(min=1e-8)

    scales = (max_vals / qmax).squeeze(1)          # (N,)
    q_mat = torch.clamp(torch.round(mat / scales.unsqueeze(1)), qmin, qmax).to(torch.int8)

    return q_mat, scales.to(torch.float32)


def quantize_col_int8_symmetric(mat: torch.Tensor):
    """
    Symmetric int8 quantization per column.
    mat: (N, M) float tensor
    Returns:
      q_mat: (N, M) int8
      scales: (M,) float32
    """
    qmin, qmax = -128, 127

    max_vals = mat.abs().amax(dim=0, keepdim=True)  # (1, M)
    max_vals = max_vals.clamp(min=1e-8)

    scales = (max_vals / qmax).squeeze(0)           # (M,)
    q_mat = torch.clamp(torch.round(mat / scales.unsqueeze(0)), qmin, qmax).to(torch.int8)

    return q_mat, scales.to(torch.float32)

def quantize_tensor_int8_symmetric(tensor: torch.Tensor):
    """
    Symmetric int8 quantization for entire tensor.
    tensor: float tensor
    Returns:
      q_tensor: int8
      scale: float32
    """
    qmin, qmax = -128, 127

    max_val = tensor.abs().amax()  # scalar
    max_val = max_val.clamp(min=1e-8)

    scale = (max_val / qmax)          # scalar
    q_tensor = torch.clamp(torch.round(tensor / scale), qmin, qmax).to(torch.int8)

    return q_tensor, scale.to(torch.float32)

def dequant_int8_gemm(out_int: torch.Tensor,
                      x_scale: torch.Tensor,
                      w_scale: torch.Tensor,
                      out_dtype=torch.float16):
    """
    Dequantize result of INT8 matmul:
      out_int: (B, out_features) int32 or float32
      x_scale: (B,) from input rows
      w_scale: (out_features,) from weight rows/cols
    """
    if out_int.dtype == torch.int32:
        out_float = out_int.to(torch.float32)
    else:
        out_float = out_int

    # scale = x_scale.unsqueeze(1) * w_scale.unsqueeze(0)  # (B, out_features)
    # out = out_float * scale
    
    out = out_float * x_scale[:, None] * w_scale[None, :]
    
    return out.to(out_dtype)


In [3]:
def benchmark_func(func, *args, n_warmup=10, n_repeat=100):
    torch._dynamo.reset()
    # Warm-up
    for _ in range(n_warmup):
        func(*args)

    # Benchmark
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for _ in range(n_repeat):
        func(*args)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / n_repeat

    return avg_time_ms

In [4]:
N = 4096
M = 2048
K = 4096
device = 'cuda'
d_type = torch.float16

X = torch.randn(N, M, device=device, dtype=d_type)
W = torch.randn(K, M, device=device, dtype=d_type)

Y_true = torch.matmul(X, W.T)
print("Y_true:", Y_true.dtype, Y_true.shape)

torch_time = benchmark_func(torch.matmul, X, W.T)
print(f"PyTorch matmul time: {torch_time:.3f} ms")

Y_true: torch.float16 torch.Size([4096, 4096])
PyTorch matmul time: 0.417 ms


In [5]:
X_q, X_scale = quantize_row_int8_symmetric(X)
# X_q, X_scale = quantize_tensor_int8_symmetric(X)
W_q, W_scale = quantize_row_int8_symmetric(W)

if X_scale.numel() == 1:
    X_scale = X_scale.item()

In [6]:
@torch.compile(dynamic=True)
def gemm_int8_func(X_q, W_q, x_scale, w_scale):
    
    # if x_scale is not tensor
    if type(x_scale) is float:
        Y_int = gemm_int8.matmul(X_q, W_q, x_scale)
        Y_deq = Y_int * w_scale.unsqueeze(0)
    else:
        Y_int = gemm_int8.matmul(X_q, W_q, 1.0)
        Y_deq = dequant_int8_gemm(Y_int, x_scale, w_scale, out_dtype=d_type)
    return Y_deq

Y_deq = gemm_int8_func(X_q, W_q, X_scale, W_scale)
print("Y_deq:", Y_deq.dtype, Y_deq.shape)

int8_time = benchmark_func(gemm_int8_func, X_q, W_q, X_scale, W_scale)
print(f"INT8 matmul time: {int8_time:.3f} ms")

Y_deq: torch.float16 torch.Size([4096, 4096])
INT8 matmul time: 0.277 ms


In [7]:
threshold = 3.0
if torch.allclose(Y_true, Y_deq, atol=threshold, rtol=threshold):
    print("passed!")
else:
    print("================ INT8 matmul failed! ================")

passed!


# 2. Batched matmul

In [13]:
H = 8
D_k = 1024
L = 4096

Q = torch.randint(-127, 127, (H, L, D_k), device=device, dtype=torch.int8)
K = torch.randint(-127, 127, (H, L, D_k), device=device, dtype=torch.int8)

Q_fp = Q.to(torch.float16)
K_fp = K.to(torch.float16)

score = torch.matmul(Q_fp, K_fp.transpose(-2, -1)) 
print("score:", score.dtype, score.shape)

torch_time = benchmark_func(torch.matmul, Q_fp, K_fp.transpose(-2, -1))
print(f"PyTorch matmul time: {torch_time:.3f} ms")

score: torch.float16 torch.Size([8, 4096, 4096])
PyTorch matmul time: 1.811 ms


In [14]:
@torch.compile(dynamic=True)
def gemm_batched_int8_func(X_q, W_q):
    Y_int = gemm_int8.bmm_int8_matmul(X_q, W_q, 1.0)
    return Y_int

score_int = gemm_int8.bmm_int8_matmul(Q, K, 1.0)

print("score_int:", score_int.dtype, score_int.shape)

score_int: torch.bfloat16 torch.Size([8, 4096, 4096])


In [15]:
int8_time = benchmark_func(gemm_batched_int8_func, Q, K)
print(f"INT8 batched matmul time: {int8_time:.3f} ms")

INT8 batched matmul time: 1.401 ms


# 3. Scale dot product

In [18]:
def scale_dot_product(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        """
        Scaled dot-product attention.
        Q: (H, L, D_k)
        K: (H, L, D_k)
        V: (H, L, D_v)
        Returns:
        out: (B, N, D_v)
        """
        d_k = Q.size(-1)
        scale_dk = torch.sqrt(torch.tensor(d_k, dtype=Q.dtype, device=Q.device))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale_dk  # (B, N, M)

        attn_weights = torch.softmax(scores, dim=-1)  # (B, N, M)
        out = torch.matmul(attn_weights, V)           # (B, N, D_v)

        return out

In [19]:
L = 2048
H = 8
D_k = 2048

Q = torch.randn(H, L, D_k, device=device, dtype=d_type)
K = torch.randn(H, L, D_k, device=device, dtype=d_type)
V = torch.randn(H, L, D_k, device=device, dtype=d_type)

out_true = scale_dot_product(Q, K, V)
print("out_true:", out_true.dtype, out_true.shape)

torch_scale_dot_product_time = benchmark_func(scale_dot_product, Q, K, V)
print(f"PyTorch scaled dot-product attention time: {torch_scale_dot_product_time:.3f} ms")

out_true: torch.float16 torch.Size([8, 2048, 2048])
PyTorch scaled dot-product attention time: 2.566 ms


In [20]:
def quantize_row_matrix_int8_symmetric_batched(mat: torch.Tensor):
    """
    Symmetric per-row quantization for batched 3D tensor.
    mat: [B, N, D]  (float tensor)
    
    Returns:
        q_mat:   [B, N, D] int8
        scales:  [B, N]    float32  (scale per row within each batch)
    """
    qmin, qmax = -128, 127

    # Compute max abs per row (per batch) - Result shape: [B, N, 1]
    max_vals, _ = torch.max(torch.abs(mat), dim=2, keepdim=True)

    # Compute scales per row
    scales = (max_vals / qmax).clamp(min=1e-12)  # avoid div-by-zero, shape [B, N, 1]

    # Quantize
    q_mat = torch.clamp(torch.round(mat / scales), qmin, qmax).to(torch.int8)

    # Return float scales of shape [B, N]
    scales = scales.squeeze(2).to(torch.float32)
    return q_mat, scales      

@torch.compile(dynamic=True)
def scale_dot_product_int8(Q_q, Q_scale, K_q, K_scale, V_q, V_scale):
    """
    Scaled dot-product attention with INT8 matmul.
    Q: (H, L, D_k) int8
    K: (H, L, D_k) int8
    V: (H, L, D_v) float16
    Returns:
    out: (H, L, D_v) float16
    """
    d_k = Q_q.size(-1)
    scale_dk = torch.sqrt(torch.tensor(d_k, dtype=torch.float32, device=Q.device))
    
    scores_int = gemm_int8.bmm_int8_matmul(Q_q, K_q, 1.0)  # (H, L, L) int32
    scores = Q_scale.unsqueeze(-1) * scores_int * K_scale.unsqueeze(1)
    scores = scores / scale_dk   # (H, L, L) float32
    scores = scores.to(torch.float16)

    attn_weights = torch.softmax(scores, dim=-1)       # (H, L, L) float32
    
    attn_weights_q, attn_weights_scale = quantize_row_matrix_int8_symmetric_batched(attn_weights)
    out_int = gemm_int8.bmm_int8_matmul(attn_weights_q, V_q)     # (H, L, D_v) float16
    out = attn_weights_scale.unsqueeze(-1) * out_int * V_scale.unsqueeze(1)

    return out.to(torch.float16)

In [21]:
Q_q, Q_scale = quantize_row_matrix_int8_symmetric_batched(Q)
K_q, K_scale = quantize_row_matrix_int8_symmetric_batched(K)
V_q, V_scale = quantize_row_matrix_int8_symmetric_batched(V)

out = scale_dot_product_int8(Q_q, Q_scale, K_q, K_scale, V_q, V_scale)
print("out:", out.dtype, out.shape)

out: torch.float16 torch.Size([8, 2048, 2048])


In [22]:
scale_dot_product_int8_time = benchmark_func(scale_dot_product_int8, Q_q, Q_scale, K_q, K_scale, V_q, V_scale)
print(f"INT8 scaled dot-product attention time: {scale_dot_product_int8_time:.3f} ms")

INT8 scaled dot-product attention time: 1.902 ms


In [None]:
threshold = 2.0
if torch.allclose(out_true, out, atol=threshold, rtol=threshold):
    print("passed!")
else:
    print("================ INT8 scaled dot-product attention failed! ================")

passed!


: 