Skip to content

Commit

Permalink
[inductor] Fix shape mismatch in sdpa pattern matcher (pytorch#115038)
Browse files Browse the repository at this point in the history
Fixes pytorch#100316

Pull Request resolved: pytorch#115038
Approved by: https://github.com/oulgen
  • Loading branch information
jansel authored and ZhiweiYan-96 committed Dec 22, 2023
1 parent 8811baf commit 18f750e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
49 changes: 49 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2463,6 +2463,55 @@ def fn(x1, x2):
)
self.assertEqual(metrics.generated_kernel_count, 1)

def test_attention_size_mismatch(self):
class Attention(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.query = torch.nn.Linear(hidden_size, hidden_size)
self.key = torch.nn.Linear(hidden_size, hidden_size)
self.value = torch.nn.Linear(hidden_size, hidden_size)
self.inv_scale = torch.nn.Parameter(
torch.Tensor([1 / self.head_size**0.5]), requires_grad=False
)

def forward(self, x):
query = self.query(x)
key = self.key(x)
value = self.value(x)
(batch_size, seq_len, hidden_size) = query.size()
query = query.view(
batch_size, seq_len, self.num_heads, self.head_size
).permute(0, 2, 1, 3)
key = key.view(
batch_size, seq_len, self.num_heads, self.head_size
).permute(0, 2, 3, 1)
value = value.view(
batch_size, seq_len, self.num_heads, self.head_size
).permute(0, 2, 1, 3)
attention_weights = (
torch.matmul(query, key).div(self.inv_scale).softmax(dim=-1)
)
output = torch.matmul(attention_weights, value)
return output

torch.manual_seed(123)
hidden_size = 16
num_heads = 1
seq_len = 4
batch_size = 1
x = torch.randn(batch_size, seq_len, hidden_size)

func = Attention(hidden_size, num_heads).to("cpu")

with torch.no_grad():
res1 = func(x)
jit_func = torch.compile(func)
res2 = jit_func(x)
self.assertEqual(res1, res2)

def test_scalar_mul_bfloat16(self):
def f(x):
return torch.ops.aten.mul.Tensor(x, 1.7015043497085571)
Expand Down
10 changes: 9 additions & 1 deletion torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,15 @@ def check_fn(match: Match):
device=args[i].device,
requires_grad=grad,
)
specific_graph = trace_fn(search_fn, args)
try:
specific_graph = trace_fn(search_fn, args)
except RuntimeError as e:
log.info(
"Replacement pattern %s failed to apply due to shape mismatch: %s",
search_fn.__name__,
e,
)
return False
specific_pattern = fx_to_pattern(
specific_graph,
argnames=argnames,
Expand Down

0 comments on commit 18f750e

Please sign in to comment.