Skip to content

[Relax][PyTorch] Fix crash on dynamic shapes with identity slice in ExportedProgram importer#18903

Merged
tlopex merged 1 commit intoapache:mainfrom
mshr-h:fix-swin_t-dynamic
Mar 10, 2026
Merged

[Relax][PyTorch] Fix crash on dynamic shapes with identity slice in ExportedProgram importer#18903
tlopex merged 1 commit intoapache:mainfrom
mshr-h:fix-swin_t-dynamic

Conversation

@mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Mar 10, 2026

Fixes TypeError: 'NoneType' object is not iterable when importing models with dynamic batch dimensions that contain identity slices (e.g., x[:, :H, :W, :] on a dynamic batch dim).

Root cause: aten.slice.Tensor(x, 0, 0, INT_MAX) (an identity slice on a dynamic dim s) produces a result with shape [T.min(INT_MAX, s), ...] instead of [s, ...]. When this is combined with the original tensor via add, TVM cannot unify the shapes, resulting in struct_info.shape = None. Any subsequent view/reshape then crashes calling list(None).

This pattern appears in models like swin_t, where shifted window attention crops padded features with x[:, :H, :W, :].contiguous().

Changes:

  • exported_program_translator.py: Skip strided_slice for identity slices (start=0, end>=INT_MAX, step=1) and return the input tensor directly.
  • base_fx_graph_translator.py: Guard the identity-reshape check in _reshape against None shape.

…n ExportedProgram importer

Fixes `TypeError: 'NoneType' object is not iterable` when importing models with dynamic batch dimensions that contain identity slices (e.g., `x[:, :H, :W, :]` on a dynamic batch dim).

**Root cause:** `aten.slice.Tensor(x, 0, 0, INT_MAX)` (an identity slice on a dynamic dim `s`) produces a result with shape `[T.min(INT_MAX, s), ...]` instead of `[s, ...]`. When this is combined with the original tensor via `add`, TVM cannot unify the shapes, resulting in `struct_info.shape = None`. Any subsequent `view`/`reshape` then crashes calling `list(None)`.

This pattern appears in models like `swin_t`, where shifted window attention crops padded features with `x[:, :H, :W, :].contiguous()`.

**Changes:**
- `exported_program_translator.py`: Skip `strided_slice` for identity slices (`start=0, end>=INT_MAX, step=1`) and return the input tensor directly.
- `base_fx_graph_translator.py`: Guard the identity-reshape check in `_reshape` against `None` shape.
@mshr-h mshr-h changed the title [Frontend][PyTorch] Fix crash on dynamic shapes with identity slice in ExportedProgram importer [Relax][PyTorch] Fix crash on dynamic shapes with identity slice in ExportedProgram importer Mar 10, 2026
@mshr-h
Copy link
Contributor Author

mshr-h commented Mar 10, 2026

Repro:

repro.py
import numpy as np
import torch
import torch.nn.functional as F
from torch.export import Dim, export

import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program


class PadRollCropAdd(torch.nn.Module):
    """Mimics swin_t's shifted window attention pattern:
    pad → cyclic shift (roll) → crop (unpad) → residual add → reshape.
    """

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, C = x.shape
        # Pad spatial dims (as in swin_t before window partitioning).
        padded = F.pad(x, (0, 0, 0, 1, 0, 1))
        # Cyclic shift (torch.roll on spatial dims, same as swin_t).
        rolled = torch.roll(padded, shifts=(-1, -1), dims=(1, 2))
        # Unpad / crop back to original spatial size.
        # The [:] on dim 0 generates aten.slice.Tensor(rolled, 0, 0, INT_MAX),
        # an identity slice on the dynamic batch dim.
        cropped = rolled[:, :H, :W, :]
        # Residual add — TVM can't unify shapes:
        #   x has [s, H, W, C] but cropped has [T.min(INT_MAX, s), H, W, C]
        #   → result struct_info.shape = None.
        out = x + cropped
        # View on the shape=None tensor triggers the crash.
        return out.view(B, H * W * C)


def main():
    model = PadRollCropAdd().eval()
    x = torch.randn(2, 4, 4, 2)

    batch = Dim("batch")
    exported_program = export(model, (x,), dynamic_shapes={"x": {0: batch}})

    mod = from_exported_program(exported_program)

    mod = relax.transform.DecomposeOpsForInference()(mod)
    target = tvm.target.Target("llvm")
    exe = tvm.compile(mod, target=target)
    vm = relax.VirtualMachine(exe, tvm.cpu())

    tvm_input = tvm.runtime.from_dlpack(x.contiguous())
    tvm_output = vm["main"](tvm_input)
    tvm_output_np = tvm_output.numpy() if hasattr(tvm_output, "numpy") else tvm_output[0].numpy()

    with torch.no_grad():
        torch_output = model(x).numpy()

    np.testing.assert_allclose(tvm_output_np, torch_output, rtol=1e-5, atol=1e-5)
    print("Numerical check passed")


if __name__ == "__main__":
    main()

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a TypeError occurring during the import of PyTorch models with dynamic batch dimensions and identity slices into Relax. The issue stemmed from aten.slice.Tensor producing complex shapes for identity slices, leading to None shapes and subsequent crashes during view/reshape operations. The fix involves explicitly handling identity slices to avoid problematic shape inference and guarding against None shapes in reshape logic, ensuring robust model import for dynamic shapes.

Highlights

  • Identity Slice Handling: Modified exported_program_translator.py to directly return the input tensor for identity slices (start=0, end>=INT_MAX, step=1) to prevent shape inference failures with dynamic shapes.
  • None Shape Guard: Added a null check in base_fx_graph_translator.py to prevent TypeError when current_shape is None during identity reshape checks.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/tvm/relax/frontend/torch/base_fx_graph_translator.py
    • Added a current_shape is not None check to the identity reshape condition in _reshape to prevent TypeError when current_shape is None.
  • python/tvm/relax/frontend/torch/exported_program_translator.py
    • Introduced a conditional block in _slice to detect identity slices (start=0, end>=sys.maxsize, step=1) and return the original input tensor x directly, bypassing strided_slice for these cases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mshr-h mshr-h marked this pull request as ready for review March 10, 2026 11:24
@mshr-h
Copy link
Contributor Author

mshr-h commented Mar 10, 2026

cc @tlopex @guan404ming

@mshr-h mshr-h requested review from guan404ming and tlopex March 10, 2026 11:24
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash in the ExportedProgram importer when handling dynamic shapes with identity slices. The fix is twofold: first, it introduces a check in exported_program_translator.py to detect and bypass identity slices, which prevents downstream shape inference problems. Second, it adds a null check in base_fx_graph_translator.py to prevent crashes when handling None shapes during reshape operations. The changes are logical and effectively solve the issue. I have one suggestion to improve the readability of the new condition for detecting identity slices.

Comment on lines +948 to +955
if (
isinstance(start, int)
and isinstance(end_val, int)
and isinstance(step, int)
and start == 0
and end_val >= sys.maxsize
and step == 1
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and conciseness, you could group the isinstance checks using the all() built-in function. This would make the condition for identifying an identity slice slightly easier to parse.

        if (
            all(isinstance(v, int) for v in (start, end_val, step))
            and start == 0
            and end_val >= sys.maxsize
            and step == 1
        ):

@tlopex tlopex merged commit 1499bda into apache:main Mar 10, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants