Skip to content

Commit

Permalink
Merge pull request #383 from allenai/epwalsh/start-new-epoch
Browse files Browse the repository at this point in the history
Add ability to restart on new epoch
  • Loading branch information
epwalsh authored Nov 27, 2023
2 parents f09a500 + b8ca94d commit e16e606
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 32 deletions.
5 changes: 5 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,11 @@ class TrainConfig(BaseConfig):
Used to seed all initial RNG states.
"""

epoch: int = 0
"""
Increment this when starting a new epoch.
"""

dry_run: bool = False
"""
If ``True``, don't actually train.
Expand Down
2 changes: 1 addition & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
IterableDataset(
dataset, # type: ignore
train_config.global_train_batch_size,
seed=train_config.seed,
seed=train_config.seed + train_config.epoch,
shuffle=True,
drop_last=train_config.data.drop_last,
max_examples=train_config.global_train_batch_size * train_config.max_duration,
Expand Down
64 changes: 33 additions & 31 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,11 @@ class Trainer:
train_loader: DataLoader
device: torch.device
evaluators: List[Evaluator]
epoch: int = 0
global_step: int = 0
global_data_step: int = 0
"""This is now redundant since adding 'global_train_examples_seen'."""
global_train_examples_seen: int = 0
"""Tracks the global number of training examples seen for the purpose of restoring the dataset
position on restarts."""
global_train_examples_seen_this_epoch: int = 0
"""Tracks the global number of training examples seen in the current epoch for the purpose of restoring
the data loader position on restarts."""
global_train_tokens_seen: int = 0
"""Tracks the global total number of tokens trained on."""
checkpoints: List[Path] = field(default_factory=list)
Expand All @@ -118,9 +117,9 @@ class Trainer:

def trainer_state_dict(self) -> Dict[str, Any]:
return {
"epoch": self.epoch,
"global_step": self.global_step,
"global_data_step": self.global_data_step,
"global_train_examples_seen": self.global_train_examples_seen,
"global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
"global_train_tokens_seen": self.global_train_tokens_seen,
"world_size": get_world_size(),
"checkpoints": self.checkpoints,
Expand All @@ -147,40 +146,44 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
]

# Dataset / dataloader position.
checkpoint_epoch = state_dict.get("epoch", 0)
self.global_step = state_dict["global_step"]
self.global_data_step = state_dict["global_data_step"]
self.global_train_examples_seen = state_dict.get( # newer addition
"global_train_examples_seen", self.global_data_step * self.cfg.global_train_batch_size
self.global_train_examples_seen_this_epoch = state_dict.get(
"global_train_examples_seen_this_epoch",
state_dict.get( # for backwards compatibility
"global_train_examples_seen",
state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size,
),
)
self.global_train_tokens_seen = state_dict.get( # newer addition
self.global_train_tokens_seen = state_dict.get(
"global_train_tokens_seen",
self.global_data_step * self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length,
state_dict.get("global_data_step", self.global_step) # for backwards compatibility
* self.cfg.global_train_batch_size
* self.cfg.model.max_sequence_length,
)

if not self.cfg.restore_dataloader:
self.global_data_step = 0
self.global_train_examples_seen = 0
self.epoch = 0
self.global_train_tokens_seen = 0
elif self.cfg.fast_forward_batches:
self.global_data_step += self.cfg.fast_forward_batches
self.global_train_examples_seen_this_epoch = 0
elif checkpoint_epoch != self.epoch:
log.info(f"Starting new epoch (epoch = {self.epoch})")
self.global_train_examples_seen_this_epoch = 0

if self.cfg.fast_forward_batches:
log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
# Technically we don't "see" these batches that we fast-forward through, but we use
# this variable to update the position of the dataset so we need to include them here.
self.global_train_examples_seen += self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
self.global_train_examples_seen_this_epoch += (
self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
)
# NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
# that variable is meant to track the actual number of tokens trained on.

if self.global_data_step > 0:
if self.global_data_step > self.global_step:
log.info(
f"Fast-forwarding data loader to step {self.global_step:,d}+{self.global_data_step-self.global_step:,d} "
f"({self.global_train_examples_seen:,d} examples)"
)
else:
log.info(
f"Fast-forwarding data loader to step {self.global_data_step:,d} "
f"({self.global_train_examples_seen:,d} examples)"
)
if self.global_train_examples_seen_this_epoch > 0:
assert isinstance(self.train_loader.dataset, IterableDataset)
self.train_loader.dataset.start_index = self.global_train_examples_seen
log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
self.train_loader.dataset.start_index = self.global_train_examples_seen_this_epoch

# Reset learning rate and weight decay to the values from the config, not the checkpoint.
log.info("Resetting learning rate...")
Expand Down Expand Up @@ -789,8 +792,7 @@ def on_trace_ready(p):
assert batch_size == self.cfg.device_train_batch_size
global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
self.global_step += 1
self.global_data_step += 1
self.global_train_examples_seen += global_batch_size
self.global_train_examples_seen_this_epoch += global_batch_size
self.global_train_tokens_seen += global_batch_size * seq_len
speed_monitor.batch_start(
self.global_train_tokens_seen,
Expand Down
1 change: 1 addition & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
# Consolidate components into `Trainer` object.
with Trainer(
cfg=cfg,
epoch=cfg.epoch,
model=olmo_model,
fsdp_model=fsdp_model,
optim=optim,
Expand Down

0 comments on commit e16e606

Please sign in to comment.