Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make position embedding expansion specific to a batch to avoid checkpoint size mismatches #4357

Merged
merged 9 commits into from Jun 14, 2022
Expand Up @@ -24,6 +24,7 @@
from torch.nn.functional import gelu

from nemo.collections.common.parts import form_attention_mask
from nemo.utils import logging

__all__ = ["TransformerEmbedding", "AttentionBridge"]

Expand Down Expand Up @@ -62,14 +63,23 @@ def forward(self, position_ids):
max_pos_id = position_ids.max()
# update positional encoding if needed
if max_pos_id >= self._max_sequence_length:
self._max_sequence_length = max_pos_id + 1
logging.warn(
MaximumEntropy marked this conversation as resolved.
Show resolved Hide resolved
f'Max position id {max_pos_id} is greater than max sequence length {self._max_sequence_length}. Expanding position embeddings just for this batch. This is not expected to work very well. Consider chunking your input into smaller sequences.'
)
self._build_pos_enc(
hidden_size=self._hidden_size, max_sequence_length=max_pos_id + 1, device=position_ids.device,
)

embeddings = torch.embedding(self.pos_enc, position_ids)

# Revert expansion of position embeddings since this wall checkpoint size mismatches.
if max_pos_id >= self._max_sequence_length:
self._build_pos_enc(
hidden_size=self._hidden_size,
max_sequence_length=self._max_sequence_length,
device=position_ids.device,
)

return torch.embedding(self.pos_enc, position_ids)
return embeddings


class TransformerEmbedding(nn.Module):
Expand Down