From dfd192f3e5e41925caacc56dc1d9e06c36999ab6 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 29 Nov 2025 03:59:41 +0000 Subject: [PATCH 01/10] [Relax][PyTorch] Fix InternalError when converting scaled_dot_product_attention with 2D inputs Fixes #18441 Previously, the TVM frontend incorrectly assumed 4D input dimensions for scaled_dot_product_attention, causing an InternalError when the actual input was 2D (seq_len, head_dim). This fix: - Detects input dimensionality (2D vs 4D) - For 2D inputs: expands to 4D, calls attention, then squeezes back - For 4D inputs: maintains existing behavior - Adds test case for 2D input scenario - Updates verify_model_numerically to use strict=False for export --- .../torch/base_fx_graph_translator.py | 54 ++++++++++++++++--- .../test_frontend_from_exported_program.py | 18 ++++++- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1938355169f0..ab20366387ef 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1477,10 +1477,48 @@ def _pixel_shuffle(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor)) def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) + query_tensor = self.env[node.args[0]] + key_tensor = self.env[node.args[1]] + value_tensor = self.env[node.args[2]] + + # Check the dimensionality of the input tensors + query_ndim = len(query_tensor.struct_info.shape) + + # TVM's nn.attention requires 4D inputs (batch, seq_len, num_heads, head_dim) + # For 2D inputs (seq_len, head_dim), we need to reshape to 4D first + if query_ndim == 2: + # 2D input: (seq_len, head_dim) -> expand to (1, seq_len, 1, head_dim) + # Add batch dimension at axis 0 + query_3d = self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0)) + key_3d = self.block_builder.emit(relax.op.expand_dims(key_tensor, axis=0)) + value_3d = self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0)) + # Add num_heads dimension at axis 2 + query = self.block_builder.emit(relax.op.expand_dims(query_3d, axis=2)) + key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=2)) + value = self.block_builder.emit(relax.op.expand_dims(value_3d, axis=2)) + + # No permutation needed for 2D inputs after expanding to 4D + # After attention, squeeze back to 2D: (1, seq_len, 1, head_dim) -> (seq_len, head_dim) + def transpose_and_reshape_back(tensor): + # Squeeze num_heads dimension (axis 2) + tensor_3d = self.block_builder.emit(relax.op.squeeze(tensor, axis=[2])) + # Squeeze batch dimension (axis 0) + return self.block_builder.emit(relax.op.squeeze(tensor_3d, axis=[0])) + elif query_ndim == 4: + # 4D input: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = transpose_S_H(query_tensor) + key = transpose_S_H(key_tensor) + value = transpose_S_H(value_tensor) + + # For 4D, transpose back after attention + def transpose_and_reshape_back(tensor): + return transpose_S_H(tensor) + else: + raise ValueError( + f"scaled_dot_product_attention expects 2D or 4D inputs, but got {query_ndim}D input" + ) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) assert dropout_p == 0.0, "Dropout is not supported" @@ -1492,11 +1530,11 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: msg = "Only a float mask is supported for the attn_mask input." assert "float" in attn_mask.struct_info.dtype, msg - return self.block_builder.emit( - transpose_S_H( - relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) - ) + attention_output = self.block_builder.emit( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) ) + + return transpose_and_reshape_back(attention_output) def _unbind(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8ff46bf611b2..0a72a2a901f4 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -62,7 +62,9 @@ def verify_model_numerically(torch_model, example_args, rtol=1e-7, atol=1e-7): with torch.no_grad(): pytorch_output = torch_model(*example_args) - exported_program = export(torch_model, args=example_args) + # Use strict=False to handle ops like scaled_dot_product_attention that may have + # internal non-exportable operations + exported_program = export(torch_model, args=example_args, strict=False) mod = from_exported_program(exported_program) target = tvm.target.Target("llvm") ex = relax.build(mod, target) @@ -4255,6 +4257,20 @@ def main( run_ep_decomposition=True, ) + # Test 2D input (seq_len, head_dim) - bug fix for #18441 + class Attention2D(Module): + def forward(self, x): + return torch.nn.functional.scaled_dot_product_attention(x, x, x, is_causal=False) + + # For 2D input, we just verify that conversion succeeds without error + # The expected IR is complex due to reshape operations, so we use verify_model_numerically + verify_model_numerically( + Attention2D(), + (torch.randn(8, 32, dtype=torch.float32),), + rtol=1e-5, + atol=1e-5, + ) + def test_unbind(): class Unbind1(Module): From 0bce7d95d6701b9149abe6c96e525fbd05f4674e Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 29 Nov 2025 03:59:41 +0000 Subject: [PATCH 02/10] [Relax][PyTorch] Fix scaled_dot_product_attention with 2D inputs --- .../torch/base_fx_graph_translator.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ab20366387ef..3cb3ce426726 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1480,37 +1480,35 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: query_tensor = self.env[node.args[0]] key_tensor = self.env[node.args[1]] value_tensor = self.env[node.args[2]] - + # Check the dimensionality of the input tensors query_ndim = len(query_tensor.struct_info.shape) - - # TVM's nn.attention requires 4D inputs (batch, seq_len, num_heads, head_dim) + + # TVM's nn.attention requires 4D inputs in format (batch, num_heads, seq_len, head_dim) # For 2D inputs (seq_len, head_dim), we need to reshape to 4D first if query_ndim == 2: - # 2D input: (seq_len, head_dim) -> expand to (1, seq_len, 1, head_dim) + # 2D input: (seq_len, head_dim) -> expand to (1, 1, seq_len, head_dim) # Add batch dimension at axis 0 query_3d = self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0)) key_3d = self.block_builder.emit(relax.op.expand_dims(key_tensor, axis=0)) value_3d = self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0)) - # Add num_heads dimension at axis 2 - query = self.block_builder.emit(relax.op.expand_dims(query_3d, axis=2)) - key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=2)) - value = self.block_builder.emit(relax.op.expand_dims(value_3d, axis=2)) - + # Add num_heads dimension at axis 1 + query = self.block_builder.emit(relax.op.expand_dims(query_3d, axis=1)) + key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=1)) + value = self.block_builder.emit(relax.op.expand_dims(value_3d, axis=1)) + # No permutation needed for 2D inputs after expanding to 4D - # After attention, squeeze back to 2D: (1, seq_len, 1, head_dim) -> (seq_len, head_dim) + # After attention, squeeze back to 2D: (1, 1, seq_len, head_dim) -> (seq_len, head_dim) def transpose_and_reshape_back(tensor): - # Squeeze num_heads dimension (axis 2) - tensor_3d = self.block_builder.emit(relax.op.squeeze(tensor, axis=[2])) - # Squeeze batch dimension (axis 0) - return self.block_builder.emit(relax.op.squeeze(tensor_3d, axis=[0])) + # Squeeze batch and num_heads dimensions + return self.block_builder.emit(relax.op.squeeze(tensor, axis=[0, 1])) elif query_ndim == 4: # 4D input: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) query = transpose_S_H(query_tensor) key = transpose_S_H(key_tensor) value = transpose_S_H(value_tensor) - + # For 4D, transpose back after attention def transpose_and_reshape_back(tensor): return transpose_S_H(tensor) @@ -1518,7 +1516,7 @@ def transpose_and_reshape_back(tensor): raise ValueError( f"scaled_dot_product_attention expects 2D or 4D inputs, but got {query_ndim}D input" ) - + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) assert dropout_p == 0.0, "Dropout is not supported" @@ -1533,7 +1531,7 @@ def transpose_and_reshape_back(tensor): attention_output = self.block_builder.emit( relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) ) - + return transpose_and_reshape_back(attention_output) def _unbind(self, node: fx.Node) -> relax.Var: From a09daa898634239a3a1783031830e2c7cabacecf Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 29 Nov 2025 04:28:51 +0000 Subject: [PATCH 03/10] [Relax][PyTorch] Fix formatting in _scaled_dot_product_attention --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 3cb3ce426726..0938a89026d3 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1502,6 +1502,7 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: def transpose_and_reshape_back(tensor): # Squeeze batch and num_heads dimensions return self.block_builder.emit(relax.op.squeeze(tensor, axis=[0, 1])) + elif query_ndim == 4: # 4D input: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) @@ -1512,6 +1513,7 @@ def transpose_and_reshape_back(tensor): # For 4D, transpose back after attention def transpose_and_reshape_back(tensor): return transpose_S_H(tensor) + else: raise ValueError( f"scaled_dot_product_attention expects 2D or 4D inputs, but got {query_ndim}D input" From 9ecf7f24344884097968beb737d7a8ef380bb16b Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 29 Nov 2025 04:40:56 +0000 Subject: [PATCH 04/10] [Relax][PyTorch] Fix pylint line length in _scaled_dot_product_attention --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 0938a89026d3..fd7ef3d57af4 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1504,7 +1504,8 @@ def transpose_and_reshape_back(tensor): return self.block_builder.emit(relax.op.squeeze(tensor, axis=[0, 1])) elif query_ndim == 4: - # 4D input: (batch, seq_len, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim) + # 4D input: (batch, seq_len, num_heads, head_dim) + # -> (batch, num_heads, seq_len, head_dim) transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) query = transpose_S_H(query_tensor) key = transpose_S_H(key_tensor) From 8974714bb908d1d304c492669848ba3d24696aa3 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 29 Nov 2025 11:01:00 +0000 Subject: [PATCH 05/10] [Relax][PyTorch] Fix bool max operation and attention output structure - Fix bool type handling in statistical operations by converting to int32 - Ensure attention transpose operations generate intermediate variables --- .../relax/frontend/torch/base_fx_graph_translator.py | 8 ++++---- .../tvm/relax/transform/legalize_ops/statistical.py | 12 +++++++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index fd7ef3d57af4..77fe81008f67 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1507,13 +1507,13 @@ def transpose_and_reshape_back(tensor): # 4D input: (batch, seq_len, num_heads, head_dim) # -> (batch, num_heads, seq_len, head_dim) transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(query_tensor) - key = transpose_S_H(key_tensor) - value = transpose_S_H(value_tensor) + query = self.block_builder.emit(transpose_S_H(query_tensor)) + key = self.block_builder.emit(transpose_S_H(key_tensor)) + value = self.block_builder.emit(transpose_S_H(value_tensor)) # For 4D, transpose back after attention def transpose_and_reshape_back(tensor): - return transpose_S_H(tensor) + return self.block_builder.emit(transpose_S_H(tensor)) else: raise ValueError( diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index bdb79126f012..7ae2bb4c68da 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -18,6 +18,7 @@ """Default legalization function for statistical operators.""" from typing import List from tvm import topi, tir, te +from tvm import relax from ...block_builder import BlockBuilder from ...expr import Call, Expr from .common import TEFunc, LegalizeFunc, register_legalize @@ -25,7 +26,16 @@ def _statistical(te_func: TEFunc) -> LegalizeFunc: def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) + arg = call.args[0] + # Handle bool type by converting to int32 first + # since topi.max/min don't support bool directly + if arg.struct_info.dtype == "bool": + # Convert bool to int32 for statistical operations + arg_int = bb.emit(relax.op.astype(arg, "int32")) + result_int = bb.call_te(te_func, arg_int, call.attrs.axis, call.attrs.keepdims) + # Convert back to bool + return bb.emit(relax.op.astype(result_int, "bool")) + return bb.call_te(te_func, arg, call.attrs.axis, call.attrs.keepdims) return statistical_call_te From c0c7f928e18383c4900e8becb25ba4ffe1a01414 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 29 Nov 2025 11:22:42 +0000 Subject: [PATCH 06/10] [Relax][PyTorch] Use structural equality for 2D attention test Replace verify_model_numerically with verify_model to avoid computational cost of E2E test --- .../test_frontend_from_exported_program.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 0a72a2a901f4..0d57c6c1795a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4262,13 +4262,33 @@ class Attention2D(Module): def forward(self, x): return torch.nn.functional.scaled_dot_product_attention(x, x, x, is_causal=False) - # For 2D input, we just verify that conversion succeeds without error - # The expected IR is complex due to reshape operations, so we use verify_model_numerically - verify_model_numerically( + @I.ir_module + class Expected2D: + @R.function + def main( + x: R.Tensor((8, 32), dtype="float32"), + ) -> R.Tensor((8, 32), dtype="float32"): + with R.dataflow(): + # Expand to add batch dimension: (8, 32) -> (1, 8, 32) + lv: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=0) + # Expand to add num_heads dimension: (1, 8, 32) -> (1, 1, 8, 32) + lv1: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=1) + lv2: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=1) + lv3: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=1) + # Attention operation: (1, 1, 8, 32) -> (1, 1, 8, 32) + lv4: R.Tensor((1, 1, 8, 32), dtype="float32") = R.nn.attention( + lv1, lv2, lv3, scale=None, causal_mask=None + ) + # Squeeze batch and num_heads dimensions: (1, 1, 8, 32) -> (8, 32) + gv: R.Tensor((8, 32), dtype="float32") = R.squeeze(lv4, axis=[0, 1]) + R.output(gv) + return gv + + verify_model( Attention2D(), (torch.randn(8, 32, dtype=torch.float32),), - rtol=1e-5, - atol=1e-5, + {}, + Expected2D, ) From 21b75fede9c1b3afe12cf4861cec8e4e4a8a4e32 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 29 Nov 2025 20:56:34 +0900 Subject: [PATCH 07/10] compute max op in int8 because max doesn't support boolean directly --- .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 77fe81008f67..33a22b34fcc0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1633,6 +1633,13 @@ def _any(self, node: fx.Node) -> relax.Var: x = args[0] dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + + # max doesn't support boolean tensors directly, so we compute it in int8 and cast back + if x.struct_info.dtype == "bool": + x = relax.op.astype(x, "int8") + ret = relax.op.max(x, dim, keepdims=keepdim) + return self.block_builder.emit(relax.op.astype(ret, "bool")) + # For boolean tensors, any is equivalent to max (checking if any element is True) return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim)) From a1c1075e017203e1ddc0f7675b021101dd98a98f Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 29 Nov 2025 12:44:50 +0000 Subject: [PATCH 08/10] [Relax][PyTorch] Address code review feedback - Revert bool handling in statistical.py (handled in ep frontend) - Remove strict=False from verify_model_numerically (SDPA should be exportable) - Add run_ep_decomposition=False to 2D attention test --- .../tvm/relax/transform/legalize_ops/statistical.py | 12 +----------- .../relax/test_frontend_from_exported_program.py | 5 ++--- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index 7ae2bb4c68da..bdb79126f012 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -18,7 +18,6 @@ """Default legalization function for statistical operators.""" from typing import List from tvm import topi, tir, te -from tvm import relax from ...block_builder import BlockBuilder from ...expr import Call, Expr from .common import TEFunc, LegalizeFunc, register_legalize @@ -26,16 +25,7 @@ def _statistical(te_func: TEFunc) -> LegalizeFunc: def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: - arg = call.args[0] - # Handle bool type by converting to int32 first - # since topi.max/min don't support bool directly - if arg.struct_info.dtype == "bool": - # Convert bool to int32 for statistical operations - arg_int = bb.emit(relax.op.astype(arg, "int32")) - result_int = bb.call_te(te_func, arg_int, call.attrs.axis, call.attrs.keepdims) - # Convert back to bool - return bb.emit(relax.op.astype(result_int, "bool")) - return bb.call_te(te_func, arg, call.attrs.axis, call.attrs.keepdims) + return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) return statistical_call_te diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 0d57c6c1795a..e87fe4c69b70 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -62,9 +62,7 @@ def verify_model_numerically(torch_model, example_args, rtol=1e-7, atol=1e-7): with torch.no_grad(): pytorch_output = torch_model(*example_args) - # Use strict=False to handle ops like scaled_dot_product_attention that may have - # internal non-exportable operations - exported_program = export(torch_model, args=example_args, strict=False) + exported_program = export(torch_model, args=example_args) mod = from_exported_program(exported_program) target = tvm.target.Target("llvm") ex = relax.build(mod, target) @@ -4289,6 +4287,7 @@ def main( (torch.randn(8, 32, dtype=torch.float32),), {}, Expected2D, + run_ep_decomposition=False, ) From ea5440de370cace18bc156f5e352ab74b6e81f0d Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 29 Nov 2025 22:51:48 +0900 Subject: [PATCH 09/10] revert _any --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33a22b34fcc0..e554648c41ad 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1634,12 +1634,6 @@ def _any(self, node: fx.Node) -> relax.Var: dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) - # max doesn't support boolean tensors directly, so we compute it in int8 and cast back - if x.struct_info.dtype == "bool": - x = relax.op.astype(x, "int8") - ret = relax.op.max(x, dim, keepdims=keepdim) - return self.block_builder.emit(relax.op.astype(ret, "bool")) - # For boolean tensors, any is equivalent to max (checking if any element is True) return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim)) From 027dfa49eac756661b04a59b0cad9bfafd984f13 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 29 Nov 2025 15:05:50 +0000 Subject: [PATCH 10/10] [Relax][PyTorch] Fix expected IR for 2D attention test Match expected IR with actual generated IR where query, key, value are expanded separately --- .../test_frontend_from_exported_program.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e87fe4c69b70..f59784c3a2f0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4265,20 +4265,24 @@ class Expected2D: @R.function def main( x: R.Tensor((8, 32), dtype="float32"), - ) -> R.Tensor((8, 32), dtype="float32"): + ) -> R.Tuple(R.Tensor((8, 32), dtype="float32")): with R.dataflow(): - # Expand to add batch dimension: (8, 32) -> (1, 8, 32) - lv: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=0) + # Expand to add batch dimension for query, key, value separately + # (8, 32) -> (1, 8, 32) + lv: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) + lv1: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) + lv2: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) # Expand to add num_heads dimension: (1, 8, 32) -> (1, 1, 8, 32) - lv1: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=1) - lv2: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=1) - lv3: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=1) + lv3: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=[1]) + lv4: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv1, axis=[1]) + lv5: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv2, axis=[1]) # Attention operation: (1, 1, 8, 32) -> (1, 1, 8, 32) - lv4: R.Tensor((1, 1, 8, 32), dtype="float32") = R.nn.attention( - lv1, lv2, lv3, scale=None, causal_mask=None + lv6: R.Tensor((1, 1, 8, 32), dtype="float32") = R.nn.attention( + lv3, lv4, lv5, scale=None, causal_mask=None, window_size=None ) # Squeeze batch and num_heads dimensions: (1, 1, 8, 32) -> (8, 32) - gv: R.Tensor((8, 32), dtype="float32") = R.squeeze(lv4, axis=[0, 1]) + lv7: R.Tensor((8, 32), dtype="float32") = R.squeeze(lv6, axis=[0, 1]) + gv: R.Tuple(R.Tensor((8, 32), dtype="float32")) = (lv7,) R.output(gv) return gv