From 18f750ee17e54ba7e5539736334852f3e285b6ec Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 3 Dec 2023 11:12:07 -0800 Subject: [PATCH] [inductor] Fix shape mismatch in sdpa pattern matcher (#115038) Fixes #100316 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115038 Approved by: https://github.com/oulgen --- test/inductor/test_cpu_repro.py | 49 ++++++++++++++++++++++++++++++ torch/_inductor/pattern_matcher.py | 10 +++++- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 06dff1d34e319..d4126a0c58e1c 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -2463,6 +2463,55 @@ def fn(x1, x2): ) self.assertEqual(metrics.generated_kernel_count, 1) + def test_attention_size_mismatch(self): + class Attention(torch.nn.Module): + def __init__(self, hidden_size, num_heads): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_size = hidden_size // num_heads + self.query = torch.nn.Linear(hidden_size, hidden_size) + self.key = torch.nn.Linear(hidden_size, hidden_size) + self.value = torch.nn.Linear(hidden_size, hidden_size) + self.inv_scale = torch.nn.Parameter( + torch.Tensor([1 / self.head_size**0.5]), requires_grad=False + ) + + def forward(self, x): + query = self.query(x) + key = self.key(x) + value = self.value(x) + (batch_size, seq_len, hidden_size) = query.size() + query = query.view( + batch_size, seq_len, self.num_heads, self.head_size + ).permute(0, 2, 1, 3) + key = key.view( + batch_size, seq_len, self.num_heads, self.head_size + ).permute(0, 2, 3, 1) + value = value.view( + batch_size, seq_len, self.num_heads, self.head_size + ).permute(0, 2, 1, 3) + attention_weights = ( + torch.matmul(query, key).div(self.inv_scale).softmax(dim=-1) + ) + output = torch.matmul(attention_weights, value) + return output + + torch.manual_seed(123) + hidden_size = 16 + num_heads = 1 + seq_len = 4 + batch_size = 1 + x = torch.randn(batch_size, seq_len, hidden_size) + + func = Attention(hidden_size, num_heads).to("cpu") + + with torch.no_grad(): + res1 = func(x) + jit_func = torch.compile(func) + res2 = jit_func(x) + self.assertEqual(res1, res2) + def test_scalar_mul_bfloat16(self): def f(x): return torch.ops.aten.mul.Tensor(x, 1.7015043497085571) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 4c506ae1824a0..33486b47b0559 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -919,7 +919,15 @@ def check_fn(match: Match): device=args[i].device, requires_grad=grad, ) - specific_graph = trace_fn(search_fn, args) + try: + specific_graph = trace_fn(search_fn, args) + except RuntimeError as e: + log.info( + "Replacement pattern %s failed to apply due to shape mismatch: %s", + search_fn.__name__, + e, + ) + return False specific_pattern = fx_to_pattern( specific_graph, argnames=argnames,