Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/EleutherAI/gpt-neox into dm…
Browse files Browse the repository at this point in the history
…oe_integration
  • Loading branch information
DayOfThePenguin committed May 22, 2024
2 parents 33e41d7 + d3d59f2 commit 613aeb9
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 15 deletions.
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ To install the remaining basic dependencies, run:
pip install -r requirements/requirements.txt
pip install -r requirements/requirements-wandb.txt # optional, if logging using WandB
pip install -r requirements/requirements-tensorboard.txt # optional, if logging via tensorboard
python ./megatron/fused_kernels/setup.py install # optional, if using fused kernels
```

from the repository root.
Expand All @@ -106,6 +105,16 @@ from the repository root.
</aside>

### Fused Kernels
We now support AMD GPUs (MI100, MI250X) through JIT fused-kernel compilation. Fused kernels will be built and loaded as needed. To avoid waiting during job launching, you can also do the following for manual pre-build:

```python
python
from megatron.fused_kernels import load
load()
```
This will automatically adapts building process over different GPU vendors (AMD, NVIDIA) without platform specific code changes. To further test fused kernels using `pytest`, use `pytest tests/model/test_fused_kernels.py`

### Flash Attention

To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
Expand Down Expand Up @@ -581,7 +590,7 @@ If you need to supply a hostfile for use with the MPI-based DeepSpeed launcher,

# Profiling

We support profiling with Nsight Systems and PyTorch Memory Profiling.
We support profiling with Nsight Systems, the PyTorch Profiler, and PyTorch Memory Profiling.

## Nsight Systems Profiling

Expand All @@ -597,6 +606,15 @@ The generated output file can then by viewed with the Nsight Systems GUI:

![Alt text](images/nsight_profiling.png)

## PyTorch Profiling

To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop`.

The PyTorch profiler will save traces to your `tensorboard` log directory. You can view these traces within
TensorBoard by following the steps [here](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html).

![Alt text](images/pytorch_profiling.png)

## PyTorch Memory Profiling

To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path`.
Expand Down
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 3ef5b66
Default = 1b85a2f

current git hash of repository

Expand Down
Binary file added images/pytorch_profiling.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def _cpp_extention_load_helper(
srcpath / "fused_rotary_positional_embedding.cpp",
srcpath / "fused_rotary_positional_embedding_cuda.cu",
]
fused_rotary_positional_embedding_cuda = _cpp_extention_load_helper(
"fused_rotary_positional_embedding_cuda",
fused_rotary_positional_embedding = _cpp_extention_load_helper(
"fused_rotary_positional_embedding",
sources,
extra_cuda_flags,
extra_include_paths,
Expand Down Expand Up @@ -174,7 +174,7 @@ def load_fused_kernels():
print(e)
print("=" * 100)
print(
f"ERROR: Fused kernels configured but not properly installed. Please run `pip install {str(srcpath)}` to install them"
f"ERROR: Fused kernels configured but not properly installed. Please run `from megatron.fused_kernels import load()` then `load()` to load them correctly"
)
print("=" * 100)
exit()
Expand Down
7 changes: 5 additions & 2 deletions megatron/model/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
from torch.nn import LayerNorm as LayerNorm
from .fused_layer_norm import MixedFusedLayerNorm


def get_norm(neox_args):
Expand All @@ -23,7 +22,11 @@ def get_norm(neox_args):
eps = neox_args.rms_norm_epsilon
elif neox_args.norm == "layernorm":
eps = neox_args.layernorm_epsilon
norm = MixedFusedLayerNorm if neox_args.layernorm_fusion else LayerNorm
if neox_args.layernorm_fusion:
from .fused_layer_norm import MixedFusedLayerNorm
norm = MixedFusedLayerNorm
else:
norm = LayerNorm
elif neox_args.norm == "scalenorm":
eps = neox_args.scalenorm_epsilon
norm = ScaleNorm
Expand Down
4 changes: 2 additions & 2 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,8 +1063,8 @@ def calculate_derived(self):
), "Mamba does not yet have dropout implemented"
if "rwkv" in self.attention_config:
assert (
not self.is_pipe_parallel and self.model_parallel_size == 1
), "RWKV not currently compatible with parallelism"
self.model_parallel_size == 1
), "RWKV not currently compatible with model parallelism"
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV"
assert (
Expand Down
22 changes: 22 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,28 @@ def train(

# to monitor if we've skipped many iterations in a row and trigger an early exit
overflow_monitor = OverflowMonitor(optimizer)

if neox_args.profile:
schedule = torch.profiler.schedule(
wait=neox_args.profile_step_start,
warmup=1,
active=neox_args.profile_step_stop - neox_args.profile_step_start,
)
prof = torch.profiler.profile(
schedule=schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
neox_args.tensorboard_dir
),
record_shapes=True,
profile_memory=True,
with_flops=True,
with_modules=True,
with_stack=True,
)
prof.start()
while iteration < neox_args.train_iters:
if neox_args.profile:
prof.step()
if neox_args.profile and iteration == neox_args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
loss_dict, skipped_iter = train_step(
Expand All @@ -904,6 +925,7 @@ def train(
)
if neox_args.profile and iteration == neox_args.profile_step_stop:
torch.cuda.cudart().cudaProfilerStop()
prof.stop()
iteration += 1
neox_args.iteration = iteration
if neox_args.precision == "fp16":
Expand Down
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
git+https://github.com/EleutherAI/DeeperSpeed.git@02e2ebf7dee6aaab3d89094ed470a4609763c742#egg=deepspeed
deepspeed@git+https://github.com/EleutherAI/DeeperSpeed.git@02e2ebf7dee6aaab3d89094ed470a4609763c742#egg=deepspeed
ftfy>=6.0.1
git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
lm_dataformat@git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
huggingface_hub>=0.11.0
jinja2==3.1.4
lm_eval>=0.4.0,<=0.4.1
Expand Down
4 changes: 1 addition & 3 deletions tests/model/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
)


@pytest.mark.xfail(
reason="ModuleNotFoundError: No module named 'scaled_masked_softmax_cuda'"
)
@pytest.mark.xfail(reason="SystemExit: None")
def test_load_fused_kernels():
load()
try:
Expand Down

0 comments on commit 613aeb9

Please sign in to comment.