In [1]:
import os 
import torch
import torch_cuda_ext

In [2]:
a = torch.randn(100, device="cuda")
b = torch.randn(100, device="cuda")

custom_dot = torch_cuda_ext.dot_forward(a, b)
torch_dot = torch.dot(a, b)

print("Custom dot product:", custom_dot)
print("PyTorch dot product:", torch_dot)
print("Difference:", torch.abs(custom_dot - torch_dot).item())

Custom dot product: tensor(8.2927, device='cuda:0')
PyTorch dot product: tensor(8.2927, device='cuda:0')
Difference: 0.0


In [3]:
A = torch.randn(10, 20, device="cuda")
B = torch.randn(20, 30, device="cuda")

custom_matmul = torch_cuda_ext.matmul_f32(A, B)
torch_matmul = torch.matmul(A, B)

if torch.allclose(custom_matmul, torch_matmul, atol=1e-6):
    print("--- Correct matmul implementation! ---")
else:
    print("[ERROR] Incorrect matmul implementation !!!")

--- Correct matmul implementation! ---


In [4]:
A = torch.randint(-127, 127, (10, 20), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (20, 30), dtype=torch.int8, device="cuda")
custom_matmul_int8 = torch_cuda_ext.matmul_int8(A, B)

print(f"Shape of int8 matmul result: {custom_matmul_int8.shape}")
print(f"Dtype of int8 matmul result: {custom_matmul_int8.dtype}")

if torch.allclose(custom_matmul_int8.float(), torch.matmul(A.float(), B.float()), atol=1e-2):
    print("--- Correct int8 matmul implementation! ---")

Shape of int8 matmul result: torch.Size([10, 30])
Dtype of int8 matmul result: torch.int32
--- Correct int8 matmul implementation! ---


In [5]:
A = torch.randint(-127, 127, (5, 10, 20), dtype=torch.int8, device="cuda")
B = torch.randint(-127, 127, (5, 20, 30), dtype=torch.int8, device="cuda")
custom_bmatmul_int8 = torch_cuda_ext.bmm_int8(A, B)

print(f"Shape of int8 batched matmul result: {custom_bmatmul_int8.shape}")
print(f"Dtype of int8 batched matmul result: {custom_bmatmul_int8.dtype}")

Shape of int8 batched matmul result: torch.Size([5, 10, 30])
Dtype of int8 batched matmul result: torch.int32


In [7]:
custom_bmatmul_int8.float()

tensor([[[ 25187.,  -7588.,  -2830.,  ..., -44194.,   5854.,  63146.],
         [  5141.,  -7088., -23081.,  ...,   8831.,    118., -29748.],
         [ 17396.,    999., -31703.,  ...,  29564.,  -5070., -12355.],
         ...,
         [  6864.,  -8054.,  22061.,  ...,  32973., -35613., -22143.],
         [-10806.,  10242.,  13992.,  ...,  -6467.,  25065.,  12208.],
         [-15481.,  -3666., -15239.,  ...,    247.,  -2213.,  -4801.]],

        [[-44666.,  31742.,  13281.,  ...,  -5506., -15045.,  -5820.],
         [-22354.,  14386., -25246.,  ..., -83878., -24350., -21194.],
         [ 14668.,  -5803.,  33727.,  ...,  30368., -28923.,  -5780.],
         ...,
         [ 48894.,   8602.,   6306.,  ...,  20136.,  -5907., -13481.],
         [  4869.,  36882.,   2609.,  ...,  -9629.,  12767., -19111.],
         [-27534., -19762.,   9586.,  ...,   -234., -15371.,   -904.]],

        [[ 35440.,  -8161.,  -9512.,  ...,  34583.,   5621., -17477.],
         [ 25596.,  24572.,  13527.,  ..., -2

In [10]:
if torch.allclose(custom_bmatmul_int8.float(), torch.bmm(A.float(), B.float()), atol=1e-2):
    print("--- Correct int8 batched matmul implementation! ---")
else:
    print("[ERROR] Incorrect int8 batched matmul implementation !!!")

--- Correct int8 batched matmul implementation! ---
