# Description

In this notebook, I will verify the custom torch extension kernel with
- The correctness.
- The peak usage memory

In [21]:
import os 
import numpy as np
import time
import functools
import torch
import torch_cuda_ext

In [2]:
print("\n=== PyTorch Info ===")
print(f"Torch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"torch.version.cuda: {torch.version.cuda}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")


=== PyTorch Info ===
Torch version: 2.5.1+cu121
CUDA available: True
torch.version.cuda: 12.1
GPU name: NVIDIA RTX A5000


In [3]:
def cuda_memory_profiler(device="cuda"):
    """
    Decorator that measures GPU memory usage (and runtime) for any function.
    
    Reports:
      - Δpeak (max temporary memory used)
      - Δcurrent (net memory retained after execution)
      - runtime (optional)
    """
    def decorator(func):
        def wrapper(*args, **kwargs):
            # synchronize before measuring
            torch.cuda.synchronize(device)
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats(device)
            
            before = torch.cuda.memory_allocated(device)

            result = func(*args, **kwargs)  # run the function

            torch.cuda.synchronize(device)
            peak = torch.cuda.max_memory_allocated(device)
            
            delta_peak = peak - before

            msg = (f"[{func.__name__}] Δpeak: {delta_peak/1e6:.2f} MB")
            print(msg)

            return result
        return wrapper
    return decorator

def cuda_time_profiler(print_time=True):
    """
    Decorator to measure GPU execution time of a function using CUDA events.
    Works only if at least one tensor is on the CUDA device.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            
            # Warm-up
            func(*args, **kwargs)
            torch.cuda.synchronize() # Ensure all previous CUDA ops are done

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            n_iter = 10

            start.record()
            for i in range(n_iter):
                result = func(*args, **kwargs)
            end.record()

            torch.cuda.synchronize()

            elapsed_ms = start.elapsed_time(end) / n_iter  # Average time per iteration
            if print_time:
                print(f"[{func.__name__}] elapsed: {elapsed_ms:.3f} ms")

            return result
        return wrapper
    return decorator

def l2_norm(tensor1, tensor2):
    return torch.sqrt(torch.nansum((tensor1 - tensor2) ** 2))

# 1. Check correctness

In [4]:
# The float matmul test
A = torch.randn(1_000, 2_000, device="cuda")
B = torch.randn(2_000, 3_000, device="cuda")

custom_matmul = torch_cuda_ext.matmul_f32(A, B)
torch_matmul = torch.matmul(A, B)

if torch.allclose(custom_matmul, torch_matmul, atol=1e-3):
    print("--- Correct float matmul implementation! ---")
else:
    print("[ERROR] Incorrect float matmul implementation !!!")

--- Correct float matmul implementation! ---


In [5]:
# The float matmul CUTLASS test
A = torch.randn(1_000, 2_000, device="cuda")
B = torch.randn(2_000, 3_000, device="cuda")

custom_matmul_cutlass = torch_cuda_ext.matmul_f32_cutlass(A, B)
torch_matmul = torch.matmul(A, B)

if torch.allclose(custom_matmul_cutlass, torch_matmul, atol=1e-3):
    print("--- Correct CUTLASS implementation! ---")
else:
    print("[ERROR] Incorrect float matmul CUTLASS implementation !!!")

--- Correct CUTLASS implementation! ---


In [6]:
# The int8 matmul test
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")
custom_matmul_int8 = torch_cuda_ext.matmul_int8(A, B)
torch_matmul_int8 = torch.matmul(A.float(), B.float())

if torch.allclose(custom_matmul_int8.float(), torch_matmul_int8, atol=1e-2):
    print("--- Correct implementation! ---")
else:
    print("[ERROR] Incorrect int8 matmul implementation !!!")

--- Correct implementation! ---


In [7]:
# The int8 matmul CUTLASS test
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")
custom_matmul_int8_cutlass = torch_cuda_ext.matmul_int8_cutlass(A, B)
torch_matmul_int8 = torch.matmul(A.float(), B.float())

if torch.allclose(custom_matmul_int8_cutlass.float(), torch_matmul_int8, atol=1e-2):
    print("--- Correct CUTLASS implementation! ---")
else:
    print("[ERROR] Incorrect int8 matmul CUTLASS implementation !!!")

--- Correct CUTLASS implementation! ---


In [8]:
# The int8 matmul + convert to float test
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")
custom_matmul_int8_f16 = torch_cuda_ext.matmul_int8_to_fp16_scaled_forward_noc(A, B, 1.0)
print(custom_matmul_int8_f16.dtype)
print(custom_matmul_int8_f16.shape)

torch.float16
torch.Size([1000, 3000])


In [9]:
# The batched int8 matmul test
A = torch.randint(-127, 127, (10, 1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (10, 2_000, 3_000), dtype=torch.int8, device="cuda")
# custom_bmatmul_int8 = torch_cuda_ext.bmm_int8(A, B)
custom_bmatmul_int8 = torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)
torch_bmatmul_int8 = torch.bmm(A.float(), B.float())        

if torch.allclose(custom_bmatmul_int8.float(), torch_bmatmul_int8, atol=1e-2):
    print("--- Correct implementation! ---")
else:
    print("[ERROR] Incorrect int8 batched matmul implementation !!!")

--- Correct implementation! ---


# 2. Check memory peak

In [10]:
# Compare float matmul and torch matmul
A = torch.randn(1_000, 2_000, device="cuda")
B = torch.randn(2_000, 3_000, device="cuda")

@cuda_memory_profiler(device="cuda")
def custom_matmul_f32(A, B):
    return torch_cuda_ext.matmul_f32_cutlass(A, B)

_ = custom_matmul_f32(A, B)
print()

@cuda_memory_profiler(device="cuda")
def torch_built_int_matmul_f32(A, B):
    return torch.matmul(A, B)

_ = torch_built_int_matmul_f32(A, B)

[custom_matmul_f32] Δpeak: 12.45 MB

[torch_built_int_matmul_f32] Δpeak: 12.00 MB


In [11]:
# Compare matmul memory peak
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")

@cuda_memory_profiler()
def custom_matmul_int8_cutlass(A, B):
    return torch_cuda_ext.matmul_int8_cutlass(A, B)

_ = custom_matmul_int8_cutlass(A, B)
print()

# =============== Torch batched matmul =================
A = A.to(torch.float16)
B = B.to(torch.float16)
@cuda_memory_profiler()
def torch_batched_matmul_int8(A, B):
    return torch.matmul(A, B)

_ = torch_batched_matmul_int8(A, B)

[custom_matmul_int8_cutlass] Δpeak: 12.00 MB

[torch_batched_matmul_int8] Δpeak: 6.45 MB


In [12]:
# Compare matmul memory peak
A = torch.randint(-127, 127, (10, 1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (10, 2_000, 3_000), dtype=torch.int8, device="cuda")

@cuda_memory_profiler()
def custom_batched_matmul_int8_cutlass(A, B):
    return torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)

_ = custom_batched_matmul_int8_cutlass(A, B)
print()

# =============== Torch batched matmul =================
A = A.to(torch.float16)
B = B.to(torch.float16)

@cuda_memory_profiler()
def torch_batched_matmul_int8(A, B):
    return torch.bmm(A, B)

_ = torch_batched_matmul_int8(A, B)

[custom_batched_matmul_int8_cutlass] Δpeak: 132.00 MB

[torch_batched_matmul_int8] Δpeak: 60.82 MB


# 3. Check time

In [13]:
# Compare time of int8 matmul
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")

# Warm-up
for _ in range(5):
    _ = torch_cuda_ext.matmul_int8_cutlass(A, B)
    _ = torch.matmul(A.float(), B.float())
torch.cuda.synchronize()

n_iter = 100
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_iter):
    _ = torch_cuda_ext.matmul_int8_cutlass(A, B)
end.record()
torch.cuda.synchronize()
custom_time = start.elapsed_time(end) / n_iter  # in ms
print(f"[custom_bmm_int8] elapsed: {custom_time:.3f} ms \n")


# =============== Torch matmul =================
A = A.to(torch.float16)
B = B.to(torch.float16)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_iter):
    _ = torch.matmul(A, B)
end.record()
torch.cuda.synchronize()
torch_time = start.elapsed_time(end) / n_iter  # in ms
print(f"[torch matmul] elapsed: {torch_time:.3f} ms")

[custom_bmm_int8] elapsed: 0.264 ms 

[torch matmul] elapsed: 0.197 ms


In [14]:
# Compare time of batched int8 matmul
A = torch.randint(-127, 127, (16, 1024, 512), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (16, 512, 1024), dtype=torch.int8, device="cuda")

# Warm-up
for _ in range(5):
    _ = torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)
    _ = torch.bmm(A.float(), B.float())
torch.cuda.synchronize()

n_iter = 100
start = time.time()
for _ in range(n_iter):
    _ = torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)
    torch.cuda.synchronize()
end = time.time()
custom_time = (end - start) / n_iter * 1e3  # in ms
print(f"[custom_bmm_int8] elapsed: {custom_time:.3f} ms \n")

A = A.to(torch.float16)
B = B.to(torch.float16)
start = time.time()
for _ in range(n_iter):
    _ = torch.bmm(A, B)
    torch.cuda.synchronize()
end = time.time()
torch_time = (end - start) / n_iter * 1e3  # in ms
print(f"[torch batched matmul] elapsed: {torch_time:.3f} ms")

[custom_bmm_int8] elapsed: 0.685 ms 

[torch batched matmul] elapsed: 0.256 ms


In [15]:
# Compare time of batched int8 matmul
A = torch.randint(-127, 127, (16, 1024, 512), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (16, 512, 1024), dtype=torch.int8, device="cuda")

# Warm-up
for _ in range(5):
    _ = torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)
    _ = torch.bmm(A.float(), B.float())
torch.cuda.synchronize()

n_iter = 100

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_iter):
    _ = torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)
end.record()
torch.cuda.synchronize()

custom_time = start.elapsed_time(end) / n_iter  # in ms
print(f"[custom_bmm_int8] elapsed: {custom_time:.3f} ms \n")

# ================ Torch batched matmul =================
A_float = A.to(torch.float16)
B_float = B.to(torch.float16)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_iter):
    _ = torch.bmm(A_float, B_float)
    
end.record()
torch.cuda.synchronize()
torch_time = start.elapsed_time(end) / n_iter  # in ms
print(f"[torch batched matmul] elapsed: {torch_time:.3f} ms")    

[custom_bmm_int8] elapsed: 0.669 ms 

[torch batched matmul] elapsed: 0.251 ms


## 3.1. Check the fusion

In [39]:
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")

# Measure time of int8 matmul + convert to float16
# Warm-up
for _ in range(5):
    _ = torch_cuda_ext.matmul_int8_to_fp16_scaled_forward_noc(A, B, 1.0)
torch.cuda.synchronize()
    
n_iter = 100
start = time.perf_counter()
for _ in range(n_iter):    
    out = torch_cuda_ext.matmul_int8_to_fp16_scaled_forward_noc(A, B, 1.0)
torch.cuda.synchronize()
end = time.perf_counter()
custom_time = (end - start) / n_iter * 1e3  # in ms
print(f"[custom_matmul_int8_to_fp16] elapsed: {custom_time:.3f} ms \n")

print(out.dtype)
print(out.shape)

[custom_matmul_int8_to_fp16] elapsed: 0.275 ms 

torch.float16
torch.Size([1000, 3000])


In [17]:
# Compare time of int8 matmul
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")

# Warm-up
for _ in range(5):
    _ = torch_cuda_ext.matmul_int8_cutlass(A, B)
torch.cuda.synchronize()

n_iter = 100
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_iter):
    out = torch_cuda_ext.matmul_int8_cutlass(A, B)
    out = out.to(torch.float16)
    
end.record()
torch.cuda.synchronize()
custom_time = start.elapsed_time(end) / n_iter  # in ms
print(f"[custom_bmm_int8] elapsed: {custom_time:.3f} ms \n")

[custom_bmm_int8] elapsed: 0.290 ms 



In [44]:
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")

A = A.to(torch.float16)
B = B.to(torch.float16)

# Warm-up
for _ in range(5):
    _ = torch.matmul(A, B)
torch.cuda.synchronize()

# Measure time by python time
n_iter = 100 
start = time.perf_counter()
for _ in range(n_iter):
    _ = torch.matmul(A, B)
torch.cuda.synchronize()
end = time.perf_counter()

torch_time = (end - start) / n_iter * 1e3  # in ms
print(f"[torch matmul] elapsed: {torch_time:.3f} ms")

[torch matmul] elapsed: 0.186 ms
