In [1]:
from kernel import hadamard
import tilelang
import tilelang.language as T
from tilelang.profiler.bench import do_bench
from fast_hadamard_transform import hadamard_transform
from ref import hadamard_transform_ref
import torch

In [2]:
def test(b=1024, d=8192, dtype=torch.float32):
    x = torch.rand(b, d, dtype=dtype, device='cuda')
    fn = hadamard(b, d, dtype.__str__().split('.')[-1])
    kernel = tilelang.compile(fn, out_idx=1, target='cuda')
    print(f'Test for {d=} and {dtype=}:', end=' ')
    if torch.allclose(
        kernel(x), hadamard_transform(x), atol=1e-3, rtol=1e-3
    ):
        print('Passed')
    else:
        print('Failed')

In [3]:
for logd in range(1, 16):
    d = 2 ** logd
    test(d=d)

Test for d=2 and dtype=torch.float32: Passed
Test for d=4 and dtype=torch.float32: Passed
Test for d=8 and dtype=torch.float32: Passed
Test for d=16 and dtype=torch.float32: Passed
Test for d=32 and dtype=torch.float32: Passed
Test for d=64 and dtype=torch.float32: Passed
Test for d=128 and dtype=torch.float32: Passed
Test for d=256 and dtype=torch.float32: Passed
Test for d=512 and dtype=torch.float32: Passed
Test for d=1024 and dtype=torch.float32: Passed
Test for d=2048 and dtype=torch.float32: Passed
Test for d=4096 and dtype=torch.float32: Passed
Test for d=8192 and dtype=torch.float32: Passed
Test for d=16384 and dtype=torch.float32: Passed
Test for d=32768 and dtype=torch.float32: Passed


In [4]:
b, d, dtype = 1024, 32768, torch.float32
fn = hadamard(b, d, dtype.__str__().split('.')[-1])
kernel = tilelang.compile(fn, out_idx=1, target='cuda')
x = torch.rand(b, d, dtype=dtype, device='cuda')
def dao_impl():
    hadamard_transform(x)

def my_impl():
    kernel(x)

In [5]:
do_bench(
    dao_impl,
    warmup=5,
    rep=10,
)

0.2274106740951538

In [6]:
do_bench(
    my_impl,
    warmup=5,
    rep=10,
)

0.7353367209434509