diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 34a9c5915254d..6d4ea64e78aa6 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -457,15 +457,15 @@ def get_signature_value(idx: int, arg: Any) -> str: inspect.signature(backend.get_codegen_implementation).parameters ) if make_ir_sig_params == 2: - ttir_module = src.make_ir(options, context) + ttir_module = src.make_ir(target, options, context) elif make_ir_sig_params == 3: codegen_fns = backend.get_codegen_implementation() - ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module = src.make_ir(target, options, codegen_fns, context) else: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() - ttir_module = src.make_ir(options, codegen_fns, module_map, context) + ttir_module = src.make_ir(target, options, codegen_fns, module_map, context) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed")