diff --git a/olmo/config.py b/olmo/config.py index 2fba4a87c..ffe8db1ca 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -769,6 +769,11 @@ class TrainConfig(BaseConfig): Whether to run the Python profiler on batches 6, 7, and 8. """ + torch_profiling: bool = False + """ + Whether to run the PyTorch profiler on batches 6, 7, and 8. + """ + @property def autocast_precision(self) -> torch.dtype: if self.precision == "amp_bf16": diff --git a/olmo/train.py b/olmo/train.py index 1e33b33c9..9d4114adb 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -857,122 +857,162 @@ def fit(self): if wandb.run is not None: wandb.log(sys_metrics, step=0) - # Profiler + # Python Profiler stuff if self.cfg.python_profiling: - profiler = cProfile.Profile() + python_profiler = cProfile.Profile() else: - profiler = None + python_profiler = None + + # PyTorch Profiler stuff + if self.cfg.torch_profiling and get_global_rank() == 0: + from torch.profiler import schedule + + profiling_schedule = schedule(wait=1, warmup=5, active=3) + + def on_trace_ready(p): + profiler_output_dir = Path(self.cfg.save_folder) / "profiler" + profiler_output_dir.mkdir(exist_ok=True) + + output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32) + log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}") + output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32) + log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}") + + p.export_chrome_trace(str(profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz")) + p.export_stacks(str(profiler_output_dir / f"{p.step_num}.gpu.stacks"), "self_cuda_time_total") + p.export_stacks(str(profiler_output_dir / f"{p.step_num}.cpu.stacks"), "self_cpu_time_total") + + from torch.profiler import ProfilerActivity + + torch_profiler = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + with_stack=True, + schedule=profiling_schedule, + on_trace_ready=on_trace_ready, + ) + del profiling_schedule + else: + import contextlib + + torch_profiler = contextlib.nullcontext() # Train. first_batch: bool = True canceled: bool = False - for batch in self.train_loader: - # Bookkeeping. - # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all - # batches see the same number of tokens, which should be the case for language model pre-training - # (at least when drop_last=True). - # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that overhead. - # So for now I'm putting these assertions here so if the assumption is violated it will fail loudly. - batch_size, seq_len = batch["input_ids"].shape - assert seq_len == self.cfg.model.max_sequence_length - assert batch_size == self.cfg.device_train_batch_size - global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks - self.global_step += 1 - self.global_data_step += 1 - self.global_train_examples_seen += global_batch_size - self.global_train_tokens_seen += global_batch_size * seq_len - speed_monitor.batch_start( - self.global_train_tokens_seen, - batch_size * seq_len, # num tokens in batch for this device - # We start monitoring speed after the first batch since the first - # batch might be an outlier due to compiling and other initialization overhead. - record=not first_batch, - ) - # Run train step on batch. - metrics = self.train_step(batch) - - # Maybe collect other metrics. - if self.should_log_this_step(): - # Speed metrics. - metrics.update(speed_monitor.check()) - # System metrics. - metrics.update(self.system_metrics()) - # Learning rate metrics. - metrics.update(lr_monitor.check()) - - # Log metrics to console. - if self.global_step % self.cfg.console_log_interval == 0: - self.log_metrics_to_console(f"[step={self.global_step}/{self.cfg.max_duration}]", metrics) - - # Log metrics to W&B. - if ( - wandb.run is not None - and self.cfg.wandb is not None - and self.global_step % self.cfg.wandb.log_interval == 0 - ): - wandb.log(metrics, step=self.global_step) - - # Check if run should be canceled. - if self.global_step % self.cfg.canceled_check_interval == 0: - canceled = self.check_if_cancelled() - - # Maybe save sharded checkpoint. - if canceled or ( - self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0 - ): - log.info("Saving checkpoint...") - checkpoint_path = self.save_sharded_checkpoint() - log.info(f"Checkpoint saved to {checkpoint_path}") - - # Reset speed monitor so that we don't count the time taken to save checkpoints. - speed_monitor.reset() - - # Maybe save unsharded checkpoint. - if ( - not canceled # we already save a sharded checkpoint when canceled - and self.cfg.save_interval_unsharded is not None - and self.global_step % self.cfg.save_interval_unsharded == 0 - and self.cfg.save_num_unsharded_checkpoints_to_keep != 0 - ): - log.info("Saving unsharded checkpoint...") - checkpoint_path = self.save_unsharded_checkpoint() - log.info(f"Unsharded checkpoint saved to {checkpoint_path}") - - # Reset speed monitor so that we don't count the time taken to save checkpoints. - speed_monitor.reset() - - # Maybe run evaluations. - if not canceled and self.global_step % self.cfg.eval_interval == 0: - eval_metrics = self.eval() + with torch_profiler as p: + for batch in self.train_loader: + # Bookkeeping. + # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all + # batches see the same number of tokens, which should be the case for language model pre-training + # (at least when drop_last=True). + # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that + # overhead. So for now I'm putting these assertions here so if the assumption is violated it will + # fail loudly. + batch_size, seq_len = batch["input_ids"].shape + assert seq_len == self.cfg.model.max_sequence_length + assert batch_size == self.cfg.device_train_batch_size + global_batch_size = batch_size * get_world_size() # assumes batch size equal across ranks + self.global_step += 1 + self.global_data_step += 1 + self.global_train_examples_seen += global_batch_size + self.global_train_tokens_seen += global_batch_size * seq_len + speed_monitor.batch_start( + self.global_train_tokens_seen, + batch_size * seq_len, # num tokens in batch for this device + # We start monitoring speed after the first batch since the first + # batch might be an outlier due to compiling and other initialization overhead. + record=not first_batch, + ) + + # Run train step on batch. + metrics = self.train_step(batch) + + # Maybe collect other metrics. + if self.should_log_this_step(): + # Speed metrics. + metrics.update(speed_monitor.check()) + # System metrics. + metrics.update(self.system_metrics()) + # Learning rate metrics. + metrics.update(lr_monitor.check()) + + # Log metrics to console. + if self.global_step % self.cfg.console_log_interval == 0: + self.log_metrics_to_console(f"[step={self.global_step}/{self.cfg.max_duration}]", metrics) # Log metrics to W&B. - if wandb.run is not None: - wandb.log(eval_metrics, step=self.global_step) - - # Reset speed monitor so that we don't count the time taken to run evaluations. - speed_monitor.reset() - - # Reset model to 'train' mode. - self.fsdp_model.train() - - # End of batch. - first_batch = False - - if canceled: - break - - # Profiler stuff - # We do this now, at the bottom of this loop, so we capture the work of getting the next batch. - if profiler is not None: - if self.global_step == 5: - profiler.enable() - elif self.global_step == 8: - profiler.disable() - profiler.print_stats(sort=SortKey.CUMULATIVE) - profiler = None - else: - log.info("Training loop complete") + if ( + wandb.run is not None + and self.cfg.wandb is not None + and self.global_step % self.cfg.wandb.log_interval == 0 + ): + wandb.log(metrics, step=self.global_step) + + # Check if run should be canceled. + if self.global_step % self.cfg.canceled_check_interval == 0: + canceled = self.check_if_cancelled() + + # Maybe save sharded checkpoint. + if canceled or ( + self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0 + ): + log.info("Saving checkpoint...") + checkpoint_path = self.save_sharded_checkpoint() + log.info(f"Checkpoint saved to {checkpoint_path}") + + # Reset speed monitor so that we don't count the time taken to save checkpoints. + speed_monitor.reset() + + # Maybe save unsharded checkpoint. + if ( + not canceled # we already save a sharded checkpoint when canceled + and self.cfg.save_interval_unsharded is not None + and self.global_step % self.cfg.save_interval_unsharded == 0 + and self.cfg.save_num_unsharded_checkpoints_to_keep != 0 + ): + log.info("Saving unsharded checkpoint...") + checkpoint_path = self.save_unsharded_checkpoint() + log.info(f"Unsharded checkpoint saved to {checkpoint_path}") + + # Reset speed monitor so that we don't count the time taken to save checkpoints. + speed_monitor.reset() + + # Maybe run evaluations. + if not canceled and self.global_step % self.cfg.eval_interval == 0: + eval_metrics = self.eval() + + # Log metrics to W&B. + if wandb.run is not None: + wandb.log(eval_metrics, step=self.global_step) + + # Reset speed monitor so that we don't count the time taken to run evaluations. + speed_monitor.reset() + + # Reset model to 'train' mode. + self.fsdp_model.train() + + # End of batch. + first_batch = False + if p is not None: + p.step() + + if canceled: + break + + # Python Profiler stuff + # We do this now, at the bottom of this loop, so we capture the work of getting the next batch. + if python_profiler is not None: + if self.global_step == 5: + python_profiler.enable() + elif self.global_step == 8: + python_profiler.disable() + python_profiler.print_stats(sort=SortKey.CUMULATIVE) + python_profiler = None + else: + log.info("Training loop complete") # Save final unsharded model-only checkpoint. if not canceled and self.cfg.save_interval_unsharded is not None: