# Description

In this notebook, I will verify the custom torch extension kernel with
- The correctness.
- The peak usage memory

In [1]:
import os 
import numpy as np
import time
import functools
import torch
import torch_cuda_ext

In [2]:
print("\n=== PyTorch Info ===")
print(f"Torch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"torch.version.cuda: {torch.version.cuda}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")


=== PyTorch Info ===
Torch version: 2.5.1+cu121
CUDA available: True
torch.version.cuda: 12.1
GPU name: NVIDIA RTX A5000


In [3]:
def cuda_memory_profiler(device="cuda"):
    """
    Decorator that measures GPU memory usage (and runtime) for any function.
    
    Reports:
      - Δpeak (max temporary memory used)
      - Δcurrent (net memory retained after execution)
      - runtime (optional)
    """
    def decorator(func):
        def wrapper(*args, **kwargs):
            # synchronize before measuring
            torch.cuda.synchronize(device)
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats(device)
            
            before = torch.cuda.memory_allocated(device)

            result = func(*args, **kwargs)  # run the function

            torch.cuda.synchronize(device)
            peak = torch.cuda.max_memory_allocated(device)
            
            delta_peak = peak - before

            msg = (f"[{func.__name__}] Δpeak: {delta_peak/1e6:.2f} MB")
            print(msg)

            return result
        return wrapper
    return decorator

def cuda_time_profiler(print_time=True):
    """
    Decorator to measure GPU execution time of a function using CUDA events.
    Works only if at least one tensor is on the CUDA device.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            
            # Warm-up
            func(*args, **kwargs)
            torch.cuda.synchronize() # Ensure all previous CUDA ops are done

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            n_iter = 10

            start.record()
            for i in range(n_iter):
                result = func(*args, **kwargs)
            end.record()

            torch.cuda.synchronize()

            elapsed_ms = start.elapsed_time(end) / n_iter  # Average time per iteration
            if print_time:
                print(f"[{func.__name__}] elapsed: {elapsed_ms:.3f} ms")

            return result
        return wrapper
    return decorator

def l2_norm(tensor1, tensor2):
    return torch.sqrt(torch.nansum((tensor1 - tensor2) ** 2))

# 1. Check correctness

In [4]:
# The dot product test
a = torch.randn(100, device="cuda")
b = torch.randn(100, device="cuda")

custom_dot = torch_cuda_ext.dot_forward(a, b)
torch_dot = torch.dot(a, b)

if torch.allclose(custom_dot, torch_dot, atol=1e-6):
    print("--- Correct dot product implementation! ---")
else:
    print("[ERROR] Incorrect dot product implementation !!!")

--- Correct dot product implementation! ---


In [5]:
# The float matmul test
A = torch.randn(1_000, 2_000, device="cuda")
B = torch.randn(2_000, 3_000, device="cuda")

custom_matmul = torch_cuda_ext.matmul_f32(A, B)
torch_matmul = torch.matmul(A, B)

if torch.allclose(custom_matmul, torch_matmul, atol=1e-3):
    print("--- Correct float matmul implementation! ---")
else:
    print("[ERROR] Incorrect float matmul implementation !!!")

--- Correct float matmul implementation! ---


In [6]:
# The float matmul CUTLASS test
A = torch.randn(1_000, 2_000, device="cuda")
B = torch.randn(2_000, 3_000, device="cuda")

custom_matmul_cutlass = torch_cuda_ext.matmul_f32_cutlass(A, B)
torch_matmul = torch.matmul(A, B)

if torch.allclose(custom_matmul_cutlass, torch_matmul, atol=1e-3):
    print("--- Correct CUTLASS implementation! ---")
else:
    print("[ERROR] Incorrect float matmul CUTLASS implementation !!!")

--- Correct CUTLASS implementation! ---


In [7]:
# The int8 matmul test
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")
custom_matmul_int8 = torch_cuda_ext.matmul_int8(A, B)
torch_matmul_int8 = torch.matmul(A.float(), B.float())

if torch.allclose(custom_matmul_int8.float(), torch_matmul_int8, atol=1e-2):
    print("--- Correct implementation! ---")
else:
    print("[ERROR] Incorrect int8 matmul implementation !!!")

--- Correct implementation! ---


In [8]:
# The int8 matmul CUTLASS test
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")
custom_matmul_int8_cutlass = torch_cuda_ext.matmul_int8_cutlass(A, B)
torch_matmul_int8 = torch.matmul(A.float(), B.float())

if torch.allclose(custom_matmul_int8_cutlass.float(), torch_matmul_int8, atol=1e-2):
    print("--- Correct CUTLASS implementation! ---")
else:
    print("[ERROR] Incorrect int8 matmul CUTLASS implementation !!!")

--- Correct CUTLASS implementation! ---


In [9]:
# The batched int8 matmul test
A = torch.randint(-127, 127, (10, 1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (10, 2_000, 3_000), dtype=torch.int8, device="cuda")
# custom_bmatmul_int8 = torch_cuda_ext.bmm_int8(A, B)
custom_bmatmul_int8 = torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)
torch_bmatmul_int8 = torch.bmm(A.float(), B.float())        

if torch.allclose(custom_bmatmul_int8.float(), torch_bmatmul_int8, atol=1e-2):
    print("--- Correct implementation! ---")
else:
    print("[ERROR] Incorrect int8 batched matmul implementation !!!")

--- Correct implementation! ---


# 2. Check memory peak

In [10]:
# Compare float matmul and torch matmul
A = torch.randn(1_000, 2_000, device="cuda")
B = torch.randn(2_000, 3_000, device="cuda")

@cuda_memory_profiler(device="cuda")
def custom_matmul_f32(A, B):
    return torch_cuda_ext.matmul_f32_cutlass(A, B)

_ = custom_matmul_f32(A, B)
print()

@cuda_memory_profiler(device="cuda")
def torch_built_int_matmul_f32(A, B):
    return torch.matmul(A, B)

_ = torch_built_int_matmul_f32(A, B)

[custom_matmul_f32] Δpeak: 12.00 MB

[torch_built_int_matmul_f32] Δpeak: 12.58 MB


In [11]:
# Compare matmul memory peak
A = torch.randint(-127, 127, (10, 1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (10, 2_000, 3_000), dtype=torch.int8, device="cuda")

@cuda_memory_profiler()
def custom_batched_matmul_int8(A, B):
    return torch_cuda_ext.bmm_int8(A, B)

_ = custom_batched_matmul_int8(A, B)
print()

@cuda_memory_profiler()
def custom_batched_matmul_int8_cutlass(A, B):
    return torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)

_ = custom_batched_matmul_int8_cutlass(A, B)
print()

@cuda_memory_profiler()
def torch_batched_matmul_int8(A, B):
    return torch.bmm(A.float(), B.float())

_ = torch_batched_matmul_int8(A, B)

[custom_batched_matmul_int8] Δpeak: 120.00 MB

[custom_batched_matmul_int8_cutlass] Δpeak: 132.00 MB

[torch_batched_matmul_int8] Δpeak: 440.00 MB


# 3. Check time

In [12]:
# Compare time of float matmul
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.float32, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.float32, device="cuda")

# Warm-up
for _ in range(5):
    _ = torch_cuda_ext.matmul_f32_cutlass(A, B)
    _ = torch.matmul(A, B)
torch.cuda.synchronize()
    
n_iter = 100
for _ in range(n_iter):
    start = time.time()
    _ = torch_cuda_ext.matmul_f32_cutlass(A, B)
    torch.cuda.synchronize()
end = time.time()
custom_time = (end - start) / n_iter * 1e3  # in ms
print(f"[custom_bmm_int8] elapsed: {custom_time:.3f} ms \n")

for _ in range(n_iter):
    start = time.time()
    _ = torch.matmul(A, B)
    torch.cuda.synchronize()
end = time.time()
torch_time = (end - start) / n_iter * 1e3  # in ms
print(f"[torch] elapsed: {torch_time:.3f} ms")

[custom_bmm_int8] elapsed: 0.009 ms 

[torch] elapsed: 0.008 ms


In [13]:
# Compare time of int8 matmul
A = torch.randint(-127, 127, (1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (2_000, 3_000), dtype=torch.int8, device="cuda")

# Warm-up
for _ in range(5):
    _ = torch_cuda_ext.matmul_int8_cutlass(A, B)
    _ = torch.matmul(A.float(), B.float())
torch.cuda.synchronize()

n_iter = 100
for _ in range(n_iter):
    start = time.time()
    _ = torch_cuda_ext.matmul_int8_cutlass(A, B)
    torch.cuda.synchronize()
end = time.time()
custom_time = (end - start) / n_iter * 1e3  # in ms
print(f"[custom_bmm_int8] elapsed: {custom_time:.3f} ms \n")

for _ in range(n_iter):
    start = time.time()
    _ = torch.matmul(A.float(), B.float())
    torch.cuda.synchronize()
end = time.time()
torch_time = (end - start) / n_iter * 1e3  # in ms
print(f"[torch batched matmul] elapsed: {torch_time:.3f} ms")

[custom_bmm_int8] elapsed: 0.003 ms 

[torch batched matmul] elapsed: 0.008 ms


In [14]:
# Compare time of batched int8 matmul
A = torch.randint(-127, 127, (10, 1_000, 2_000), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (10, 2_000, 3_000), dtype=torch.int8, device="cuda")

# Warm-up
for _ in range(5):
    _ = torch_cuda_ext.bmm_int8(A, B)
    _ = torch.bmm(A.float(), B.float())
torch.cuda.synchronize()

n_iter = 100
for _ in range(n_iter):
    start = time.time()
    # _ = torch_cuda_ext.bmm_int8(A, B)
    _ = torch_cuda_ext.bmm_int8_cutlass_forward_streams(A, B)
    torch.cuda.synchronize()
end = time.time()
custom_time = (end - start) / n_iter * 1e3  # in ms
print(f"[custom_bmm_int8] elapsed: {custom_time:.3f} ms \n")

for _ in range(n_iter):
    start = time.time()
    _ = torch.bmm(A.float(), B.float())
    torch.cuda.synchronize()
end = time.time()
torch_time = (end - start) / n_iter * 1e3  # in ms
print(f"[torch batched matmul] elapsed: {torch_time:.3f} ms")

[custom_bmm_int8] elapsed: 0.032 ms 

[torch batched matmul] elapsed: 0.089 ms
