Skip to content

[TORCH] Add support for logcumsumexp #4183

Open
@sharavana20

Description

@sharavana20

I would like to request support for the torch.logcumsumexp operation in the Torch dialect of Torch-MLIR.

I tested with the torch.logcumexp using fx.export_and_import and the reproduced error is

test_logcumsumexp
-----------------
loc("/home/data/sharavana/torch-mlir/test/python/fx_importer/test_logcumsumexp.py":19:0): 
error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
Traceback (most recent call last):
  File "/home/data/sharavana/torch-mlir/test/python/fx_importer/test_logcumsumexp.py", line 13, in <module>
    def test_logcumsumexp():
  File "/home/data/sharavana/torch-mlir/test/python/fx_importer/test_logcumsumexp.py", line 9, in run
    f()
  File "/home/data/sharavana/torch-mlir/test/python/fx_importer/test_logcumsumexp.py", line 20, in test_logcumsumexp
    m = fx.export_and_import(Logcumsumexp(), torch.randn(3, 4),output_type="torch")
  File "/home/data/sharavana/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/fx.py", line 124, in export_and_import
    return _module_lowering(
  File "/home/data/sharavana/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/fx.py", line 61, in _module_lowering
    run_pipeline_with_repro_report(
  File "/home/data/sharavana/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 127, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchFX IR -> Torch Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline


For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{ extra-library=})' /home/data/sharavana/tmp/UnnammedModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.`

Minimal Reproduction

def run(f):
    print(f"{f.__name__}")
    print("-" * len(f.__name__))
    f()
    print()
@run
def test_logcumsumexp():
    class LogcumsumOp(nn.Module):
        def forward(self, x):
            return torch.logcumsumexp(x,1)
    input_tensor = torch.randn(2, 4)
    exported = fx.export_and_import(LogcumsumOp(), input_tensor, output_type="torch")
    print(exported)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions