Skip to content

Commit

Permalink
Merge pull request #265 from allenai/LayerNormAffine-ManualLayerNorm-…
Browse files Browse the repository at this point in the history
…Profiling

Profiling
  • Loading branch information
dirkgr committed Sep 20, 2023
2 parents 2df922b + 69ccb13 commit ef85d5c
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 107 deletions.
5 changes: 5 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,11 @@ class TrainConfig(BaseConfig):
Whether to run the Python profiler on batches 6, 7, and 8.
"""

torch_profiling: bool = False
"""
Whether to run the PyTorch profiler on batches 6, 7, and 8.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
254 changes: 147 additions & 107 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,122 +857,162 @@ def fit(self):
if wandb.run is not None:
wandb.log(sys_metrics, step=0)

# Profiler
# Python Profiler stuff
if self.cfg.python_profiling:
profiler = cProfile.Profile()
python_profiler = cProfile.Profile()
else:
profiler = None
python_profiler = None

# PyTorch Profiler stuff
if self.cfg.torch_profiling and get_global_rank() == 0:
from torch.profiler import schedule

profiling_schedule = schedule(wait=1, warmup=5, active=3)

def on_trace_ready(p):
profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
profiler_output_dir.mkdir(exist_ok=True)

output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32)
log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}")
output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")

p.export_chrome_trace(str(profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
p.export_stacks(str(profiler_output_dir / f"{p.step_num}.gpu.stacks"), "self_cuda_time_total")
p.export_stacks(str(profiler_output_dir / f"{p.step_num}.cpu.stacks"), "self_cpu_time_total")

from torch.profiler import ProfilerActivity

torch_profiler = torch.profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=False,
profile_memory=False,
with_stack=True,
schedule=profiling_schedule,
on_trace_ready=on_trace_ready,
)
del profiling_schedule
else:
import contextlib

torch_profiler = contextlib.nullcontext()

# Train.
first_batch: bool = True
canceled: bool = False
for batch in self.train_loader:
# Bookkeeping.
# NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
# batches see the same number of tokens, which should be the case for language model pre-training
# (at least when drop_last=True).
# Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that overhead.
# So for now I'm putting these assertions here so if the assumption is violated it will fail loudly.
batch_size, seq_len = batch["input_ids"].shape
assert seq_len == self.cfg.model.max_sequence_length
assert batch_size == self.cfg.device_train_batch_size
global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
self.global_step += 1
self.global_data_step += 1
self.global_train_examples_seen += global_batch_size
self.global_train_tokens_seen += global_batch_size * seq_len
speed_monitor.batch_start(
self.global_train_tokens_seen,
batch_size * seq_len, # num tokens in batch for this device
# We start monitoring speed after the first batch since the first
# batch might be an outlier due to compiling and other initialization overhead.
record=not first_batch,
)

# Run train step on batch.
metrics = self.train_step(batch)

# Maybe collect other metrics.
if self.should_log_this_step():
# Speed metrics.
metrics.update(speed_monitor.check())
# System metrics.
metrics.update(self.system_metrics())
# Learning rate metrics.
metrics.update(lr_monitor.check())

# Log metrics to console.
if self.global_step % self.cfg.console_log_interval == 0:
self.log_metrics_to_console(f"[step={self.global_step}/{self.cfg.max_duration}]", metrics)

# Log metrics to W&B.
if (
wandb.run is not None
and self.cfg.wandb is not None
and self.global_step % self.cfg.wandb.log_interval == 0
):
wandb.log(metrics, step=self.global_step)

# Check if run should be canceled.
if self.global_step % self.cfg.canceled_check_interval == 0:
canceled = self.check_if_cancelled()

# Maybe save sharded checkpoint.
if canceled or (
self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0
):
log.info("Saving checkpoint...")
checkpoint_path = self.save_sharded_checkpoint()
log.info(f"Checkpoint saved to {checkpoint_path}")

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

# Maybe save unsharded checkpoint.
if (
not canceled # we already save a sharded checkpoint when canceled
and self.cfg.save_interval_unsharded is not None
and self.global_step % self.cfg.save_interval_unsharded == 0
and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
):
log.info("Saving unsharded checkpoint...")
checkpoint_path = self.save_unsharded_checkpoint()
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

# Maybe run evaluations.
if not canceled and self.global_step % self.cfg.eval_interval == 0:
eval_metrics = self.eval()
with torch_profiler as p:
for batch in self.train_loader:
# Bookkeeping.
# NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
# batches see the same number of tokens, which should be the case for language model pre-training
# (at least when drop_last=True).
# Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
# overhead. So for now I'm putting these assertions here so if the assumption is violated it will
# fail loudly.
batch_size, seq_len = batch["input_ids"].shape
assert seq_len == self.cfg.model.max_sequence_length
assert batch_size == self.cfg.device_train_batch_size
global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks
self.global_step += 1
self.global_data_step += 1
self.global_train_examples_seen += global_batch_size
self.global_train_tokens_seen += global_batch_size * seq_len
speed_monitor.batch_start(
self.global_train_tokens_seen,
batch_size * seq_len, # num tokens in batch for this device
# We start monitoring speed after the first batch since the first
# batch might be an outlier due to compiling and other initialization overhead.
record=not first_batch,
)

# Run train step on batch.
metrics = self.train_step(batch)

# Maybe collect other metrics.
if self.should_log_this_step():
# Speed metrics.
metrics.update(speed_monitor.check())
# System metrics.
metrics.update(self.system_metrics())
# Learning rate metrics.
metrics.update(lr_monitor.check())

# Log metrics to console.
if self.global_step % self.cfg.console_log_interval == 0:
self.log_metrics_to_console(f"[step={self.global_step}/{self.cfg.max_duration}]", metrics)

# Log metrics to W&B.
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)

