Skip to content

[Bug] [Relax][LLVM] torch.prod with dtype=torch.bool lowers to bool R.prod and fails LLVM codegen #19551

@tinywisdom

Description

@tinywisdom

Summary

A minimal PyTorch model using torch.prod(x, dtype=torch.bool) can be exported to Relax, but fails during tvm.compile(..., target="llvm").

PyTorch eager handles this case successfully:

x = torch.zeros((1, 1, 16, 16), dtype=torch.bool)
torch.prod(x, dtype=torch.bool)
# tensor(False)

After torch.export and from_exported_program`, TVM produces a Relax program containing:

R.prod(x, axis=None, keepdims=False)

where x has dtype bool. However, tvm.compile fails during LLVM code generation with:

InternalError: Check failed: (t.is_float()) is false:

The stack trace shows the failure reaches CodeGenLLVM::CreateMul, suggesting that the bool reduction is lowered as a multiplication-based reduction over bool values. For bool prod, this should either be lowered to a valid logical-AND-style reduction, cast to a supported integer representation, or rejected earlier with a clear unsupported-dtype diagnostic.

Expected behavior

The exported Relax IR contains R.prod over a bool tensor:

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 1, 16, 16), dtype="bool")) -> R.Tuple(R.Tensor((), dtype="bool")):
        with R.dataflow():
            lv: R.Tensor((), dtype="bool") = R.prod(x, axis=None, keepdims=False)
            gv: R.Tuple(R.Tensor((), dtype="bool")) = (lv,)
            R.output(gv)
        return gv

Actual behavior

tvm.compile fails for the LLVM target.

The relevant part of the stack trace is:

tvm.tir.build
  -> codegen_build
  -> CodeGenLLVM::AddFunctionInternal
  -> CodeGenLLVM::VisitStmt_(BufferStoreNode)
  -> CodeGenLLVM::MakeValue
  -> CodeGenLLVM::VisitExpr_(CastNode)
  -> CodeGenLLVM::CreateMul
  -> InternalError: Check failed: (t.is_float()) is false:

Full observed failure:

tvm.error.InternalError: Check failed: (t.is_float()) is false:

Environment

TVM: 0.23.0
LLVM: 17.0.6
Python: 3.10.16 (from stack paths)
NumPy: 2.2.6

Steps to reproduce

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import platform
import traceback

import torch
import tvm


class MyModel(torch.nn.Module):
    def forward(self, x):
        return torch.prod(x, dtype=torch.bool)


def main():
    print("=" * 80)
    print("Environment")
    print("=" * 80)
    print("python:", sys.version.replace("\n", " "))
    print("platform:", platform.platform())
    print("torch:", torch.__version__)
    print("tvm:", getattr(tvm, "__version__", "<unknown>"))
    print("tvm path:", getattr(tvm, "__file__", "<unknown>"))

    model = MyModel().eval()
    x = torch.zeros((1, 1, 16, 16), dtype=torch.bool)

    with torch.no_grad():
        eager = model(x)

    print("=" * 80)
    print("PyTorch eager")
    print("=" * 80)
    print("input shape:", tuple(x.shape), "dtype:", x.dtype)
    print("eager:", eager, eager.dtype, eager.shape)

    ep = torch.export.export(model, (x,))

    from tvm.relax.frontend.torch import from_exported_program

    ir_mod = from_exported_program(ep)

    print("=" * 80)
    print("Exported Relax IR")
    print("=" * 80)
    print(ir_mod.script(show_meta=True))

    print("=" * 80)
    print("tvm.compile with LLVM")
    print("=" * 80)

    ex = tvm.compile(
        ir_mod,
        target=tvm.target.Target("llvm"),
        relax_pipeline="default",
        tir_pipeline="default",
    )

    print("compile: OK")
    print(ex)


if __name__ == "__main__":
    try:
        main()
    except Exception:
        print("compile: FAILED")
        traceback.print_exc()

Triage

  • needs-triage
  • bug

cc @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions