Skip to content

Commit

Permalink
[Inductor] Flex attention supports dynamic shape (pytorch#125994)
Browse files Browse the repository at this point in the history
## static shapes perf
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     0.692 |              |             |             |             |            |             |                |
| Max     |     0.855 |           16 |          16 |        4096 |        4096 |         64 | head_bias   | torch.bfloat16 |
| Min     |     0.419 |            8 |          16 |         512 |         512 |        256 | noop        | torch.bfloat16 |
```
## dynamic shapes perf
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     0.670 |              |             |             |             |            |               |                |
| Max     |     0.864 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     0.376 |            8 |          16 |         512 |         512 |        256 | relative_bias | torch.bfloat16 |
```

Pull Request resolved: pytorch#125994
Approved by: https://github.com/Chillee
  • Loading branch information
yanboliang authored and ZelboK committed May 19, 2024
1 parent aa17484 commit 685b207
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 17 deletions.
21 changes: 16 additions & 5 deletions benchmarks/transformer/score_mod.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -98,7 +99,7 @@ def generate_inputs(
return query, key, value


def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
device = torch.device("cuda")
query, key, value = generate_inputs(
config.batch_size,
Expand All @@ -113,7 +114,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
def eager_sdpa(query, key, value, _):
return F.scaled_dot_product_attention(query, key, value)

compiled_sdpa = torch.compile(_flex_attention)
compiled_sdpa = torch.compile(_flex_attention, dynamic=dynamic)

score_mod = config.score_mod

Expand Down Expand Up @@ -242,16 +243,26 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
return all_configs


def main():
def main(dynamic=False):
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
results.append(
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)

print_results(results)


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument(
"--dynamic",
action="store_true",
help="Runs a dynamic shapes version of compiled flex attention.",
)

args = parser.parse_args()
main(args.dynamic)
145 changes: 134 additions & 11 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ def score_mod(score, b, h, m, n):


class TestTemplatedSDPA(InductorTestCase):
def _check_equal(self, golden_out, ref_out, compiled_out, dtype):
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
if compiled_error > ref_error * fudge_factor:
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)

def run_test(
self,
score_mod: Callable,
Expand All @@ -145,25 +158,135 @@ def run_test(
)
ref_out = sdpa_partial(q, k, v)
compiled_out = compiled_sdpa(q, k, v)
self._check_equal(golden_out, ref_out, compiled_out, dtype)

compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
if compiled_error > ref_error * fudge_factor:
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
def run_dynamic_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
S: int = S,
D: int = D,
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1 = sdpa_partial(
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
)
ref_out1 = sdpa_partial(q1, k1, v1)

# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2 = sdpa_partial(
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
)
ref_out2 = sdpa_partial(q2, k2, v2)

# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
torch._dynamo.reset()
# Compiling with dynamic shape in the first batch.
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)

# No re-compilation, use the compiled dynamic shape version.
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)

def run_automatic_dynamic_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
S: int = S,
D: int = D,
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1 = sdpa_partial(
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
)
ref_out1 = sdpa_partial(q1, k1, v1)

# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2 = sdpa_partial(
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
)
ref_out2 = sdpa_partial(q2, k2, v2)

# The third eager batch, shape (B * 4, H, S / 4, D)
B = int(B * 2)
S = int(S / 2)
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out3 = sdpa_partial(
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64)
)
ref_out3 = sdpa_partial(q3, k3, v3)

# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure:
# 1, the first batch is compiled with static shape
# 2, the second batch is compiled with dynamic shape
# 3, no re-compilation in the third batch
torch._dynamo.reset()
# The first batch.
compiled_sdpa = torch.compile(sdpa_partial)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)

# The second batch (automatic dynamic).
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)

# The third batch (no re-compilation).
compiled_out3 = compiled_sdpa(q3, k3, v3)
self._check_equal(golden_out3, ref_out3, compiled_out3, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable):
self.run_dynamic_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_automatic_dynamic(
self, dtype: torch.dtype, score_mod: Callable
):
self.run_automatic_dynamic_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_skip_odd_keys(self, dtype: torch.dtype):
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,3 +628,7 @@ def is_from_defaults(source: Source):
if isinstance(source, ChainedSource):
return is_from_defaults(source.base)
return False


def is_cell_contents(source: Source):
return isinstance(source, AttrSource) and source.member == "cell_contents"
2 changes: 2 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
FloatTensorSource,
GetItemSource,
GradSource,
is_cell_contents,
is_constant_source,
is_from_defaults,
is_from_optimizer_source,
Expand Down Expand Up @@ -1166,6 +1167,7 @@ def wrap_literal(self, value):
# NN modules on the fly)
or self.source.guard_source().is_nn_module()
or is_from_defaults(self.source)
or is_cell_contents(self.source)
):
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta):
# TODO generalize and add proper mask support
mask = (idx_m != -1) & (idx_d != -1)
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc")}}
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}}
# TODO dont want to write this if we dont require grad
if OUTPUT_LOGSUMEXP:
Expand Down
4 changes: 4 additions & 0 deletions torch/nn/attention/_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def score_mod(
"""

if torch.compiler.is_dynamo_compiling():
# mark head_dim and dim always to be static
for x in [query, key, value]:
torch._dynamo.mark_static(x, 1)
torch._dynamo.mark_static(x, -1)
out, _ = flex_attention_hop(query, key, value, score_mod)
return out

Expand Down

0 comments on commit 685b207

Please sign in to comment.