Open
Description
I would like to request support for the torch.logcumsumexp
operation in the Torch dialect of Torch-MLIR.
I tested with the torch.logcumexp using fx.export_and_import and the reproduced error is
test_logcumsumexp
-----------------
loc("/home/data/sharavana/torch-mlir/test/python/fx_importer/test_logcumsumexp.py":19:0):
error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
Traceback (most recent call last):
File "/home/data/sharavana/torch-mlir/test/python/fx_importer/test_logcumsumexp.py", line 13, in <module>
def test_logcumsumexp():
File "/home/data/sharavana/torch-mlir/test/python/fx_importer/test_logcumsumexp.py", line 9, in run
f()
File "/home/data/sharavana/torch-mlir/test/python/fx_importer/test_logcumsumexp.py", line 20, in test_logcumsumexp
m = fx.export_and_import(Logcumsumexp(), torch.randn(3, 4),output_type="torch")
File "/home/data/sharavana/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/fx.py", line 124, in export_and_import
return _module_lowering(
File "/home/data/sharavana/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/fx.py", line 61, in _module_lowering
run_pipeline_with_repro_report(
File "/home/data/sharavana/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 127, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchFX IR -> Torch Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{ extra-library=})' /home/data/sharavana/tmp/UnnammedModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.`
Minimal Reproduction
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
def test_logcumsumexp():
class LogcumsumOp(nn.Module):
def forward(self, x):
return torch.logcumsumexp(x,1)
input_tensor = torch.randn(2, 4)
exported = fx.export_and_import(LogcumsumOp(), input_tensor, output_type="torch")
print(exported)
Metadata
Metadata
Assignees
Labels
No labels