From db722f63f75de24011da856c479d928e1b0039ef Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Fri, 13 Feb 2026 11:48:05 +0700 Subject: [PATCH] [LLVM][Codegen] Cast NaN to bool gives true - Ensure consistency with the existing framework (C/C++/Python/Torch/Numpy/...). --- src/target/llvm/codegen_llvm.cc | 2 +- .../codegen/test_target_codegen_llvm.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 0a2ae8b09e04..b7004dec32e2 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -928,7 +928,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } else if (to.is_bool()) { if (from.is_float()) { llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.); - return builder_->CreateFCmpONE(value, zero); + return builder_->CreateFCmpUNE(value, zero); } else { llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0); return builder_->CreateICmpNE(value, zero); diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index da58f5bb459c..78f2abf523b9 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -370,6 +370,30 @@ def main(A: T.Buffer((64,), "int32"), C: T.Buffer((64,), "float32")): tvm.testing.assert_allclose(c.numpy(), c_np) +@tvm.testing.requires_llvm +def test_llvm_cast_float_to_bool(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "bool")): + T.func_attr({"tir.noalias": True}) + for i in range(4): + with T.sblock("C"): + v_i = T.axis.spatial(4, i) + T.reads(A[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast("bool", A[v_i]) + + n = 4 + f = tvm.compile(Module, target="llvm") + dev = tvm.cpu(0) + a = tvm.runtime.tensor(np.array([0.0, 1.0, np.nan, np.inf], dtype="float32"), dev) + c = tvm.runtime.empty((n,), dtype="bool", device=dev) + f(a, c) + c_np = np.array([False, True, True, True], dtype="bool") + tvm.testing.assert_allclose(c.numpy(), c_np) + + @tvm.testing.requires_llvm def test_rank_zero(): @I.ir_module