In [4]:
import numpy as np
import numpy.testing
import hidet

def matmul_func(m_size, n_size, k_size):
    from hidet.lang import attr, f32
    from hidet.lang.cuda import threadIdx, blockIdx, blockDim
    from hidet.transforms.tools import add_packed_func

    def ceil_div(a, b):
        return (a + b - 1) // b

    tile_size = 16

    with hidet.script_module() as script_module:

        @hidet.script
        def kernel(
                a: f32[m_size, k_size],
                b: f32[k_size, n_size],
                c: f32[m_size, n_size]
        ):
            attr.func_kind = 'cuda_kernel'
            attr.cuda_block_dim = (tile_size, tile_size)
            attr.cuda_grid_dim = ceil_div(m_size, tile_size), ceil_div(n_size, tile_size)
            i = threadIdx.x + blockIdx.x * blockDim.x
            j = threadIdx.y + blockIdx.y * blockDim.y
            if i < m_size and j < n_size:
                acc = f32(0.0)
                for k in range(k_size):
                    acc += a[i, k] * b[k, j]
                c[i, j] = acc

    ir_module = script_module.ir_module()
    add_packed_func(ir_module, func=kernel, pack_func_name='matmul')
    return hidet.driver.build_ir_module(ir_module, func_name='matmul')

m_size, n_size, k_size = 1024, 1024, 1024
matmul = matmul_func(m_size, n_size, k_size)
print(matmul.source(color=True))

[38;5;64m#[39m[38;5;64minclude[39m[38;5;250m [39m[38;5;248;03m<stdint.h>[39;00m
[38;5;64m#[39m[38;5;64minclude[39m[38;5;250m [39m[38;5;248;03m<cuda_fp16.h>[39;00m
[38;5;64m#[39m[38;5;64minclude[39m[38;5;250m [39m[38;5;248;03m<cuda_bf16.h>[39;00m
[38;5;64m#[39m[38;5;64minclude[39m[38;5;250m [39m[38;5;248;03m<hidet/runtime/cuda_context.h>[39;00m
[38;5;64m#[39m[38;5;64minclude[39m[38;5;250m [39m[38;5;248;03m<hidet/runtime/cpu_context.h>[39;00m
[38;5;19mtypedef[39m[38;5;250m [39m[38;5;37mfloat[39m[38;5;250m [39mtfloat32_t;
[38;5;64m#[39m[38;5;64mdefine __float_to_tf32(x) (x)[39m
[38;5;19mextern[39m[38;5;250m [39m[38;5;130m"[39m[38;5;130mC[39m[38;5;130m"[39m[38;5;250m [39m{

[38;5;19m__global__[39m[38;5;250m [39m[38;5;37mvoid[39m[38;5;250m [39m__launch_bounds__([38;5;30m256[39m)[38;5;250m [39mhidet_kernel([38;5;37mfloat[39m[38;5;250m [39m*[38;5;250m [39m[38;5;37m__restrict__[39m[38;5;250m [39ma,[38;5;250m

In [5]:
a = hidet.randn([m_size, k_size]).cuda()
b = hidet.randn([k_size, n_size]).cuda()
c = hidet.empty([m_size, n_size]).cuda()
matmul(a, b, c)

np_a = a.cpu().numpy()
np_b = b.cpu().numpy()
np_c = np.matmul(np_a, np_b)

numpy.testing.assert_allclose(c.cpu().numpy(), np_c, rtol=1e-4, atol=1e-4)
print('Correctness: Pass')

Correctness: Pass


In [6]:
latency = hidet.utils.benchmark_func(lambda: matmul(a, b, c))
print('Latency: {:.2f} ms'.format(latency))

Latency: 4.04 ms
