Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HF] Add input embedding argument to HF model #442

Merged
merged 5 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fixed default value of `--tokenizer` argument to `scripts/prepare_tulu_data.py` to be an absolute path, not relative path, the script can be run from other directories.
- Added the option to directly pass input embeddings to `OLMo` and `OLMoForCausalLM`.

## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02

Expand Down
2 changes: 2 additions & 0 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None, init_params
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -64,6 +65,7 @@ def forward(
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.forward(
input_ids=input_ids,
input_embeddings=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
17 changes: 9 additions & 8 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.
def forward(
self,
input_ids: torch.LongTensor,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I would make input_ids optional, but this would not be a fully backwards compatible change.

input_embeddings: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
Expand All @@ -1145,6 +1146,8 @@ def forward(
) -> OlmoOutput:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
embeddings. When provided, it is treated as the output of the input embedding layer.
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
which input IDs are masked. A `1` value in the mask means that
the corresponding input ID should *not* be ignored. A `0` means
Expand Down Expand Up @@ -1174,22 +1177,20 @@ def forward(
if past_key_values:
assert len(past_key_values) == self.config.n_layers

batch_size, seq_len = input_ids.size()
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
if past_key_values is None:
past_length = 0
else:
past_length = past_key_values[0][0].size(-2)

# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore

if not (self.config.alibi or self.config.rope):
# Get positional embeddings.
# shape: (1, seq_len)
pos = torch.arange(
past_length, past_length + seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
# shape: (1, seq_len, d_model)
pos_emb = self.transformer.wpe(pos) # type: ignore
x = pos_emb + x
Expand Down Expand Up @@ -1229,7 +1230,7 @@ def forward(
if attention_mask is not None:
mask_len = attention_mask.shape[-1]
elif past_key_values is not None:
mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1]
mask_len = past_key_values[0][0].shape[-2] + seq_len
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)

# Add in the masking bias.
Expand Down Expand Up @@ -1470,7 +1471,7 @@ def generate(
tokens_generated = 0

def flatten_past_key_values(
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Dict[str, torch.Tensor]:
out = {}
for i, (key, value) in enumerate(past_key_values):
Expand All @@ -1479,7 +1480,7 @@ def flatten_past_key_values(
return out

def unflatten_past_key_values(
past_key_values: Dict[str, torch.Tensor]
past_key_values: Dict[str, torch.Tensor],
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
out = []
for i in range(self.config.n_layers):
Expand Down
Loading