Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic gradient accumulation #132

Closed
wants to merge 3 commits into from

Conversation

Adamits
Copy link
Collaborator

@Adamits Adamits commented Aug 14, 2023

No description provided.

Copy link
Contributor

@kylebgorman kylebgorman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have meditated on this a bit and I'm not sure it's a good idea (though it obviously solves the problem for research purposes). I'm bothered by the lack of a closed form solution (e.g., the while loop on line 120 onward). Do you think it's possible to find one quickly, or is this the only way?

Typed up some style notes in the meantime.

def get_batch_size_and_accumulation_steps(
batch_size: int, max_batch_size: int
) -> Tuple[int, int]:
"""Calculates a batch size and number of gradient accumulation steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style notes on the docstring (finicky I know):

  • can you get the first line on a single line, then give a longer one in full sentences? Usually something like:
```Calculates batch size and the number of gradient accumulation steps.

This calculates the statistics so that we can simulate an effective `max_batch_size`.
...
  • Can you put periods at the end of the args thingies?
  • I think the return description could be shorter, something like: Tuple[int, int]: a batch size/ accumulation factor tuple.

"""
if batch_size <= max_batch_size:
return batch_size, 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of the mid-function blank lines? I lose track of things.

@Adamits
Copy link
Collaborator Author

Adamits commented Aug 15, 2023

I have meditated on this a bit and I'm not sure it's a good idea (though it obviously solves the problem for research purposes). I'm bothered by the lack of a closed form solution (e.g., the while loop on line 120 onward). Do you think it's possible to find one quickly, or is this the only way?

That's fine, we certainly don't need this, can move it to an example script, or get rid of it.

I think the problem is just finding the smallest divisor of batch_size. I don't think I know a closed form solution, but surely I could optimize this.

edit: I suppose what I have currently can also break in edge cases.

Typed up some style notes in the meantime.

Cool!

@kylebgorman
Copy link
Contributor

edit: I suppose what I have currently can also break in edge cases.

I was also thinking about this a bit.

@Adamits
Copy link
Collaborator Author

Adamits commented Aug 16, 2023

To clarify what I said before, what we need is the smallest divisor of batch_size, k, s.t. batch_size / k <= max_batch_size. I think that theoretically, we should not need a divisor since we could accumulate a batch of size m followed by a batch of size n and then backprop, and I would think the effect would be the same as one step of a batch of size m + n. This is, however, a practical concern since PTL wants one batch size, to be accumulated for k steps, and I would not know how to implement it to allow for variable batch sizes...

Regardless, would you rather that I:

  1. Clean this up to work for all edge cases
  2. Do 1) but move this to an example script for sweeps
  3. Scrap this completely and leave it to users.

If 1/2, I think I need to keep the while loop, but I can ensure that we exit the loop fi we reach sqrt(original_batch_size). Then, we probably need to raise an exception that the requested batch size cannot be split into a new batch size < the max batch size, and either the batch size, or the max batch size needs to be adjusted. I think requesting a batch size that is a prime number is the only case where this happens (or otherwise if unreasonable max_batch_sizes are set)?

In my experiments, this never happened since I sampled batch sizes that are multiples of 16.

@kylebgorman
Copy link
Contributor

To clarify what I said before, what we need is the smallest divisor of batch_size, k, s.t. batch_size / k <= max_batch_size. I think that theoretically, we should not need a divisor since we could accumulate a batch of size m followed by a batch of size n and then backprop, and I would think the effect would be the same as one step of a batch of size m + n. This is, however, a practical concern since PTL wants one batch size, to be accumulated for k steps, and I would not know how to implement it to allow for variable batch sizes...

Makes sense.

Regardless, would you rather that I:

  1. Clean this up to work for all edge cases

Okay with this but not urgent.

  1. Do 1) but move this to an example script for sweeps

I think I'd recommend this, or alternatively making it a separate example altogether: i.e., just this function and some documentation explaining how you'd use it.

There is some argument for making it part of the sweeps thing: sweeps code needs other types of special casing too. (For another example, if you want to consider multiple layers in the pointer-generator during a sweep, you need to match encoder and decoder layers, hence the following).

  1. Scrap this completely and leave it to users.

No need to kill it.

If 1/2, I think I need to keep the while loop, but I can ensure that we exit the loop fi we reach sqrt(original_batch_size).

Good idea. I am more comfortable with the while if there's a hard theoretical bound (and this is said bound methinks).

Then, we probably need to raise an exception that the requested batch size cannot be split into a new batch size < the max batch size, and either the batch size, or the max batch size needs to be adjusted. I think requesting a batch size that is a prime number is the only case where this happens (or otherwise if unreasonable max_batch_sizes are set)?

The prime case definitely. I was thinking something might happen if the two numbers are coprime (or "relatively prime") but I just checked and that's fine.

@kylebgorman
Copy link
Contributor

Thinking about this a little, should we (and can we?) merge this the wandb examples directory instead? It seems like it might be something we could implement there instead of in the main library.

@Adamits
Copy link
Collaborator Author

Adamits commented Sep 13, 2023

Thinking about this a little, should we (and can we?) merge this the wandb examples directory instead? It seems like it might be something we could implement there instead of in the main library.

Sure, that is where I originally had it.

@kylebgorman
Copy link
Contributor

This is still open, of course, and discussed in #148. Having thought a bit more about this, my proposal is that:

  • We implement Lightning's support for finding a max batch size, enabled by flag (--scale_batch_size=power, --scale_batch_size=binsearch, etc.). This itself is pretty useful.
  • We then add another option (different flag?) which, which finds the max batch size automatically $n_{max}$ and if it's larger than the requested batch size $b$, uses one or the other solver here to set up the gradient accumulation trick. This is mostly useful in the context of a hyperparameter search.

@bonham79
Copy link
Collaborator

bonham79 commented Jun 3, 2024

Very late to party here, but I'm confused what this offers that's not done by accumulate_grad_batches?

@kylebgorman
Copy link
Contributor

This PR is might be dead but I understand @Adamits is working on #148, which will automatically estimate the maximum batch size that can be supported on a given accelerator and then also solve for best minibatch size and number of gradient accumulation steps needed to simulate a desired batch size.

@kylebgorman
Copy link
Contributor

Closing this as implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants