Skip to content

Commit

Permalink
fix dataloader max steps
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed May 24, 2023
1 parent 7ffe204 commit 6d29ee4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
3 changes: 2 additions & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def build_eval_dataloader(


def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
assert train_config.device_train_batch_size is not None
collator = DataCollator(
pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
)
Expand All @@ -50,7 +51,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
seed=train_config.seed,
shuffle=True,
drop_last=train_config.data.drop_last,
max_steps=train_config.max_duration,
max_steps=train_config.device_train_batch_size * train_config.max_duration,
),
batch_size=train_config.device_train_batch_size,
collate_fn=collator,
Expand Down
13 changes: 8 additions & 5 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ def fast_forward_batches(self):
else:
log.info(f"Fast-forwarding data loader to {self.global_data_step}")
assert isinstance(self.train_loader.dataset, IterableDataset)
self.train_loader.dataset.start_step = self.global_data_step
assert self.cfg.device_train_batch_size is not None
self.train_loader.dataset.start_step = self.cfg.device_train_batch_size * self.global_data_step

def save_checkpoint(self, checkpoint_type: CheckpointType = CheckpointType.sharded) -> Path:
if checkpoint_type == CheckpointType.sharded:
Expand Down Expand Up @@ -734,10 +735,6 @@ def fit(self):
# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

if time_limit_reached:
log.info("Training time limit reached, ending early")
break

# Maybe run evaluations.
if self.global_step % self.cfg.eval_interval == 0:
eval_metrics = self.eval()
Expand All @@ -755,6 +752,12 @@ def fit(self):
# End of batch.
first_batch = False

if time_limit_reached:
log.info("Training time limit reached, ending early")
break
else:
log.info("Training loop complete")

# Save final unsharded model-only checkpoint.
log.info("Saving final unsharded model checkpoint...")
checkpoint_path = self.save_unsharded_checkpoint()
Expand Down
1 change: 1 addition & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def main(cfg: TrainConfig) -> None:
if not cfg.dry_run:
log.info("Starting training...")
trainer.fit()
log.info("Training complete")
else:
log.info("Dry run complete")

Expand Down

0 comments on commit 6d29ee4

Please sign in to comment.