-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic gradient accumulation #132
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have meditated on this a bit and I'm not sure it's a good idea (though it obviously solves the problem for research purposes). I'm bothered by the lack of a closed form solution (e.g., the while
loop on line 120 onward). Do you think it's possible to find one quickly, or is this the only way?
Typed up some style notes in the meantime.
def get_batch_size_and_accumulation_steps( | ||
batch_size: int, max_batch_size: int | ||
) -> Tuple[int, int]: | ||
"""Calculates a batch size and number of gradient accumulation steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style notes on the docstring (finicky I know):
- can you get the first line on a single line, then give a longer one in full sentences? Usually something like:
```Calculates batch size and the number of gradient accumulation steps.
This calculates the statistics so that we can simulate an effective `max_batch_size`.
...
- Can you put periods at the end of the args thingies?
- I think the return description could be shorter, something like:
Tuple[int, int]: a batch size/ accumulation factor tuple.
""" | ||
if batch_size <= max_batch_size: | ||
return batch_size, 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get rid of the mid-function blank lines? I lose track of things.
That's fine, we certainly don't need this, can move it to an example script, or get rid of it. I think the problem is just finding the smallest divisor of edit: I suppose what I have currently can also break in edge cases.
Cool! |
I was also thinking about this a bit. |
To clarify what I said before, what we need is the smallest divisor of Regardless, would you rather that I:
If 1/2, I think I need to keep the while loop, but I can ensure that we exit the loop fi we reach In my experiments, this never happened since I sampled batch sizes that are multiples of 16. |
Makes sense.
Okay with this but not urgent.
I think I'd recommend this, or alternatively making it a separate example altogether: i.e., just this function and some documentation explaining how you'd use it. There is some argument for making it part of the sweeps thing: sweeps code needs other types of special casing too. (For another example, if you want to consider multiple layers in the pointer-generator during a sweep, you need to match encoder and decoder layers, hence the following).
No need to kill it.
Good idea. I am more comfortable with the
The prime case definitely. I was thinking something might happen if the two numbers are coprime (or "relatively prime") but I just checked and that's fine. |
Thinking about this a little, should we (and can we?) merge this the wandb examples directory instead? It seems like it might be something we could implement there instead of in the main library. |
Sure, that is where I originally had it. |
This is still open, of course, and discussed in #148. Having thought a bit more about this, my proposal is that:
|
Very late to party here, but I'm confused what this offers that's not done by |
Closing this as implemented. |
No description provided.