Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Torch Profiler Support #1226

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ If you need to supply a hostfile for use with the MPI-based DeepSpeed launcher,

# Profiling

We support profiling with Nsight Systems and PyTorch Memory Profiling.
We support profiling with Nsight Systems, the PyTorch Profiler, and PyTorch Memory Profiling.

## Nsight Systems Profiling

Expand All @@ -656,6 +656,15 @@ The generated output file can then by viewed with the Nsight Systems GUI:

![Alt text](images/nsight_profiling.png)

## PyTorch Profiling

To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop`.

The PyTorch profiler will save traces to your `tensorboard` log directory. You can view these traces within
TensorBoard by following the steps [here](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html).

![Alt text](images/pytorch_profiling.png)

## PyTorch Memory Profiling

To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path`.
Expand Down
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 0d5992f
Default = b68ba6d

current git hash of repository

Expand Down
Binary file added images/pytorch_profiling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions megatron/data/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}

} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down Expand Up @@ -660,9 +660,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down
22 changes: 22 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,28 @@ def train(

# to monitor if we've skipped many iterations in a row and trigger an early exit
overflow_monitor = OverflowMonitor(optimizer)

if neox_args.profile:
schedule = torch.profiler.schedule(
wait=neox_args.profile_step_start,
warmup=1,
active=neox_args.profile_step_stop - neox_args.profile_step_start,
)
prof = torch.profiler.profile(
schedule=schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
neox_args.tensorboard_dir
),
record_shapes=True,
profile_memory=True,
with_flops=True,
with_modules=True,
with_stack=True,
)
prof.start()
while iteration < neox_args.train_iters:
if neox_args.profile:
prof.step()
if neox_args.profile and iteration == neox_args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
loss_dict, skipped_iter = train_step(
Expand All @@ -983,6 +1004,7 @@ def train(
)
if neox_args.profile and iteration == neox_args.profile_step_stop:
torch.cuda.cudart().cudaProfilerStop()
prof.stop()
iteration += 1
neox_args.iteration = iteration
if neox_args.precision == "fp16":
Expand Down