[Relax][PyTorch] Fix crash on dynamic shapes with identity slice in ExportedProgram importer#18903
Conversation
…n ExportedProgram importer Fixes `TypeError: 'NoneType' object is not iterable` when importing models with dynamic batch dimensions that contain identity slices (e.g., `x[:, :H, :W, :]` on a dynamic batch dim). **Root cause:** `aten.slice.Tensor(x, 0, 0, INT_MAX)` (an identity slice on a dynamic dim `s`) produces a result with shape `[T.min(INT_MAX, s), ...]` instead of `[s, ...]`. When this is combined with the original tensor via `add`, TVM cannot unify the shapes, resulting in `struct_info.shape = None`. Any subsequent `view`/`reshape` then crashes calling `list(None)`. This pattern appears in models like `swin_t`, where shifted window attention crops padded features with `x[:, :H, :W, :].contiguous()`. **Changes:** - `exported_program_translator.py`: Skip `strided_slice` for identity slices (`start=0, end>=INT_MAX, step=1`) and return the input tensor directly. - `base_fx_graph_translator.py`: Guard the identity-reshape check in `_reshape` against `None` shape.
|
Repro: repro.pyimport numpy as np
import torch
import torch.nn.functional as F
from torch.export import Dim, export
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
class PadRollCropAdd(torch.nn.Module):
"""Mimics swin_t's shifted window attention pattern:
pad → cyclic shift (roll) → crop (unpad) → residual add → reshape.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape
# Pad spatial dims (as in swin_t before window partitioning).
padded = F.pad(x, (0, 0, 0, 1, 0, 1))
# Cyclic shift (torch.roll on spatial dims, same as swin_t).
rolled = torch.roll(padded, shifts=(-1, -1), dims=(1, 2))
# Unpad / crop back to original spatial size.
# The [:] on dim 0 generates aten.slice.Tensor(rolled, 0, 0, INT_MAX),
# an identity slice on the dynamic batch dim.
cropped = rolled[:, :H, :W, :]
# Residual add — TVM can't unify shapes:
# x has [s, H, W, C] but cropped has [T.min(INT_MAX, s), H, W, C]
# → result struct_info.shape = None.
out = x + cropped
# View on the shape=None tensor triggers the crash.
return out.view(B, H * W * C)
def main():
model = PadRollCropAdd().eval()
x = torch.randn(2, 4, 4, 2)
batch = Dim("batch")
exported_program = export(model, (x,), dynamic_shapes={"x": {0: batch}})
mod = from_exported_program(exported_program)
mod = relax.transform.DecomposeOpsForInference()(mod)
target = tvm.target.Target("llvm")
exe = tvm.compile(mod, target=target)
vm = relax.VirtualMachine(exe, tvm.cpu())
tvm_input = tvm.runtime.from_dlpack(x.contiguous())
tvm_output = vm["main"](tvm_input)
tvm_output_np = tvm_output.numpy() if hasattr(tvm_output, "numpy") else tvm_output[0].numpy()
with torch.no_grad():
torch_output = model(x).numpy()
np.testing.assert_allclose(tvm_output_np, torch_output, rtol=1e-5, atol=1e-5)
print("Numerical check passed")
if __name__ == "__main__":
main() |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a crash in the ExportedProgram importer when handling dynamic shapes with identity slices. The fix is twofold: first, it introduces a check in exported_program_translator.py to detect and bypass identity slices, which prevents downstream shape inference problems. Second, it adds a null check in base_fx_graph_translator.py to prevent crashes when handling None shapes during reshape operations. The changes are logical and effectively solve the issue. I have one suggestion to improve the readability of the new condition for detecting identity slices.
| if ( | ||
| isinstance(start, int) | ||
| and isinstance(end_val, int) | ||
| and isinstance(step, int) | ||
| and start == 0 | ||
| and end_val >= sys.maxsize | ||
| and step == 1 | ||
| ): |
There was a problem hiding this comment.
For improved readability and conciseness, you could group the isinstance checks using the all() built-in function. This would make the condition for identifying an identity slice slightly easier to parse.
if (
all(isinstance(v, int) for v in (start, end_val, step))
and start == 0
and end_val >= sys.maxsize
and step == 1
):
Fixes
TypeError: 'NoneType' object is not iterablewhen importing models with dynamic batch dimensions that contain identity slices (e.g.,x[:, :H, :W, :]on a dynamic batch dim).Root cause:
aten.slice.Tensor(x, 0, 0, INT_MAX)(an identity slice on a dynamic dims) produces a result with shape[T.min(INT_MAX, s), ...]instead of[s, ...]. When this is combined with the original tensor viaadd, TVM cannot unify the shapes, resulting instruct_info.shape = None. Any subsequentview/reshapethen crashes callinglist(None).This pattern appears in models like
swin_t, where shifted window attention crops padded features withx[:, :H, :W, :].contiguous().Changes:
exported_program_translator.py: Skipstrided_slicefor identity slices (start=0, end>=INT_MAX, step=1) and return the input tensor directly.base_fx_graph_translator.py: Guard the identity-reshape check in_reshapeagainstNoneshape.