Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Megatron perceiver with tensor parallelism only #4318

Merged
merged 162 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
162 commits
Select commit Hold shift + click to select a range
99037b5
Temp
MaximumEntropy Mar 28, 2022
f37fe52
Merge
MaximumEntropy Mar 28, 2022
12e6574
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Mar 29, 2022
5d2cdbb
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Mar 29, 2022
db6a25b
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Mar 31, 2022
15513e3
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 3, 2022
06baa1c
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 4, 2022
3f6ed7e
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 6, 2022
c0b029e
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 6, 2022
423b4ce
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 14, 2022
afdbcdc
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 15, 2022
3153ab5
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 18, 2022
a6fe975
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 19, 2022
2daf757
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Apr 22, 2022
13e79f8
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 3, 2022
9b6e43a
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 3, 2022
cdedcf6
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 6, 2022
016cd1f
Add megatron dataset
MaximumEntropy May 8, 2022
d1639d4
Update config and fix global batch fetcher
MaximumEntropy May 8, 2022
3b7c91a
Add dataset class
MaximumEntropy May 9, 2022
2e22cfc
Update comments
MaximumEntropy May 9, 2022
590f40e
Style
MaximumEntropy May 9, 2022
17bd54e
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 9, 2022
340feb5
Merge branch 'nmt_memmap_dataloader' of github.com:NVIDIA/NeMo into n…
MaximumEntropy May 9, 2022
968938b
Update yaml
MaximumEntropy May 9, 2022
8b8134d
Fix duplicate yaml key
MaximumEntropy May 9, 2022
7a7cb89
Translate method and preprocess script for raw text
MaximumEntropy May 11, 2022
23836dc
Merge branch 'main' of github.com:NVIDIA/NeMo into nmt_memmap_dataloader
MaximumEntropy May 11, 2022
0c4657a
Merge branch 'main' into nmt_memmap_dataloader
michalivne May 11, 2022
a144f10
Style
MaximumEntropy May 11, 2022
6443799
Merge branch 'nmt_memmap_dataloader' of github.com:NVIDIA/NeMo into n…
MaximumEntropy May 11, 2022
8b59cca
Remove pdb
MaximumEntropy May 11, 2022
d3917f3
Fix arg name
MaximumEntropy May 11, 2022
a5a6ee5
Fix other arg
MaximumEntropy May 11, 2022
b2c3e9c
Change sampler back
MaximumEntropy May 11, 2022
f46f0d1
Move back to global batch fetcher to use distributed sampler
MaximumEntropy May 11, 2022
f225b4b
Add text memmap data
MaximumEntropy May 11, 2022
373e1f5
Update monitor
MaximumEntropy May 11, 2022
f04b4b5
Fixes for PP
MaximumEntropy May 12, 2022
201e1be
Remove unused import
MaximumEntropy May 12, 2022
79f676d
Truncate examples in text memmap
MaximumEntropy May 12, 2022
7fc54b8
NMT training batch interpolation key
MaximumEntropy May 12, 2022
a0a446b
tarred data fix
MaximumEntropy May 12, 2022
2baef4b
Change dataset type check
MaximumEntropy May 12, 2022
fc7eb49
Fix sampler
MaximumEntropy May 12, 2022
e67b465
Pass dataset cfg to determine type
MaximumEntropy May 13, 2022
bbdce07
Log global step on validation step as well
MaximumEntropy May 13, 2022
7f38112
Fix NMT model saving with artifacts
MaximumEntropy May 13, 2022
3c328a9
Initialize DDP in decode if not initialized. Needed for inference onl…
MaximumEntropy May 14, 2022
708868c
Megatron NMT inference script
MaximumEntropy May 14, 2022
bb8467d
Inference config file
MaximumEntropy May 14, 2022
f33a546
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 14, 2022
ef0e632
hardcode max delta temporarily
MaximumEntropy May 14, 2022
16a1a10
Merge branch 'nmt_memmap_dataloader' of github.com:NVIDIA/NeMo into n…
MaximumEntropy May 14, 2022
5fde572
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 14, 2022
0230165
Style
MaximumEntropy May 17, 2022
e2ddd34
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 17, 2022
19d5e92
detokenizer if processor is not none
MaximumEntropy May 18, 2022
5f8de2a
Sampler config
MaximumEntropy May 18, 2022
59406ab
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 18, 2022
16386c7
Compat with configs without sampler arg
MaximumEntropy May 18, 2022
bd2aa5d
Merge branch 'nmt_memmap_dataloader' of github.com:NVIDIA/NeMo into n…
MaximumEntropy May 18, 2022
7b486b0
Style
MaximumEntropy May 18, 2022
4a84a1f
Comment for validation dataset type
MaximumEntropy May 18, 2022
b382932
Fix tokenizer building
MaximumEntropy May 18, 2022
7faa950
CI test for megatron nmt
MaximumEntropy May 18, 2022
57bb39c
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 18, 2022
20a1b5d
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 18, 2022
f64c7f2
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 20, 2022
7e6c4ea
Fix tokenizer in restore
MaximumEntropy May 20, 2022
0727bdf
Style
MaximumEntropy May 20, 2022
b2fadff
Merge branch 'nmt_memmap_dataloader' of github.com:NVIDIA/NeMo into n…
MaximumEntropy May 20, 2022
1c6a72c
O2 restore from fix
MaximumEntropy May 20, 2022
1b4d414
Remove print
MaximumEntropy May 21, 2022
c30e0d8
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 21, 2022
5fa2553
Change tokenizer model name in config
MaximumEntropy May 21, 2022
405b65b
Merge branch 'nmt_memmap_dataloader' of github.com:NVIDIA/NeMo into n…
MaximumEntropy May 21, 2022
153b345
Logging
MaximumEntropy May 21, 2022
c924138
Set seed for distributed sampler
MaximumEntropy May 23, 2022
d1642b3
Cluster debugging messages
MaximumEntropy May 23, 2022
51c644d
Style
MaximumEntropy May 23, 2022
134cdc0
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 23, 2022
74705da
Fix max generation delta
MaximumEntropy May 23, 2022
d725791
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 23, 2022
b751372
No LM Init
MaximumEntropy May 24, 2022
6607611
Merge branch 'main' into fix_no_lm_init
MaximumEntropy May 24, 2022
ee87d78
Merge branch 'fix_no_lm_init' of github.com:NVIDIA/NeMo into nmt_memm…
MaximumEntropy May 24, 2022
244ebde
Use nlp save restore connector
MaximumEntropy May 24, 2022
37b2397
Remove useless infer args
MaximumEntropy May 25, 2022
b047c77
Typo
MaximumEntropy May 25, 2022
80639da
UTF8 safe print of translation result
MaximumEntropy May 25, 2022
5caca92
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 25, 2022
be27a7f
Style
MaximumEntropy May 25, 2022
aff7d14
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 25, 2022
9fa3c4e
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 26, 2022
3efd318
Add save restore connector back with comment
MaximumEntropy May 26, 2022
605a7c7
Refactor
MaximumEntropy May 26, 2022
1c721b8
Fix CI test
MaximumEntropy May 31, 2022
031c127
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 31, 2022
04c91ba
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 31, 2022
120ae1e
Add missing args
MaximumEntropy May 31, 2022
0169239
Merge branch 'main' into nmt_memmap_dataloader
aklife97 May 31, 2022
fd0b976
Merge branch 'nmt_memmap_dataloader' of github.com:NVIDIA/NeMo into n…
MaximumEntropy May 31, 2022
1fe838b
Address comments
MaximumEntropy May 31, 2022
09884f3
Empty to restart
MaximumEntropy May 31, 2022
4323dc0
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy May 31, 2022
ff3df3d
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy May 31, 2022
d1b71b5
Fix CI test
MaximumEntropy May 31, 2022
a6c9dda
Merge branch 'nmt_memmap_dataloader' of github.com:NVIDIA/NeMo into n…
MaximumEntropy May 31, 2022
19630e2
Merge branch 'main' of github.com:NVIDIA/NeMo into nmt_memmap_dataloader
MaximumEntropy Jun 1, 2022
2d3ba1a
Check for test ds
MaximumEntropy Jun 1, 2022
4214909
Fix merge conflicts
MaximumEntropy Jun 1, 2022
d9b4102
set fusion to false
MaximumEntropy Jun 1, 2022
37cda3b
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy Jun 1, 2022
2423d07
Merge branch 'main' of github.com:NVIDIA/NeMo into main
MaximumEntropy Jun 1, 2022
86137df
Merge branch 'main' into nmt_memmap_dataloader
MaximumEntropy Jun 1, 2022
cccf09c
Initial perceiver encoder
MaximumEntropy Jun 1, 2022
7d770e9
Fix conflicts and some perceiver fixes
MaximumEntropy Jun 1, 2022
a8286a6
Perceiver with PP=1
MaximumEntropy Jun 2, 2022
a440c78
Remove init cross attn
MaximumEntropy Jun 2, 2022
ae12cae
Fix conflicts and merge main
MaximumEntropy Jun 2, 2022
3e3629d
CI test and remove init cross attn arg
MaximumEntropy Jun 2, 2022
632c43d
Remove init cross attn layers from file
MaximumEntropy Jun 2, 2022
5cf1582
Style
MaximumEntropy Jun 2, 2022
e811827
Clean up
MaximumEntropy Jun 2, 2022
e7150b3
Merge branch 'main' into megatron_perceiver
michalivne Jun 3, 2022
89cb8d8
Merge branch 'main' into megatron_perceiver
MaximumEntropy Jun 7, 2022
0ab2227
update branch
ericharper Jun 13, 2022
15d3101
Set headscale false (#4364)
MaximumEntropy Jun 13, 2022
2366748
Add wandb as dependency (#4365)
titu1994 Jun 13, 2022
c26100d
Raise trainer error (#4356)
MaximumEntropy Jun 13, 2022
7246e6b
Set headscale false (#4364) (#4366)
titu1994 Jun 14, 2022
372f9f7
Finetuning changes for BART (#4003)
MaximumEntropy Jun 14, 2022
98f8988
Make position embedding expansion specific to a batch to avoid checkp…
MaximumEntropy Jun 14, 2022
4dbda18
Refactor bias act fusion
MaximumEntropy Jun 15, 2022
fec4763
Update NMT config
MaximumEntropy Jun 15, 2022
31099f4
Fix electronic bug, new time ITN rule (#4355)
ekmb Jun 15, 2022
f94e587
Update ci tests
MaximumEntropy Jun 15, 2022
bb87904
Merge branch 'r1.10.0' into bias_act_fusion_refactor
MaximumEntropy Jun 15, 2022
f67625c
Correct support for dataclasses in default module dim (#4372)
titu1994 Jun 15, 2022
8187bc9
fix pad id bug (#4377)
yidong72 Jun 16, 2022
a9c0cab
Question answering bug fix (#4381)
Zhilin123 Jun 16, 2022
e6c5347
Merge branch 'r1.10.0' into bias_act_fusion_refactor
MaximumEntropy Jun 16, 2022
76ae9b5
Fix ASR Typos in tutorials (#4384)
titu1994 Jun 17, 2022
f304ca6
Merge branch 'r1.10.0' into bias_act_fusion_refactor
MaximumEntropy Jun 17, 2022
b2fba34
Add Docs for NeMo Adapters (#4369)
titu1994 Jun 17, 2022
d8c5fac
Update NeMo docs (#4397)
titu1994 Jun 17, 2022
faaf02f
Punctuation and capitalization tests race condition (#4399)
PeganovAnton Jun 18, 2022
bf433c0
Merge branch 'r1.10.0' into bias_act_fusion_refactor
MaximumEntropy Jun 21, 2022
b25bfb5
Merge bias act fusion refactor
MaximumEntropy Jun 21, 2022
dd836b5
bias act fusion changes
MaximumEntropy Jun 21, 2022
6c42156
Merge branch 'main' of github.com:NVIDIA/NeMo into megatron_perceiver
MaximumEntropy Jun 26, 2022
7608f3b
Address comments
MaximumEntropy Jun 26, 2022
b9e3ca5
Merge branch 'main' of github.com:NVIDIA/NeMo into megatron_perceiver
MaximumEntropy Jun 29, 2022
3c38900
Fix geglu without fusion
MaximumEntropy Jul 2, 2022
ba8ff05
Merge and fix
MaximumEntropy Jul 3, 2022
64ddf1c
Merge branch 'geglu_no_fusion_fix' of github.com:NVIDIA/NeMo into meg…
MaximumEntropy Jul 3, 2022
35db236
Reset files to main
MaximumEntropy Jul 5, 2022
6ff9618
Remove hidden blocks
MaximumEntropy Jul 5, 2022
ffa9914
Fix style
MaximumEntropy Jul 6, 2022
0a28392
Merge branch 'main' of github.com:NVIDIA/NeMo into megatron_perceiver
MaximumEntropy Jul 6, 2022
216a034
Merge branch 'main' into megatron_perceiver
MaximumEntropy Jul 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/nlp/language_modeling/conf/megatron_bart_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ model:

seq_length: 512
max_position_embeddings: ${.seq_length}
num_layers: 12
num_layers: 12 # For perceiver models, this is the number of cross-attention blocks. Each layer has 1 cross-attention and "num_self_attention_per_cross_attention" self-attention layers.
hidden_size: 768
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 12
Expand All @@ -76,11 +76,13 @@ model:
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
bias: True # Whether to use bias terms in all weight matrices.
normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm'
encoder_arch: 'transformer'
decoder_arch: 'transformer'
encoder_arch: 'transformer' # Options: ['transformer', 'perceiver']
decoder_arch: 'transformer' # Options: ['transformer']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
hidden_steps: 32 # Number of latent vectors to use for pereceiver encoders
num_self_attention_per_cross_attention: 1 # Number of self-attention layers for every cross-attention layer.

tokenizer:
library: 'megatron'
Expand Down
8 changes: 5 additions & 3 deletions examples/nlp/language_modeling/conf/megatron_t5_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ model:

seq_length: 512
max_position_embeddings: ${.seq_length}
num_layers: 12
num_layers: 12 # For perceiver models, this is the number of cross-attention blocks. Each layer has 1 cross-attention and "num_self_attention_per_cross_attention" self-attention layers.
hidden_size: 768
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 12
Expand All @@ -78,11 +78,13 @@ model:
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
bias: True # Whether to use bias terms in all weight matrices.
normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm'
encoder_arch: 'transformer'
decoder_arch: 'transformer'
encoder_arch: 'transformer' # Options: ['transformer', 'perceiver']
decoder_arch: 'transformer' # Options: ['transformer']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
hidden_steps: 32 # Number of latent vectors to use for pereceiver encoders
num_self_attention_per_cross_attention: 1 # Number of self-attention layers for every cross-attention layer.

tokenizer:
library: 'megatron'
Expand Down
8 changes: 5 additions & 3 deletions examples/nlp/language_modeling/conf/megatron_ul2_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ model:

seq_length: 512
max_position_embeddings: ${.seq_length}
num_layers: 12
num_layers: 12 # For perceiver models, this is the number of cross-attention blocks. Each layer has 1 cross-attention and "num_self_attention_per_cross_attention" self-attention layers.
hidden_size: 768
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 12
Expand All @@ -75,11 +75,13 @@ model:
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
bias: True # Whether to use bias terms in all weight matrices.
normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm'
encoder_arch: 'transformer'
decoder_arch: 'transformer'
encoder_arch: 'transformer' # Options: ['transformer', 'perceiver']
decoder_arch: 'transformer' # Options: ['transformer']
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
hidden_steps: 32 # Number of latent vectors to use for pereceiver encoders
num_self_attention_per_cross_attention: 1 # Number of self-attention layers for every cross-attention layer.
MaximumEntropy marked this conversation as resolved.
Show resolved Hide resolved

tokenizer:
library: 'megatron'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ model:

seq_length: 512
max_position_embeddings: ${.seq_length}
num_layers: 12
num_layers: 12 # For perceiver models, this is the number of cross-attention blocks. Each layer has 1 cross-attention and "num_self_attention_per_cross_attention" self-attention layers.
hidden_size: 768
ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 12
Expand All @@ -91,6 +91,8 @@ model:
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
hidden_steps: 32 # Number of latent vectors to use for pereceiver encoders
num_self_attention_per_cross_attention: 1 # Number of self-attention layers for every cross-attention layer.

# precision
native_amp_init_scale: 4294967296 # 2 ** 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def setup_optimizer_param_groups(self):

def model_provider_func(self, pre_process, post_process, add_encoder, add_decoder):
# TODO: create get_encoder_decoder_model()here for different losses (e..g, nll, vae, mim)
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and self.cfg.encoder_arch == 'perceiver':
raise ValueError(f"Perceivers with pipeline parallel > 1 is not supported yet.")
if hasattr(self.cfg, 'bias_gelu_fusion'):
logging.warning('bias_gelu_fusion is deprecated. Please use bias_activation_fusion instead.')
activation_fusion = self.cfg.bias_gelu_fusion
Expand Down Expand Up @@ -163,6 +165,8 @@ def model_provider_func(self, pre_process, post_process, add_encoder, add_decode
normalization=self.cfg.get('normalization', 'layernorm'),
transformer_block_type=self.cfg.get('transformer_block_type', 'pre_ln'),
headscale=self.cfg.get('headscale', False),
hidden_steps=self.cfg.get('hidden_steps', -1),
num_self_attention_per_cross_attention=self.cfg.get('num_self_attention_per_cross_attention', 1),
add_encoder=add_encoder,
add_decoder=add_decoder,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def get_decoder_model(
headscale=False,
transformer_block_type="pre_ln",
hidden_steps=-1,
hidden_blocks=1,
parent_model_type=ModelType.encoder_or_decoder,
layer_type=None,
chunk_size=64,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

"""Transformer based language model."""
import torch

from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults

Expand Down Expand Up @@ -41,15 +43,25 @@ def __init__(
# AttnMaskType enum mask type (e.g., padding, casual)
encoder_attn_mask_type: AttnMaskType = None,
decoder_attn_mask_type: AttnMaskType = None,
hidden_steps: int = None,
):
super(MegatronTransformerEncoderDecoderModule, self).__init__()

self.encoder = encoder
self.decoder = decoder
self.hidden_steps = hidden_steps
if isinstance(encoder, MegatronPerceiverEncoderModule) and hidden_steps is None:
raise ValueError(
f"hidden_steps cannot be None for perceiver encoders. It is needed to compute the encoder-decoder cross attention mask."
)

# try to infer mask_type if not given
if encoder_attn_mask_type is None:
if encoder is None:
encoder_attn_mask_type = None
# Perceiver does not have a `.model` attribute, assume it always uses padding mask.
elif isinstance(encoder, MegatronPerceiverEncoderModule):
encoder_attn_mask_type = AttnMaskType.padding
elif hasattr(encoder.model, 'self_attn_mask_type'):
encoder_attn_mask_type = encoder.model.self_attn_mask_type
else:
Expand Down Expand Up @@ -136,6 +148,10 @@ def forward(
return enc_output

# decoder
# Adjust encoder attention mask if encoder is a perceiver.
if self.encoder is not None and isinstance(self.encoder, MegatronPerceiverEncoderModule):
enc_attn_mask = torch.ones(enc_output.size(0), self.hidden_steps).to(enc_output.device)

dec_output = self.decode(
dec_input=dec_input,
dec_attn_mask=dec_attn_mask,
Expand Down
45 changes: 43 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/megatron_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Transformer based language model."""
from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule
from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import MegatronTransformerEncoderModule
from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import (
MegatronRetrievalTransformerEncoderModule,
Expand All @@ -35,7 +36,7 @@

__all__ = []

AVAILABLE_ENCODERS = ["transformer"]
AVAILABLE_ENCODERS = ["transformer", "perceiver", "retro"]


def get_encoder_model(
Expand Down Expand Up @@ -74,11 +75,12 @@ def get_encoder_model(
normalization="layernorm",
headscale=False,
transformer_block_type="pre_ln",
hidden_steps=-1,
hidden_steps=32,
hidden_blocks=1,
parent_model_type=ModelType.encoder_or_decoder,
layer_type=None,
chunk_size=64,
num_self_attention_per_cross_attention=1,
layer_number_offset=0, # this is use only for attention norm_factor scaling
):
"""Build language model and return along with the key to save."""
Expand Down Expand Up @@ -168,6 +170,45 @@ def get_encoder_model(
chunk_size=chunk_size,
layer_number_offset=layer_number_offset,
)
elif arch == "perceiver":
encoder = MegatronPerceiverEncoderModule(
init_method=init_method,
output_layer_init_method=scaled_init_method,
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
kv_channels=kv_channels,
ffn_hidden_size=ffn_hidden_size,
encoder_attn_mask_type=encoder_attn_mask_type,
pre_process=pre_process,
post_process=post_process,
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
position_embedding_type=position_embedding_type,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
precision=precision,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
activation=activation,
bias=bias,
normalization=normalization,
transformer_block_type=transformer_block_type,
headscale=headscale,
parent_model_type=parent_model_type,
hidden_steps=hidden_steps,
num_self_attention_per_cross_attention=num_self_attention_per_cross_attention,
)
else:
raise ValueError(f"Unknown encoder arch = {arch}. Available encoder arch = {AVAILABLE_ENCODERS}")

Expand Down
Loading