diff --git a/olmo/model.py b/olmo/model.py index fb9275d6f..bdf3acf99 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -9,7 +9,7 @@ import math import os from abc import abstractmethod -from typing import Dict, List, NamedTuple, Optional, cast +from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast import torch import torch.backends.cuda @@ -275,8 +275,14 @@ def get_rotary_embedding(self, seq_len: int, device: Optional[torch.device]) -> return pos_emb def attention( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_bias: Optional[torch.FloatTensor] = None - ) -> torch.Tensor: + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, C = q.size() # batch size, sequence length, d_model dtype = k.dtype @@ -299,10 +305,26 @@ def attention( # shape: (B, nh, T, hs) v = v.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + if use_cache: + present = (k, v) + else: + present = None + + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + if self.config.rope: # Apply rotary embeddings. - positions = self.get_rotary_embedding(T, q.device) - q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) + positions = self.get_rotary_embedding(key_len, q.device) + q = apply_rotary_pos_emb(positions[key_len - query_len : key_len], q) + k = apply_rotary_pos_emb(positions, k) + + if attention_bias is not None: + attention_bias = attention_bias[:, :, key_len - query_len : key_len, :key_len] # Get the attention scores. # shape: (B, nh, T, hs) @@ -319,14 +341,14 @@ def attention( att = att.transpose(1, 2).contiguous().view(B, T, C) # Apply output projection. - return self.attn_out(att) + return self.attn_out(att), present @abstractmethod def forward( self, x: torch.Tensor, attention_bias: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: raise NotImplementedError @classmethod @@ -364,8 +386,10 @@ def __init__(self, config: ModelConfig): def forward( self, x: torch.Tensor, - attention_bias: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Get query, key, value projections. # shape: # - for regular attn q, k, v: (batch_size, seq_len, d_model) @@ -373,15 +397,18 @@ def forward( # k, v: (batch_size, seq_len, d_model // n_heads) q, k, v = self.att_proj(self.norm(x)).split(self.fused_dims, dim=-1) + # Get attention scores. + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + # Add attention scores. # shape: (B, T, C) - x = x + self.dropout(self.attention(q, k, v, attention_bias)) + x = x + self.dropout(att) # Add feed-forward projection. # shape: (batch_size, seq_len, d_model) x = x + self.dropout(self.ff_out(self.act(self.ff_proj(self.norm(x))))) - return x + return x, cache class OlmoParallelBlock(OlmoBlock): @@ -419,8 +446,10 @@ def __init__(self, config: ModelConfig): def forward( self, x: torch.Tensor, - attention_bias: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Get query, key, value, and feed-forward projections. # shape of q, k, v: # - for regular attn q, k, v: (batch_size, seq_len, d_model) @@ -431,12 +460,12 @@ def forward( # Get attention scores. # shape: (B, T, C) - att = self.attention(q, k, v, attention_bias) + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) # Apply output projections (and activation function) and sum the results. # We keep these projections separate because we found that we got better throughput this # way compared to fusing them. - return x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att) + return x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att), cache class OlmoOutput(NamedTuple): @@ -446,6 +475,11 @@ class OlmoOutput(NamedTuple): for the next token *before* normalization via (log) softmax. """ + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] + """ + Attention keys and values from each block. + """ + class OlmoGenerateOutput(NamedTuple): token_ids: torch.LongTensor @@ -582,6 +616,9 @@ def forward( input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + last_logits_only: bool = False, ) -> OlmoOutput: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. @@ -604,7 +641,16 @@ def forward( scores before the softmax. The default is causal, which corresponds to a lower-diagonal byte matrix of ones. + :param past_key_values: Pre-computed keys and values for each attention block. + Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + :param use_cache: If `True`, return key and value tensors for each block. + :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. + This can speed up decoding when you only care about the next token. """ + if past_key_values: + assert len(past_key_values) == self.config.n_layers + batch_size, seq_len = input_ids.size() assert seq_len <= self.config.max_sequence_length, ( f"Cannot forward input with seq_len={seq_len}, " @@ -617,8 +663,14 @@ def forward( if not (self.config.alibi or self.config.rope): # Get positional embeddings. + if past_key_values is None: + past_length = 0 + else: + past_length = past_key_values[0][0].size(-2) # shape: (1, seq_len) - pos = torch.arange(0, seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0) + pos = torch.arange( + past_length, past_length + seq_len, dtype=torch.long, device=input_ids.device + ).unsqueeze(0) # shape: (1, seq_len, d_model) pos_emb = self.transformer.wpe(pos) # type: ignore x = pos_emb + x @@ -635,7 +687,15 @@ def forward( attention_mask.masked_fill_(attention_mask == 1.0, float("-inf")) # Merge attention mask with attention bias. - if attention_bias is not None or attention_mask is not None or self.config.alibi: + if ( + attention_bias is not None + or attention_mask is not None + or self.config.alibi + # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly + # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute + # scores correctly. + or past_key_values is not None + ): if attention_bias is None and self.config.alibi: attention_bias = self.causal_attention_bias + self.alibi_attention_bias elif attention_bias is None: @@ -644,26 +704,44 @@ def forward( attention_bias = attention_bias.to(dtype=x.dtype) attention_bias.masked_fill_(attention_bias == 0.0, float("-inf")) - attention_bias = attention_bias[:, :, :seq_len, :seq_len].to(x.dtype) + # Transform to the right shape and data type. + mask_len = seq_len + if attention_mask is not None: + mask_len = attention_mask.shape[-1] + elif past_key_values is not None: + mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1] + attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(x.dtype) # Add in the masking bias. if attention_mask is not None: attention_bias = attention_bias + attention_mask + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + # Apply blocks one-by-one. - for block in self.transformer.blocks: # type: ignore + for block, layer_past in zip( + self.transformer.blocks, # type: ignore + past_key_values or [None] * self.config.n_layers, # type: ignore + ): # shape: (batch_size, seq_len, d_model) - x = block(x, attention_bias=attention_bias) + x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + + if last_logits_only: + # shape: (batch_size, 1, d_model) + x = x[:, -1, :].unsqueeze(1) # Apply final layer norm. - # shape: (batch_size, seq_len, d_model) + # shape: (batch_size, seq_len or 1, d_model) x = self.transformer.ln_f(x) # type: ignore # Get logits. - # shape: (batch_size, seq_len, vocab_size) + # shape: (batch_size, seq_len or 1, vocab_size) logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore - return OlmoOutput(logits=logits) # type: ignore[arg-type] + return OlmoOutput(logits=logits, attn_key_values=attn_key_values) # type: ignore[arg-type] def fsdp_wrap_fn(self, module, recurse: bool = True, nonwrapped_numel: int = 0): del recurse, nonwrapped_numel @@ -807,29 +885,58 @@ def generate( tokens_generated = 0 + def flatten_past_key_values( + past_key_values: List[Tuple[torch.Tensor, torch.Tensor]] + ) -> Dict[str, torch.Tensor]: + out = {} + for i, (key, value) in enumerate(past_key_values): + out[f"past_key_{i}"] = key + out[f"past_value_{i}"] = value + return out + + def unflatten_past_key_values( + past_key_values: Dict[str, torch.Tensor] + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + out = [] + for i in range(self.config.n_layers): + past_key = past_key_values[f"past_key_{i}"] + past_value = past_key_values[f"past_value_{i}"] + out.append((past_key, past_value)) + return out + def step( last_predictions: torch.Tensor, state: dict[str, torch.Tensor] ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: nonlocal tokens_generated - input_ids = state["input_ids"] attention_mask = state.get("attention_mask") attention_bias = state.get("attention_bias") - group_size = input_ids.shape[0] if tokens_generated > 0: - input_ids = torch.cat((input_ids, last_predictions.unsqueeze(1)), dim=-1) + past_key_values = unflatten_past_key_values(state) + input_ids = last_predictions.unsqueeze(1) if attention_mask is not None: + group_size = input_ids.shape[0] attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1) + else: + past_key_values = None + input_ids = state["input_ids"] tokens_generated += 1 # Run forward pass of model to get logits, then normalize to get log probs. - output = self(input_ids, attention_mask=attention_mask, attention_bias=attention_bias) + output = self( + input_ids, + attention_mask=attention_mask, + attention_bias=attention_bias, + past_key_values=past_key_values, + use_cache=True, + last_logits_only=True, + ) log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1) # Create new state. - state = {"input_ids": input_ids} + state = flatten_past_key_values(output.attn_key_values) if attention_mask is not None: state["attention_mask"] = attention_mask if attention_bias is not None: diff --git a/tests/model_test.py b/tests/model_test.py index 45d2b1caf..c1165a2ba 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -24,14 +24,12 @@ id="alibi-emb-parallel-block-cpu-bf16", ), pytest.param( - False, False, False, BlockType.sequential, False, False, torch.bfloat16, id="posit-emb-cpu-bf16" + False, False, False, BlockType.sequential, False, False, torch.bfloat16, id="abs-emb-cpu-bf16" ), pytest.param( True, False, False, BlockType.sequential, False, False, torch.float32, id="alibi-emb-cpu-f32" ), - pytest.param( - False, False, False, BlockType.sequential, False, False, torch.float32, id="posit-emb-cpu-f32" - ), + pytest.param(False, False, False, BlockType.sequential, False, False, torch.float32, id="abs-emb-cpu-f32"), pytest.param( False, True, False, BlockType.sequential, False, False, torch.bfloat16, id="rope-emb-cpu-bf16" ), @@ -86,7 +84,7 @@ False, True, torch.bfloat16, - id="posit-emb-cuda-bf16", + id="abs-emb-cuda-bf16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), @@ -100,7 +98,7 @@ False, True, torch.bfloat16, - id="posit-emb-flash-cuda-bf16", + id="abs-emb-flash-cuda-bf16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), @@ -114,14 +112,14 @@ False, True, torch.float16, - id="posit-emb-flash-cuda-f16", + id="abs-emb-flash-cuda-f16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), ), ), pytest.param( - False, False, False, BlockType.sequential, True, False, torch.float32, id="posit-emb-mqattn-cpu-f32" + False, False, False, BlockType.sequential, True, False, torch.float32, id="abs-emb-mqattn-cpu-f32" ), pytest.param( False, @@ -131,7 +129,7 @@ True, False, torch.float32, - id="posit-emb-parallel-block-mqattn-cpu-f32", + id="abs-emb-parallel-block-mqattn-cpu-f32", ), ], ) @@ -183,14 +181,31 @@ def test_forward( device_type="cuda" if cuda else "cpu", enabled=use_amp, dtype=None if not use_amp else dtype ): output1 = model(torch.tensor(input1, device=model.device).unsqueeze(0)) + key_value_cache1 = model( + torch.tensor(input1[:-1], device=model.device).unsqueeze(0), use_cache=True + ).attn_key_values + output1_from_cached = model( + torch.tensor(input1[-1:], device=model.device).unsqueeze(0), past_key_values=key_value_cache1 + ) output2 = model(torch.tensor(input2, device=model.device).unsqueeze(0)) batch_output = model(**batch_inputs) + batch_key_value_cache = model( + batch_inputs["input_ids"][:, :-1], + attention_mask=batch_inputs["attention_mask"][:, :-1], + use_cache=True, + ).attn_key_values + batch_output_from_cached = model( + batch_inputs["input_ids"][:, -1].unsqueeze(1), + attention_mask=batch_inputs["attention_mask"], + past_key_values=batch_key_value_cache, + ) - # Check that logits from individual inputs are equal to logits from batch. # With using half-precision types these might have some big differences in a small # percentage of the elements. atol = 1e-2 if use_amp else None rtol = 1e3 if use_amp else None + + # Check that logits from individual inputs are equal to logits from batch. torch.testing.assert_close( output1.logits[0][: len(input1)], batch_output.logits[0][: len(input1)], rtol=rtol, atol=atol ) @@ -198,12 +213,17 @@ def test_forward( output2.logits[0][: len(input2)], batch_output.logits[1][: len(input2)], rtol=rtol, atol=atol ) + # Check that output using cached attention keys + values matches. + torch.testing.assert_close(output1.logits[0][-1], output1_from_cached.logits[0][-1], rtol=rtol, atol=atol) + # For the batched output this only makes sense for the longer of the two inputs, since the shorter one is padded on the right. + torch.testing.assert_close(output2.logits[0][-1], batch_output_from_cached.logits[1][-1], rtol=rtol, atol=atol) + @pytest.mark.parametrize( "alibi, flash_attn, cuda, dtype", [ pytest.param(True, False, False, torch.bfloat16, id="alibi-emb-cpu-bf16"), - pytest.param(False, False, False, torch.bfloat16, id="posit-emb-cpu-bf16"), + pytest.param(False, False, False, torch.bfloat16, id="abs-emb-cpu-bf16"), pytest.param( True, False, @@ -220,7 +240,7 @@ def test_forward( False, True, torch.bfloat16, - id="posit-emb-cuda-bf16", + id="abs-emb-cuda-bf16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), @@ -231,7 +251,7 @@ def test_forward( True, True, torch.bfloat16, - id="posit-emb-flash-cuda-bf16", + id="abs-emb-flash-cuda-bf16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"),