In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F

from torch.utils.cpp_extension import load

sgemm = load(
    "sgemm", ['sgemm_variant/gemm.cpp', 'sgemm_variant/gemm_nt.cu', 'sgemm_variant/gemm_tn.cu'], 
    build_directory='build/'
)

In [None]:
device = torch.device('cuda:0')

M = 256
N = 512
K = 1024

a = torch.randn(M, K).to(device)
b = torch.randn(N, K).to(device)

# print(sgemm.gemm(a, b).size())
mat = sgemm.gemm(a, b)
print(mat.size())
gt = F.linear(a, b)
print(gt.size())
print((mat - gt).abs().max().item())

In [39]:
import random

for i in range(1000):

    M = random.randint(500, 800)
    N = random.randint(500, 800)
    K = random.randint(500, 800)
    a = torch.randn(M, K).to(device)
    b = torch.randn(N, K).to(device)

    # print(sgemm.gemm(a, b).size())
    mat = sgemm.gemm(a, b)
    gt = F.linear(a, b)
    diff = (mat - gt).abs().max().item()
    print(f'exp {i}, (M, N, K) = ({M}, {N}, {K}), max diff: {(mat - gt).abs().max().item()}')
    assert diff < 0.0003

exp 0, (M, N, K) = (506, 539, 537), max diff: 0.00011444091796875
exp 1, (M, N, K) = (689, 744, 652), max diff: 0.00014495849609375
exp 2, (M, N, K) = (769, 539, 637), max diff: 0.0001373291015625
exp 3, (M, N, K) = (688, 504, 624), max diff: 0.0001220703125
exp 4, (M, N, K) = (655, 758, 592), max diff: 0.0001220703125
exp 5, (M, N, K) = (795, 662, 650), max diff: 0.00016021728515625
exp 6, (M, N, K) = (544, 689, 610), max diff: 0.0001220703125
exp 7, (M, N, K) = (556, 545, 728), max diff: 0.000152587890625
exp 8, (M, N, K) = (789, 722, 722), max diff: 0.00012969970703125
exp 9, (M, N, K) = (552, 670, 571), max diff: 0.0001068115234375
exp 10, (M, N, K) = (561, 766, 551), max diff: 0.00011444091796875
exp 11, (M, N, K) = (708, 708, 511), max diff: 0.0001068115234375
exp 12, (M, N, K) = (700, 758, 603), max diff: 0.0001220703125
exp 13, (M, N, K) = (645, 600, 693), max diff: 0.0001373291015625
exp 14, (M, N, K) = (715, 721, 503), max diff: 0.0001068115234375
exp 15, (M, N, K) = (772, 58

In [12]:
I= 1000

In [34]:
%%time
for i in range(I):
    sgemm.gemm(a, b)

CPU times: user 52.2 ms, sys: 62.6 ms, total: 115 ms
Wall time: 114 ms


In [38]:
%%time
for i in range(I):
    F.linear(a, b)

CPU times: user 53.6 ms, sys: 152 ms, total: 205 ms
Wall time: 204 ms