# Reset speed monitor so that we don't count the time taken to run evaluations.
speed_monitor.reset()

# Reset model to 'train' mode.
self.fsdp_model.train()

# End of batch.
first_batch = False

if canceled:
break

# Profiler stuff
# We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
if profiler is not None:
if self.global_step == 5:
profiler.enable()
elif self.global_step == 8:
profiler.disable()
profiler.print_stats(sort=SortKey.CUMULATIVE)
profiler = None
else:
log.info("Training loop complete")
if (
wandb.run is not None
and self.cfg.wandb is not None
and self.global_step % self.cfg.wandb.log_interval == 0
):
wandb.log(metrics, step=self.global_step)

# Check if run should be canceled.
if self.global_step % self.cfg.canceled_check_interval == 0:
canceled = self.check_if_cancelled()

# Maybe save sharded checkpoint.
if canceled or (
self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0
):
log.info("Saving checkpoint...")
checkpoint_path = self.save_sharded_checkpoint()
log.info(f"Checkpoint saved to {checkpoint_path}")

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

# Maybe save unsharded checkpoint.
if (
not canceled # we already save a sharded checkpoint when canceled
and self.cfg.save_interval_unsharded is not None
and self.global_step % self.cfg.save_interval_unsharded == 0
and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
):
log.info("Saving unsharded checkpoint...")
checkpoint_path = self.save_unsharded_checkpoint()
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

# Maybe run evaluations.
if not canceled and self.global_step % self.cfg.eval_interval == 0:
eval_metrics = self.eval()

# Log metrics to W&B.
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)

# Reset speed monitor so that we don't count the time taken to run evaluations.
speed_monitor.reset()

# Reset model to 'train' mode.
self.fsdp_model.train()

# End of batch.
first_batch = False
if p is not None:
p.step()

if canceled:
break

# Python Profiler stuff
# We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
if python_profiler is not None:
if self.global_step == 5:
python_profiler.enable()
elif self.global_step == 8:
python_profiler.disable()
python_profiler.print_stats(sort=SortKey.CUMULATIVE)
python_profiler = None
else:
log.info("Training loop complete")

# Save final unsharded model-only checkpoint.
if not canceled and self.cfg.save_interval_unsharded is not None:
Expand Down

0 comments on commit ef85d5c

Please sign in to comment.