Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, TVMScript] Add TIR - Triton integration #17395

Merged
merged 8 commits into from
Sep 23, 2024

Conversation

vinx13
Copy link
Member

@vinx13 vinx13 commented Sep 20, 2024

Added a macro T.call_triton in TIR script parser, which expands to AOT compilation of the kernel and the host TIR code to launch the kernel.

cc @tqchen @cyx-6

Added a macro `T.call_triton` in TIR script parser, which expands to AOT
compilation of the kernel and the host TIR code to launch the kernel.
@tqchen
Copy link
Member

tqchen commented Sep 20, 2024

This is a great mechanism to integrate kernel generators. Some notes on design that might help generalize it abit

Would be great to change the intrinsic to T.call_kernel, which checks the first parameter for kernel types and dispatch accordingly.

  • Let us think of a base class tir.kernel.BaseKernel which implements the base methods needed
    • compile_to_device_module(
      • calls get_meta_data to get the meta_data
    • have a registry from class name to Kernel constructor(can be in next PR)
      • So the construct and import of related kernel can be done only when we see the related class name
      • call_kernel will lookup registry and construct related kernels if needed
    • TritonKernel subclasses that, call_kernel will lookup the constructor, construct it and call related class method to generate the necessary downstream classes.
  • Let us also add a tir.kernel.CUDAKernel as an example extension pt
    • CUDAKernel takes in cuda C source device code, and ensure call_kernel mechanism works the same

Overall Mechanism

  • A new DSL/kernel defines a subclass of tir.kernel.Kernel
    • the subclass describes the mechanism on how to compile the source to the device module
    • it also registers into the mapping a class name => Kernel type
  • User can either manually construct a instance of tir.kernel.Kernel, or pass in the original data structure triton.JITFunction and call_kernel leverages the registered mapping to do automatic conversion.
  • call_kernel can contain possible specialization hints(constants etc)

pytestmark = pytest.skip("Triton is not available", allow_module_level=True)


@tvm.testing.requires_cuda
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a testcase of rewrite usecase

@tqchen tqchen merged commit 48d3ada into apache:main Sep 23, 2024
17 of 18 checks passed
if tir_mod is not None and len(tir_mod.get_global_vars()) > 0:
lib = tvm.build(
tir_mod,
target=target,
runtime=_autodetect_system_lib_req(target, system_lib),
)
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore
for ext_mod in ext_libs:
if ext_mod.type_key == "cuda":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a followup, add a function to check if the module is_device_module, this should include cuda, rocm, webgpu, vulkan, opencl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants