-
Notifications
You must be signed in to change notification settings - Fork 468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability to restart on new epoch #383
Conversation
You can set the epoch via the option `--epoch=[INTEGER]`. This automatically handles changing the data order each epoch by setting the data seed to `seed + epoch`. So `--epoch` is the only flag you need to set when restarting on a new epoch. Everything else in the config can stay the same. Note that we count epochs starting from 0. So to start the 2nd epoch you would add the flag `--epoch=1`.
olmo/train.py
Outdated
@@ -147,40 +152,47 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: | |||
] | |||
|
|||
# Dataset / dataloader position. | |||
checkpoint_epoch = state_dict.get("epoch", 0) | |||
self.global_step = state_dict["global_step"] | |||
self.global_data_step = state_dict["global_data_step"] | |||
self.global_train_examples_seen = state_dict.get( # newer addition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason to keep self.global_train_examples_seen
? You can perform this state_dict.get()
backwards compatibility check without keeping the global_train_examples_seen
variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea good point. Turns out we don't need global_data_step
either. 4d6e61c
`global_train_examples_seen` and `global_data_step` no longer needed
olmo/train.py
Outdated
"global_train_tokens_seen", | ||
self.global_data_step * self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length, | ||
state_dict.get("global_data_step", 0) # for backwards compatibility |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will result in throughput/total_tokens
being reset to 0 if global_train_tokens_seen
and global_data_step
are both not present. Maybe state_dict.get("global_data_step", self.global_step)
is safer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely, good catch: b8ca94d
You can set the epoch via the option
--epoch=[INTEGER]
. Thisautomatically handles changing the data order each epoch by setting the
data seed to
seed + epoch
. So--epoch
is the only flag you need toset when restarting on a new epoch. Everything else in the config can
stay the same.
Note that we count epochs starting from 0. So to start the 2nd epoch you would add
the flag
--epoch=1
.I cherry-picked this commit from #350, which has now started its 2nd epoch.