# Description

In this notebook, I will implement the scaled-dot product operation and compare the speed between 
- Torch built-in function
- Normal bfloat16 scaled-dot product.
- Int8 scaled-dot product

In [None]:
import os 
import math
import time
import functools
import numpy as np 

import torch
from torch.nn.functional import scaled_dot_product_attention

import torch_cuda_ext

In [None]:
def cuda_memory_profiler(device="cuda"):
    """
    Decorator that measures GPU memory usage (and runtime) for any function.
    
    Reports:
      - Δpeak (max temporary memory used)
      - Δcurrent (net memory retained after execution)
      - runtime (optional)
    """
    def decorator(func):
        def wrapper(*args, **kwargs):
            # synchronize before measuring
            torch.cuda.synchronize(device)
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats(device)
            
            before = torch.cuda.memory_allocated(device)

            result = func(*args, **kwargs)  # run the function

            torch.cuda.synchronize(device)
            peak = torch.cuda.max_memory_allocated(device)
            
            delta_peak = peak - before

            msg = (f"[{func.__name__}] Δpeak: {delta_peak/1e6:.2f} MB")
            print(msg)

            return result
        return wrapper
    return decorator

def cuda_time_profiler(print_time=True):
    """
    Decorator to measure GPU execution time of a function using CUDA events.
    Works only if at least one tensor is on the CUDA device.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            
            # Warm-up
            func(*args, **kwargs)
            torch.cuda.synchronize() # Ensure all previous CUDA ops are done

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            n_iter = 10

            start.record()
            for i in range(n_iter):
                result = func(*args, **kwargs)
            end.record()

            torch.cuda.synchronize()

            elapsed_ms = start.elapsed_time(end) / n_iter  # Average time per iteration
            if print_time:
                print(f"[{func.__name__}] elapsed: {elapsed_ms:.3f} ms")

            return result
        return wrapper
    return decorator

def l2_norm(tensor1, tensor2):
    return torch.sqrt(torch.nansum((tensor1 - tensor2) ** 2))

# 1. Torch built-in scaled do product

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

dtype = torch.float16

In [None]:
BATCH_SIZE = 16
SEQUENCE_LENGTH = 1024 
D_MODEL = 512
# N_HEADS = 8

Q = torch.randn(BATCH_SIZE, SEQUENCE_LENGTH, D_MODEL, device=device, dtype=dtype)
K = torch.randn(BATCH_SIZE, SEQUENCE_LENGTH, D_MODEL, device=device, dtype=dtype)
V = torch.randn(BATCH_SIZE, SEQUENCE_LENGTH, D_MODEL, device=device, dtype=dtype)

# warmup
for _ in range(10):
    _ = scaled_dot_product_attention(Q, K, V, attn_mask=None, dropout_p=0.0, is_causal=False)

In [None]:
@cuda_memory_profiler()
@cuda_time_profiler()
def torch_built_in_scaled_dot_product(Q, K, V):
    z = scaled_dot_product_attention(Q, K, V)
    return z

z = torch_built_in_scaled_dot_product(Q, K, V)
print(f"Shape of output: {z.shape}")
print(f"Dtype of output: {z.dtype}")  

# 2. Normal (float) scaled-dot product

In [None]:
@cuda_memory_profiler()
@cuda_time_profiler()
def custom_scaled_dot_product(Q, K, V):
    d_k = Q.size(-1)
    scale = 1 / math.sqrt(d_k)
    
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    attn_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    return output

In [None]:
z_custom = custom_scaled_dot_product(Q, K, V)
print(f"Shape of custom output: {z_custom.shape}")
print(f"Dtype of custom output: {z_custom.dtype}")

In [None]:
print()
if l2_norm(z, z_custom) < 1.0:
    print("Outputs are close! - l2 norm difference:", l2_norm(z, z_custom).item())
else:
    raise Exception("[ERROR] Outputs differ !!!")

# 3. Int8 scaled-dot product

In [None]:
def quantize_tensor_asymmetric_int8(mat:torch.Tensor):
    """
    mat: input float tensor (e.g., torch.float32 or torch.bfloat16)
    """
    min_val = mat.min()
    max_val = mat.max()
    
    qmin = -128
    qmax = 127
    scale = (max_val - min_val) / (qmax - qmin)
    zero_point = qmin - torch.round(min_val / scale).item()
    
    q_mat = torch.clamp(torch.round(mat / scale) + zero_point, qmin, qmax).to(torch.int8)
    return q_mat, scale, zero_point

def quantize_tensor_symmetric_int8(mat:torch.Tensor):
    """
    Symmetric quantization to int8.
    mat: input float tensor (e.g., torch.float32 or torch.bfloat16)
    """
    max_val = torch.max(torch.abs(mat))
    
    qmin = -128
    qmax = 127
    scale = max_val / qmax
    zero_point = 0  # For symmetric quantization, zero_point is typically 0
    
    q_mat = torch.clamp(torch.round(mat / scale), qmin, qmax).to(torch.int8)
    return q_mat, scale, zero_point

def dequantize_tensor_int8(q_mat, scale:float, zero_point:int):
    """
    De-quantize an int8 tensor back to float using the provided scale and zero_point.
    q_mat: input int8 tensor
    scale: float scaling factor
    zero_point: integer zero point
    """
    return scale * (q_mat.float() - zero_point)

def quantization_error(original, dequantized):
    """
    Compute the relative error between the original and dequantized tensors using l2 norm.
    """
    return torch.norm(original - dequantized) / torch.norm(original)

In [None]:
# @cuda_memory_profiler()
# @cuda_time_profiler()
# def scaled_dot_product_int8(Q, K, V):
#     dk = Q.size(-1)
#     scale = 1.0 / math.sqrt(dk)
    
#     # Quantize Q, K, V
#     Q_q, Q_scale, Q_zp = quantize_tensor_symmetric_int8(Q)
#     K_q, K_scale, K_zp = quantize_tensor_symmetric_int8(K)
    
#     Q_q = Q_q.view(-1, D_MODEL)
#     K_q = K_q.view(-1, D_MODEL)

#     scores_int32 = int8_linear_matmul(Q_q, K_q, dtype=torch.int32)
#     scores = scores_int32.view(BATCH_SIZE, SEQUENCE_LENGTH, SEQUENCE_LENGTH)
#     scores = scores.to(dtype) * (Q_scale * K_scale) * scale # dequantize
    
#     attn_weights = torch.softmax(scores, dim=-1)
#     output = torch.matmul(attn_weights, V)
#     return output


@cuda_memory_profiler()
@cuda_time_profiler()
def scaled_dot_product_int8(Q_q, Q_scale, K_q, K_scale, V):
    dk = Q_q.size(-1)
    scale = 1.0 / math.sqrt(dk)
    
    scores_int32 = torch_cuda_ext.bmm_int8(Q_q, K_q)
    scores = scores_int32.to(dtype) * (Q_scale * K_scale) * scale # dequantize
    
    attn_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    
    return output

In [None]:
# Quantize Q, K, V
Q_q, Q_scale, Q_zp = quantize_tensor_symmetric_int8(Q)
K_q, K_scale, K_zp = quantize_tensor_symmetric_int8(K)
K_q_transpose = K_q.transpose(-2, -1).contiguous()

z_int8 = scaled_dot_product_int8(Q_q, Q_scale, K_q_transpose, K_scale, V)
print(f"Shape of int8 output: {z_int8.shape}")
print(f"Dtype of int8 output: {z_int8.dtype}")

print()
if l2_norm(z, z_int8) / BATCH_SIZE < 1.0:
    print("Int8 Outputs are close! - l2 norm difference:", l2_norm(z, z_int8).item())
else:
    print("L2 norm difference:", l2_norm(z, z_int8).item())
    raise Exception("[ERROR] Int8 Outputs differ !!!")