Skip to content

Commit

Permalink
Merge pull request #216 from allenai/petew-cache-attn
Browse files Browse the repository at this point in the history
Cache attention keys + values to speed up inference
  • Loading branch information
epwalsh committed Jun 20, 2023
2 parents 05c6d53 + 3ab48df commit 7c866c9
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 41 deletions.
163 changes: 135 additions & 28 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math
import os
from abc import abstractmethod
from typing import Dict, List, NamedTuple, Optional, cast
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast

import torch
import torch.backends.cuda
Expand Down Expand Up @@ -275,8 +275,14 @@ def get_rotary_embedding(self, seq_len: int, device: Optional[torch.device]) ->
return pos_emb

def attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_bias: Optional[torch.FloatTensor] = None
) -> torch.Tensor:
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, C = q.size() # batch size, sequence length, d_model
dtype = k.dtype

Expand All @@ -299,10 +305,26 @@ def attention(
# shape: (B, nh, T, hs)
v = v.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)

if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)

if use_cache:
present = (k, v)
else:
present = None

query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None

if self.config.rope:
# Apply rotary embeddings.
positions = self.get_rotary_embedding(T, q.device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
positions = self.get_rotary_embedding(key_len, q.device)
q = apply_rotary_pos_emb(positions[key_len - query_len : key_len], q)
k = apply_rotary_pos_emb(positions, k)

if attention_bias is not None:
attention_bias = attention_bias[:, :, key_len - query_len : key_len, :key_len]

# Get the attention scores.
# shape: (B, nh, T, hs)
Expand All @@ -319,14 +341,14 @@ def attention(
att = att.transpose(1, 2).contiguous().view(B, T, C)

# Apply output projection.
return self.attn_out(att)
return self.attn_out(att), present

@abstractmethod
def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
raise NotImplementedError

@classmethod
Expand Down Expand Up @@ -364,24 +386,29 @@ def __init__(self, config: ModelConfig):
def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Get query, key, value projections.
# shape:
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
q, k, v = self.att_proj(self.norm(x)).split(self.fused_dims, dim=-1)

# Get attention scores.
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Add attention scores.
# shape: (B, T, C)
x = x + self.dropout(self.attention(q, k, v, attention_bias))
x = x + self.dropout(att)

# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
x = x + self.dropout(self.ff_out(self.act(self.ff_proj(self.norm(x)))))

return x
return x, cache


class OlmoParallelBlock(OlmoBlock):
Expand Down Expand Up @@ -419,8 +446,10 @@ def __init__(self, config: ModelConfig):
def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Get query, key, value, and feed-forward projections.
# shape of q, k, v:
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
Expand All @@ -431,12 +460,12 @@ def forward(

# Get attention scores.
# shape: (B, T, C)
att = self.attention(q, k, v, attention_bias)
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Apply output projections (and activation function) and sum the results.
# We keep these projections separate because we found that we got better throughput this
# way compared to fusing them.
return x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att)
return x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att), cache


class OlmoOutput(NamedTuple):
Expand All @@ -446,6 +475,11 @@ class OlmoOutput(NamedTuple):
for the next token *before* normalization via (log) softmax.
"""

attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
"""
Attention keys and values from each block.
"""


class OlmoGenerateOutput(NamedTuple):
token_ids: torch.LongTensor
Expand Down Expand Up @@ -582,6 +616,9 @@ def forward(
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
last_logits_only: bool = False,
) -> OlmoOutput:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
Expand All @@ -604,7 +641,16 @@ def forward(
scores before the softmax.
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
:param past_key_values: Pre-computed keys and values for each attention block.
Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
:param use_cache: If `True`, return key and value tensors for each block.
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
This can speed up decoding when you only care about the next token.
"""
if past_key_values:
assert len(past_key_values) == self.config.n_layers

batch_size, seq_len = input_ids.size()
assert seq_len <= self.config.max_sequence_length, (
f"Cannot forward input with seq_len={seq_len}, "
Expand All @@ -617,8 +663,14 @@ def forward(

if not (self.config.alibi or self.config.rope):
# Get positional embeddings.
if past_key_values is None:
past_length = 0
else:
past_length = past_key_values[0][0].size(-2)
# shape: (1, seq_len)
pos = torch.arange(0, seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0)
pos = torch.arange(
past_length, past_length + seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
# shape: (1, seq_len, d_model)
pos_emb = self.transformer.wpe(pos) # type: ignore
x = pos_emb + x
Expand All @@ -635,7 +687,15 @@ def forward(
attention_mask.masked_fill_(attention_mask == 1.0, float("-inf"))

# Merge attention mask with attention bias.
if attention_bias is not None or attention_mask is not None or self.config.alibi:
if (
attention_bias is not None
or attention_mask is not None
or self.config.alibi
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
# scores correctly.
or past_key_values is not None
):
if attention_bias is None and self.config.alibi:
attention_bias = self.causal_attention_bias + self.alibi_attention_bias
elif attention_bias is None:
Expand All @@ -644,26 +704,44 @@ def forward(
attention_bias = attention_bias.to(dtype=x.dtype)
attention_bias.masked_fill_(attention_bias == 0.0, float("-inf"))

attention_bias = attention_bias[:, :, :seq_len, :seq_len].to(x.dtype)
# Transform to the right shape and data type.
mask_len = seq_len
if attention_mask is not None:
mask_len = attention_mask.shape[-1]
elif past_key_values is not None:
mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1]
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(x.dtype)

# Add in the masking bias.
if attention_mask is not None:
attention_bias = attention_bias + attention_mask

attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None

# Apply blocks one-by-one.
for block in self.transformer.blocks: # type: ignore
for block, layer_past in zip(
self.transformer.blocks, # type: ignore
past_key_values or [None] * self.config.n_layers, # type: ignore
):
# shape: (batch_size, seq_len, d_model)
x = block(x, attention_bias=attention_bias)
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)

if last_logits_only:
# shape: (batch_size, 1, d_model)
x = x[:, -1, :].unsqueeze(1)

# Apply final layer norm.
# shape: (batch_size, seq_len, d_model)
# shape: (batch_size, seq_len or 1, d_model)
x = self.transformer.ln_f(x) # type: ignore

# Get logits.
# shape: (batch_size, seq_len, vocab_size)
# shape: (batch_size, seq_len or 1, vocab_size)
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore

return OlmoOutput(logits=logits) # type: ignore[arg-type]
return OlmoOutput(logits=logits, attn_key_values=attn_key_values) # type: ignore[arg-type]

def fsdp_wrap_fn(self, module, recurse: bool = True, nonwrapped_numel: int = 0):
del recurse, nonwrapped_numel
Expand Down Expand Up @@ -807,29 +885,58 @@ def generate(

tokens_generated = 0

def flatten_past_key_values(
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Dict[str, torch.Tensor]:
out = {}
for i, (key, value) in enumerate(past_key_values):
out[f"past_key_{i}"] = key
out[f"past_value_{i}"] = value
return out

def unflatten_past_key_values(
past_key_values: Dict[str, torch.Tensor]
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
out = []
for i in range(self.config.n_layers):
past_key = past_key_values[f"past_key_{i}"]
past_value = past_key_values[f"past_value_{i}"]
out.append((past_key, past_value))
return out

def step(
last_predictions: torch.Tensor, state: dict[str, torch.Tensor]
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
nonlocal tokens_generated

input_ids = state["input_ids"]
attention_mask = state.get("attention_mask")
attention_bias = state.get("attention_bias")
group_size = input_ids.shape[0]

if tokens_generated > 0:
input_ids = torch.cat((input_ids, last_predictions.unsqueeze(1)), dim=-1)
past_key_values = unflatten_past_key_values(state)
input_ids = last_predictions.unsqueeze(1)
if attention_mask is not None:
group_size = input_ids.shape[0]
attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1)
else:
past_key_values = None
input_ids = state["input_ids"]

tokens_generated += 1

# Run forward pass of model to get logits, then normalize to get log probs.
output = self(input_ids, attention_mask=attention_mask, attention_bias=attention_bias)
output = self(
input_ids,
attention_mask=attention_mask,
attention_bias=attention_bias,
past_key_values=past_key_values,
use_cache=True,
last_logits_only=True,
)
log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1)

# Create new state.
state = {"input_ids": input_ids}
state = flatten_past_key_values(output.attn_key_values)
if attention_mask is not None:
state["attention_mask"] = attention_mask
if attention_bias is not None:
Expand Down

0 comments on commit 7c866c9

Please sign in to comment.