Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions kernel_tuner/backends/hip/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,16 @@ def compile(self, kernel_instance):
# Format kernel string
kernel_string = kernel_instance.kernel_string
kernel_name = kernel_instance.name
if 'extern "C"' not in kernel_string:
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"
expression_name = kernel_name.encode()

# Create program
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), kernel_name.encode(), 0, [], []))

try:
# Add the kernel as an expression. This forces hiprtc to instantiate the kernel if it
# is templated or if it is in a namespace.
hip_check(hiprtc.hiprtcAddNameExpression(prog, expression_name))

# Get device properties
props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))
Expand All @@ -174,6 +177,10 @@ def compile(self, kernel_instance):
hip_check(hiprtc.hiprtcGetProgramLog(prog, log))
raise RuntimeError(log.decode())

# Get the lowered name. This is the name that can be used in hipModuleGetFunction to
# get the kernel. For templated kernels, this differs from the original kernel name.
lowered_name = hip_check(hiprtc.hiprtcGetLoweredName(prog, expression_name))

# Get compiled code
code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog))
code = bytearray(code_size)
Expand All @@ -182,7 +189,7 @@ def compile(self, kernel_instance):
# Load module and get function
module = hip_check(hip.hipModuleLoadData(code))
self.current_module = module
kernel = hip_check(hip.hipModuleGetFunction(module, kernel_name.encode()))
kernel = hip_check(hip.hipModuleGetFunction(module, lowered_name))

except Exception as e:
# Cleanup
Expand Down
2 changes: 1 addition & 1 deletion kernel_tuner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose)
)

# check for templated kernel
if kernel_source.lang in ["CUDA", "NVCUDA", "HIP"] and "<" in name and ">" in name:
if kernel_source.lang in ["CUDA", "NVCUDA"] and "<" in name and ">" in name:
kernel_string, name = wrap_templated_kernel(kernel_string, name)

# Preprocess GPU arguments. Require for handling `Tunable` arguments
Expand Down
Loading