In [1]:
import hidet

def vector_addition(n):
    from hidet.lang import attr, f32
    from hidet.lang.cuda import threadIdx, blockIdx, blockDim
    from hidet.transforms.tools import add_packed_func

    with hidet.script_module() as script_module:

        @hidet.script
        def kernel(a: f32[n], b: f32[n], c: f32[n]):
            attr.func_kind = 'cuda_kernel'
            attr.cuda_block_dim = 256
            attr.cuda_grid_dim = (n + 255) / 256
            idx = threadIdx.x + blockIdx.x * blockDim.x
            if idx < n:
                c[idx] = a[idx] + b[idx]

    ir_module = script_module.ir_module()
    add_packed_func(ir_module, func=kernel, pack_func_name='add')
    print(ir_module)

    return hidet.driver.build_ir_module(ir_module, func_name='add')


n = 5
a = hidet.randint(low=0, high=3, shape=[n]).to('float32').cuda()
b = hidet.randint(low=0, high=3, shape=[n]).to('float32').cuda()
c = hidet.randn([n]).cuda()
add_func = vector_addition(n)


def kernel fn(a: TensorPointerType(tensor(float32, [5])), 
   b: TensorPointerType(tensor(float32, [5])), 
   c: TensorPointerType(tensor(float32, [5])))
    # func_kind: cuda_kernel
    # cuda_block_dim: 256
    # cuda_grid_dim: 1
    declare idx: int32 = (threadIdx.x + (blockIdx.x * blockIdx.x))
    if (idx < 5)
        c[idx] = (a[idx] + b[idx])

def add fn(num_args: int32, 
   arg_types: PointerType(int32), 
   args: PointerType(PointerType(VoidType)))
    # func_name: add
    # func_kind: packed_func
    # packed_func: kernel
    assert((num_args == 3), Expect 3 arguments)
    assert((arg_types[0] == 3), The 0-th argument should be TensorPointerType(tensor(float32, [5])))
    declare a: PointerType(float32) = cast(PointerType(float32), args[0])
    assert((arg_types[1] == 3), The 1-th argument should be TensorPointerType(tensor(float32, [5])))
    declare b: PointerType(float32) = cast(PointerType(float32), args[1])
    assert((arg_types[2] == 3), The 2-th argument should be Tens

In [2]:
print(a)
print(b)
add_func(a, b, c)
print(c)

Tensor(shape=(5,), dtype='float32', device='cuda:0')
[1. 1. 2. 0. 2.]
Tensor(shape=(5,), dtype='float32', device='cuda:0')
[1. 1. 0. 0. 2.]
Tensor(shape=(5,), dtype='float32', device='cuda:0')
[2. 2. 2. 0. 4.]
