Skip to content

[Bug][Pytorch] Negative indexing not handled for start argument of flatten #10658

@AleksKnezevic

Description

@AleksKnezevic

When a flatten operator is imported from PyTorch, negative indexing on end is handled correctly, but it is not on start.

Looking at python/tvm/relay/frontend/pytorch.py:1237

    if end < 0:
        end += ndim

There is no equivalent check for start.

Expected behavior

Flatten from PyTorch with negative indexing on either argument to work correctly.

Actual behavior

Flatten from PyTorch fails with negative start index.

Steps to reproduce

Run the following module though PyTorch front end:

    def forward(
        self,
        x,
    ):
        x1 = x[:, :, :, ::2]
        x2 = x[:, :, :, 1::2]
        x = torch.stack((-x2, x1), axis=-1)
        return x.flatten(-2)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions