Skip to content

Commit

Permalink
Merge branch 'mcore_te_api_fixes' into 'main'
Browse files Browse the repository at this point in the history
Use TE provided APIs

See merge request ADLR/megatron-lm!690
  • Loading branch information
jaredcasper committed Jul 24, 2023
2 parents 5f03f6d + 35b2860 commit 65da5be
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions megatron/core/transformer/custom_layers/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from megatron.core.transformer.transformer_config import TransformerConfig


class TELayerNorm(te.pytorch.module.LayerNorm):
class TELayerNorm(te.pytorch.LayerNorm):
"""
Wrapper for the Transformer-Engine's `LayerNorm`.
"""
Expand All @@ -20,7 +20,7 @@ def __init__(
super().__init__(hidden_size=hidden_size, eps=eps, sequence_parallel=sequence_parallel)


class TELinear(te.pytorch.module.Linear):
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(self, input_size: int, output_size: int, config: TransformerConfig,
)


class TECoreAttention(te.pytorch.transformer.DotProductAttention):
class TECoreAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Expand Down
4 changes: 2 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ def custom_forward(*args, **kwargs):
l = 0
while l < self.num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
hidden_states = transformer_engine.pytorch.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
Expand All @@ -1540,7 +1540,7 @@ def custom_forward(*args, **kwargs):
for l in range(self.num_layers):
if l < self.recompute_num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
hidden_states = transformer_engine.pytorch.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
Expand Down

0 comments on commit 65da5be

Please sign in to comment.