Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,10 +1477,49 @@ def _pixel_shuffle(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor))

def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3])
query = transpose_S_H(self.env[node.args[0]])
key = transpose_S_H(self.env[node.args[1]])
value = transpose_S_H(self.env[node.args[2]])
query_tensor = self.env[node.args[0]]
key_tensor = self.env[node.args[1]]
value_tensor = self.env[node.args[2]]

# Check the dimensionality of the input tensors
query_ndim = len(query_tensor.struct_info.shape)

# TVM's nn.attention requires 4D inputs in format (batch, num_heads, seq_len, head_dim)
# For 2D inputs (seq_len, head_dim), we need to reshape to 4D first
if query_ndim == 2:
# 2D input: (seq_len, head_dim) -> expand to (1, 1, seq_len, head_dim)
# Add batch dimension at axis 0
query_3d = self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0))
key_3d = self.block_builder.emit(relax.op.expand_dims(key_tensor, axis=0))
value_3d = self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0))
# Add num_heads dimension at axis 1
query = self.block_builder.emit(relax.op.expand_dims(query_3d, axis=1))
key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=1))
value = self.block_builder.emit(relax.op.expand_dims(value_3d, axis=1))

# No permutation needed for 2D inputs after expanding to 4D
# After attention, squeeze back to 2D: (1, 1, seq_len, head_dim) -> (seq_len, head_dim)
def transpose_and_reshape_back(tensor):
# Squeeze batch and num_heads dimensions
return self.block_builder.emit(relax.op.squeeze(tensor, axis=[0, 1]))

elif query_ndim == 4:
# 4D input: (batch, seq_len, num_heads, head_dim)
# -> (batch, num_heads, seq_len, head_dim)
transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3])
query = self.block_builder.emit(transpose_S_H(query_tensor))
key = self.block_builder.emit(transpose_S_H(key_tensor))
value = self.block_builder.emit(transpose_S_H(value_tensor))

# For 4D, transpose back after attention
def transpose_and_reshape_back(tensor):
return self.block_builder.emit(transpose_S_H(tensor))

else:
raise ValueError(
f"scaled_dot_product_attention expects 2D or 4D inputs, but got {query_ndim}D input"
)

attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None)
dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0)
assert dropout_p == 0.0, "Dropout is not supported"
Expand All @@ -1492,12 +1531,12 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
msg = "Only a float mask is supported for the attn_mask input."
assert "float" in attn_mask.struct_info.dtype, msg

return self.block_builder.emit(
transpose_S_H(
relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask)
)
attention_output = self.block_builder.emit(
relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask)
)

return transpose_and_reshape_back(attention_output)

def _unbind(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
Expand Down Expand Up @@ -1594,6 +1633,7 @@ def _any(self, node: fx.Node) -> relax.Var:
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)

# For boolean tensors, any is equivalent to max (checking if any element is True)
return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim))

Expand Down
39 changes: 39 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -4255,6 +4255,45 @@ def main(
run_ep_decomposition=True,
)

# Test 2D input (seq_len, head_dim) - bug fix for #18441
class Attention2D(Module):
def forward(self, x):
return torch.nn.functional.scaled_dot_product_attention(x, x, x, is_causal=False)

@I.ir_module
class Expected2D:
@R.function
def main(
x: R.Tensor((8, 32), dtype="float32"),
) -> R.Tuple(R.Tensor((8, 32), dtype="float32")):
with R.dataflow():
# Expand to add batch dimension for query, key, value separately
# (8, 32) -> (1, 8, 32)
lv: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0])
lv1: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0])
lv2: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0])
# Expand to add num_heads dimension: (1, 8, 32) -> (1, 1, 8, 32)
lv3: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=[1])
lv4: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv1, axis=[1])
lv5: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv2, axis=[1])
# Attention operation: (1, 1, 8, 32) -> (1, 1, 8, 32)
lv6: R.Tensor((1, 1, 8, 32), dtype="float32") = R.nn.attention(
lv3, lv4, lv5, scale=None, causal_mask=None, window_size=None
)
# Squeeze batch and num_heads dimensions: (1, 1, 8, 32) -> (8, 32)
lv7: R.Tensor((8, 32), dtype="float32") = R.squeeze(lv6, axis=[0, 1])
gv: R.Tuple(R.Tensor((8, 32), dtype="float32")) = (lv7,)
R.output(gv)
return gv

verify_model(
Attention2D(),
(torch.randn(8, 32, dtype=torch.float32),),
{},
Expected2D,
run_ep_decomposition=False,
)


def test_unbind():
class Unbind1(Module):
Expand Down