In [1]:
from kernel import hadamard
import tilelang
import tilelang.language as T
from fast_hadamard_transform import hadamard_transform
from ref import hadamard_transform_ref
import torch

In [2]:
def test(b=1024, d=8192, dtype=torch.float32):
    x = torch.rand(b, d, dtype=dtype, device='cuda')
    fn = hadamard(b, d, dtype.__str__().split('.')[-1])
    kernel = tilelang.compile(fn, out_idx=1, target='cuda')
    print(f'Test for {d=} and {dtype=}:', end=' ')
    if torch.allclose(
        kernel(x), hadamard_transform(x), atol=1e-3, rtol=1e-3
    ):
        print('Passed')
    else:
        print('Failed')

In [7]:
for logd in range(1, 16):
    d = 2 ** logd
    test(d=d)

Test for d=2 and dtype=torch.float32: Passed
Test for d=4 and dtype=torch.float32: Passed
Test for d=8 and dtype=torch.float32: Passed
Test for d=16 and dtype=torch.float32: Passed
Test for d=32 and dtype=torch.float32: Passed
Test for d=64 and dtype=torch.float32: Passed
Test for d=128 and dtype=torch.float32: Passed
Test for d=256 and dtype=torch.float32: Passed
Test for d=512 and dtype=torch.float32: Passed
Test for d=1024 and dtype=torch.float32: Passed
Test for d=2048 and dtype=torch.float32: Passed
Test for d=4096 and dtype=torch.float32: Passed
Test for d=8192 and dtype=torch.float32: Passed
Test for d=16384 and dtype=torch.float32: Passed
Test for d=32768 and dtype=torch.float32: Passed


In [11]:
x = torch.rand(1024, 8192, dtype=torch.float32, device='cuda')
with torch.autograd.profiler.profile(use_cuda=True) as prof:
    hadamard_transform(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
      HadamardTransformFn        34.86%     142.321us        86.71%     354.023us     354.023us     207.000us        56.10%     369.000us     369.000us             1  
            aten::reshape         8.99%      36.712us        20.75%      84.723us      42.361us      45.000us        12.20%      84.000us      42.000us             2  
         aten::empty_like         5.56%      22.695us        16.67%      68.069us      68.069us      32.000us         8.67%      78.000us      78.000us        

  with torch.autograd.profiler.profile(use_cuda=True) as prof:
