Skip to content

fixing rope pytest benchmark grad accumulation#3743

Merged
jjsjann123 merged 2 commits intomainfrom
rope_patch_grad_accumulation
Jan 24, 2025
Merged

fixing rope pytest benchmark grad accumulation#3743
jjsjann123 merged 2 commits intomainfrom
rope_patch_grad_accumulation

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jan 22, 2025

#3349 removed grad accumulation, but rope benchmark implementation needs an update to get that working.

Reference implementation.

           Model  Batch-Size  Sequence-Length  ... Forward-Time(ms)  Backward-Kernels  Backward-Time(ms)
0  Llama-2-7b-hf                      2             4096  ...            0.166                 5              0.857
0  Llama-3-8B                         2             8192  ...            0.567                 5              1.433
0  mistralai/Mistral-Nemo-Base-2407   1             4096  ...            0.138                 6              0.166
0  Qwen/Qwen2.5-7B-Instruct           1             4096  ...            0.072                 8              0.397
0  microsoft/Phi-3.5-mini-instruct    1             8192  ...            0.236                 6              0.494

after l2_cache clear

                             Model  Batch-Size  Sequence-Length  ... Forward-Time(ms)  Backward-Kernels  Backward-Time(ms)
0  Llama-2-7b-hf                      2             4096  ...            0.166                 5              0.870
0  Llama-3-8B                         2             8192  ...            0.567                 5              1.444
0  mistralai/Mistral-Nemo-Base-2407   1             4096  ...            0.138                 6              0.192
0  Qwen/Qwen2.5-7B-Instruct           1             4096  ...            0.072                 8              0.417
0  microsoft/Phi-3.5-mini-instruct    1             8192  ...            0.234                 6              0.516

Before this PR:

Name (time in us)                                                                       Mean                    Median
---------------------------------------------------------------------------------------------------------------------------------
test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope']        1,192.8558 (14.56)        1,191.9040 (14.53)
test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope']           1,767.5348 (21.58)        1,766.8410 (21.54)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope']        275.4680 (3.36)           275.7265 (3.36)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope']               488.4243 (5.96)           488.3105 (5.95)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope']                757.9140 (9.25)           757.6910 (9.24)
---------------------------------------------------------------------------------------------------------------------------------

In this PR:

Name (time in us)                                                                    Mean                Median
-----------------------------------------------------------------------------------------------------------------------
test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope']       871.5996 (5.23)       871.6050 (5.24)
test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope']        1,443.0095 (8.66)     1,442.9955 (8.67)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope']     166.5515 (1.0)        166.4480 (1.0)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope']            386.4463 (2.32)       386.5565 (2.32)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope']             452.3351 (2.72)       452.0685 (2.72)
-----------------------------------------------------------------------------------------------------------------------

With the existing issue on pytest/torch.profiler, if I instead run each benchmark separately,

test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope']     871.1912  871.2465
test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope']        1.4427  1.4427
test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope']   191.6567  191.6795
test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope']          416.8007  416.8935
test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope']           514.7512  514.4900

So these number does match the manual benchmark with l2_cache cleared. I think that justifies this PR.

@github-actions
Copy link

github-actions bot commented Jan 22, 2025

PR Reviewer Guide 🔍

(Review updated until commit 378979b)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 2 🔵🔵⚪⚪⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Input Validation

The gen_inputs() function is called without checking if it returns a valid input. Consider adding input validation to ensure the inputs are correct.

inputs = gen_inputs()
Potential Bug

The iobytes variable is computed based on how Thunder autograd worked, but it's not clear if this is still relevant for TorchCompile and Eager Executor. Consider reviewing this logic to ensure it's correct.

run_benchmark(

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Jan 22, 2025

phi3 now has a 450 us vs 494 us(ref). Which looks strange. I'll pull a profile to investigate.

Even stranger is that, running bwd with only phi3 gives

---------------------------------------------------------------------------- benchmark: 1 tests ---------------------------------------------------------------------------
Name (time in us)                                                             Min       Max      Mean  StdDev    Median     IQR  Outliers  OPS (Kops/s)  Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope']     512.7370  518.0190  514.6843  1.7697  514.4470  3.2210       4;0        1.9429      10           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------

But with qwen2 before phi3, we get faster kernels... looks like we are picking up some cache here and our heuristics isn't making the best decision 😢

@jjsjann123
Copy link
Collaborator Author

following up on #3394 (comment)

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Jan 23, 2025

Two separate issues I'm seeing with benchmark number:

  1. our profiler seems to be non-deterministically dropping cuda events. That's why running phi3 along has a longer measured kernel time;
    2. standalone benchmark has faster kernel time on one of the segments. That's just coming from standalone benchmark not clearing L2 on backward. So the first kernel gets some extra speed up.

Investigating...

cc'ing @naoyam in case you are looking at backward time.

@naoyam
Copy link
Collaborator

naoyam commented Jan 23, 2025

Yes, I'm looking at Phi3 but still only focuses on the forward fusion.

I wonder if it's specific to Phi3? Or, could it be it was just because it was executed last? For example, if Qwen2 was executed after Phi3, would Qwen2 see a similar discrepancy?

@jjsjann123
Copy link
Collaborator Author

strangely that seems to only affect phi3.

I swapped it to run phi3 before qwen2 and haven't noticed any event dropped in qwen2. But given that it's not deterministic in phi3 in the first place, I'm not sure if this means there's anything specific to phi3.

@jjsjann123
Copy link
Collaborator Author

Re: missing event.

  1. It's missing one of the small kernel at the beginning.
  2. The first round seems to catch all 6 event, while the subsequent round were missing the very first kernel in the group.

So I'm now wondering if we can actually run

prof.start()
# measure
prof.stop()

prof.start()
# measure
prof.stop()

Patterns I see on pytorch are all using context manager, not sure if we are hitting a bug on profiler?

@jjsjann123
Copy link
Collaborator Author

I don't think this is specific to phi3. At least I'm seeing the same thing happening to qwen2, when it's not running as the first benchmark.

So there's indeed something wrong with the benchmark profiler and pytest.

@jjsjann123 jjsjann123 requested a review from Priya2698 January 23, 2025 19:21
@jjsjann123 jjsjann123 marked this pull request as ready for review January 23, 2025 19:21
@Priya2698
Copy link
Collaborator

I don't think this is specific to phi3. At least I'm seeing the same thing happening to qwen2, when it's not running as the first benchmark.

So there's indeed something wrong with the benchmark profiler and pytest.

Can you open an issue for this mentioning the order of benchmarks to reproduce this?
I'm wondering if the start(), stop() API was changed/deprecated.

# a reference point for torchcompile and eager executor for comparison.
run_benchmark(
benchmark, unary_bwd_torch, [output, grad(), fwd_inputs()], iobytes=iobytes()
benchmark, unary_bwd_torch, [output, grad(), *fwd_inputs], iobytes=iobytes()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I need to expand fwd_inputs is to have it work with the infra code where we clear grad from torch.Tensor inputs. Since fwd_inputs is a sequence of torch.Tensor.

An alternative is to flatten inputs here instead.

Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving this PR since it is not introducing any issue with the RoPE benchmarks missing CUDA events.

@jjsjann123
Copy link
Collaborator Author

!test --diff-bench

@jjsjann123 jjsjann123 merged commit 8689c33 into main Jan 24, 2025
60 of 61 checks passed
@jjsjann123 jjsjann123 deleted the rope_patch_grad_accumulation branch January 24, 2025 18:01
@jjsjann123
Copy link
Collaborator Author

failure doesn't look related. merged as-is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants