In [13]:
import time
import torch

In [7]:
from matmul_kernel import matmul_fp16_int4
from final_quantization_kernel import quantize_rowwise_int4, dequantize_rowwise_int4

In [18]:
import torch
import time

torch.manual_seed(42)

layer_configs = [
    (2048, 2048),
    (2048, 512), 
    (2048, 8192),
    (8192, 2048),
]

token_counts = [128, 512, 2048]
num_warmup = 10
num_iterations = 100

results = {}

for M in token_counts:
    results[M] = {}
    print(f"\n{'='*60}")
    print(f"Benchmarking with {M} tokens")
    print(f"{'='*60}")
    
    for in_features, out_features in layer_configs:
        K, N = in_features, out_features
        
        print(f"\nLayer: ({K}, {N}) - X[{M},{K}] @ W[{N},{K}]^T")
        
        X = torch.randn(M, K, device='cuda', dtype=torch.float16)
        W_fp16 = torch.randn(N, K, device='cuda', dtype=torch.float16)
        
        W_int4_packed, scales = quantize_rowwise_int4(W_fp16)
        
        for _ in range(num_warmup):
            _ = torch.matmul(X, W_fp16.t())
            _ = matmul_fp16_int4(X, W_int4_packed, scales)
        
        torch.cuda.synchronize()
        
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        
        start_event.record()
        for _ in range(num_iterations):
            C_fp16 = torch.matmul(X, W_fp16.t())
        end_event.record()
        torch.cuda.synchronize()
        fp16_time = start_event.elapsed_time(end_event) / num_iterations
        
        start_event.record()
        for _ in range(num_iterations):
            C_int4 = matmul_fp16_int4(X, W_int4_packed, scales)
        end_event.record()
        torch.cuda.synchronize()
        int4_time = start_event.elapsed_time(end_event) / num_iterations
        
        
        speedup = fp16_time / int4_time
        
        results[M][(K, N)] = {
            'fp16_time_ms': fp16_time,
            'int4_time_ms': int4_time, 
            'speedup': speedup,
        }
        
        print(f"  fp16: {fp16_time:.3f} ms")
        print(f"  Int4: {int4_time:.3f} ms") 
        print(f"  Speedup: {speedup:.2f}x")


Benchmarking with 128 tokens

Layer: (2048, 2048) - X[128,2048] @ W[2048,2048]^T
  fp16: 0.052 ms
  Int4: 0.361 ms
  Speedup: 0.14x

Layer: (2048, 512) - X[128,2048] @ W[512,2048]^T
  fp16: 0.016 ms
  Int4: 0.129 ms
  Speedup: 0.13x

Layer: (2048, 8192) - X[128,2048] @ W[8192,2048]^T
  fp16: 0.179 ms
  Int4: 1.296 ms
  Speedup: 0.14x

Layer: (8192, 2048) - X[128,8192] @ W[2048,8192]^T
  fp16: 0.179 ms
  Int4: 1.288 ms
  Speedup: 0.14x

Benchmarking with 512 tokens

Layer: (2048, 2048) - X[512,2048] @ W[2048,2048]^T
  fp16: 0.151 ms
  Int4: 1.162 ms
  Speedup: 0.13x

Layer: (2048, 512) - X[512,2048] @ W[512,2048]^T
  fp16: 0.046 ms
  Int4: 0.324 ms
  Speedup: 0.14x

Layer: (2048, 8192) - X[512,2048] @ W[8192,2048]^T
  fp16: 0.612 ms
  Int4: 4.547 ms
  Speedup: 0.13x

Layer: (8192, 2048) - X[512,8192] @ W[2048,8192]^T
  fp16: 0.588 ms
  Int4: 4.641 ms
  Speedup: 0.13x

Benchmarking with 2048 tokens

Layer: (2048, 2048) - X[2048,2048] @ W[2048,2048]^T
  fp16: 0.591 ms
  Int4: 4.534 ms
  