Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU-based vectorized SpecAug #9041

Closed
wants to merge 2 commits into from
Closed

GPU-based vectorized SpecAug #9041

wants to merge 2 commits into from

Conversation

pzelasko
Copy link
Collaborator

What does this PR do ?

Context: together with @galv we found feature normalization and specaug to take approx 30% of the total forward step time in Canary training, due to CPU-bottlenecked implementations. Feature normalization is addressed in #8964.

This PR adds a GPU based SpecAugment. The original implementation loops over every example and every mask, then waits on CPU RNG to sample the numbers. The new fast implementation still applies masks sequentially, but is vectorized on batch size and uses GPU's RNG. We found approx. 5x speedup (70ms -> 17ms in profiling, but both numbers include profiler overhead). I also added a flag to be able to revert to the old implementation in case anybody encounters a compatibility issue.

I validated visually that the new impl behavior is as expected.

Old:
image

New:
image

Collection: ASR

Changelog

  • 5x faster SpecAugment to reduce typical forward step time by ~10%.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

Jenkins CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

There's no need to comment jenkins on the PR to trigger Jenkins CI.
The GitHub Actions CI will run automatically when the PR is opened.
To run CI on an untrusted fork, a NeMo user with write access must click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
@github-actions github-actions bot added the ASR label Apr 25, 2024
@pzelasko pzelasko requested review from titu1994 and galv April 25, 2024 19:22
) -> torch.Tensor:
if isinstance(width, float):
width = length * width
value = torch.rand(x.shape[0], device=x.device, dtype=x.dtype) * width
Copy link
Collaborator

@galv galv Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that using dtype=x.dtype, followed by multiplication by width can cause a very subtle mistake here.

Basically, the goal is uniformly generate an integer in [0, width). Suppose our datatype was bfloat16. The smallest integer not representable by bfloat16 is 257, which seems precariously small. If width were 258, then I don't think this code could ever generate a value such that value.long() would equal 257 (but I could be wrong).

I think it is probably safest to do the entirety of this math using integers. The only concern is if width is a floating point value, which it sometimes is (I guess width is intended as a "stretch" factor on the length when it is a floating point value?). I can probably do this for you if you want.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point. Because of width typically being a float in practice, maybe instead we can hardcode float32 as the dtype for randn.

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
@pzelasko
Copy link
Collaborator Author

pzelasko commented May 9, 2024

Closing in favor of #9155

@pzelasko pzelasko closed this May 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants