Summary
A minimal PyTorch model using torch.prod(x, dtype=torch.bool) can be exported to Relax, but fails during tvm.compile(..., target="llvm").
PyTorch eager handles this case successfully:
x = torch.zeros((1, 1, 16, 16), dtype=torch.bool)
torch.prod(x, dtype=torch.bool)
# tensor(False)
After torch.export and from_exported_program`, TVM produces a Relax program containing:
R.prod(x, axis=None, keepdims=False)
where x has dtype bool. However, tvm.compile fails during LLVM code generation with:
InternalError: Check failed: (t.is_float()) is false:
The stack trace shows the failure reaches CodeGenLLVM::CreateMul, suggesting that the bool reduction is lowered as a multiplication-based reduction over bool values. For bool prod, this should either be lowered to a valid logical-AND-style reduction, cast to a supported integer representation, or rejected earlier with a clear unsupported-dtype diagnostic.
Expected behavior
The exported Relax IR contains R.prod over a bool tensor:
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 1, 16, 16), dtype="bool")) -> R.Tuple(R.Tensor((), dtype="bool")):
with R.dataflow():
lv: R.Tensor((), dtype="bool") = R.prod(x, axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="bool")) = (lv,)
R.output(gv)
return gv
Actual behavior
tvm.compile fails for the LLVM target.
The relevant part of the stack trace is:
tvm.tir.build
-> codegen_build
-> CodeGenLLVM::AddFunctionInternal
-> CodeGenLLVM::VisitStmt_(BufferStoreNode)
-> CodeGenLLVM::MakeValue
-> CodeGenLLVM::VisitExpr_(CastNode)
-> CodeGenLLVM::CreateMul
-> InternalError: Check failed: (t.is_float()) is false:
Full observed failure:
tvm.error.InternalError: Check failed: (t.is_float()) is false:
Environment
TVM: 0.23.0
LLVM: 17.0.6
Python: 3.10.16 (from stack paths)
NumPy: 2.2.6
Steps to reproduce
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import platform
import traceback
import torch
import tvm
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.prod(x, dtype=torch.bool)
def main():
print("=" * 80)
print("Environment")
print("=" * 80)
print("python:", sys.version.replace("\n", " "))
print("platform:", platform.platform())
print("torch:", torch.__version__)
print("tvm:", getattr(tvm, "__version__", "<unknown>"))
print("tvm path:", getattr(tvm, "__file__", "<unknown>"))
model = MyModel().eval()
x = torch.zeros((1, 1, 16, 16), dtype=torch.bool)
with torch.no_grad():
eager = model(x)
print("=" * 80)
print("PyTorch eager")
print("=" * 80)
print("input shape:", tuple(x.shape), "dtype:", x.dtype)
print("eager:", eager, eager.dtype, eager.shape)
ep = torch.export.export(model, (x,))
from tvm.relax.frontend.torch import from_exported_program
ir_mod = from_exported_program(ep)
print("=" * 80)
print("Exported Relax IR")
print("=" * 80)
print(ir_mod.script(show_meta=True))
print("=" * 80)
print("tvm.compile with LLVM")
print("=" * 80)
ex = tvm.compile(
ir_mod,
target=tvm.target.Target("llvm"),
relax_pipeline="default",
tir_pipeline="default",
)
print("compile: OK")
print(ex)
if __name__ == "__main__":
try:
main()
except Exception:
print("compile: FAILED")
traceback.print_exc()
Triage
cc @junrushao
Summary
A minimal PyTorch model using
torch.prod(x, dtype=torch.bool)can be exported to Relax, but fails duringtvm.compile(..., target="llvm").PyTorch eager handles this case successfully:
After
torch.export andfrom_exported_program`, TVM produces a Relax program containing:where
xhas dtypebool. However,tvm.compilefails during LLVM code generation with:The stack trace shows the failure reaches CodeGenLLVM::CreateMul, suggesting that the bool reduction is lowered as a multiplication-based reduction over bool values. For bool prod, this should either be lowered to a valid logical-AND-style reduction, cast to a supported integer representation, or rejected earlier with a clear unsupported-dtype diagnostic.
Expected behavior
The exported Relax IR contains R.prod over a bool tensor:
Actual behavior
tvm.compilefails for the LLVM target.The relevant part of the stack trace is:
Full observed failure:
Environment
TVM: 0.23.0
LLVM: 17.0.6
Python: 3.10.16 (from stack paths)
NumPy: 2.2.6
Steps to reproduce
Triage
cc @junrushao