diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index f0b1fe50..a3d5260b 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -138,14 +138,14 @@ def __call__(self, xq, xk, xv, mask, cache): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) - xk: torch.Tensor of (batch size, num_heads, seqlen, head_dim) - xv: torch.Tensor of (batch size, num_heads, seqlen, head_dim) + xk: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim) + xv: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim) mask: mask with 0 and -inf, or None cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, _, _, kv_head_dim = xk.shape - n_rep = head_dim // kv_head_dim + _, num_kv_heads, _, kv_head_dim = xk.shape + n_rep = num_heads // num_kv_heads if seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) @@ -191,14 +191,14 @@ def __call__(self, xq, xk, xv, mask, cache): """ Args: xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) - xk: torch.Tensor of (batch size, num_heads, seqlen, head_dim) - xv: torch.Tensor of (batch size, num_heads, seqlen, head_dim) + xk: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim) + xv: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim) mask: mask with 0 and -inf, or None cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, _, _, kv_head_dim = xk.shape - n_rep = head_dim // kv_head_dim + _, num_kv_heads, _, kv_head_dim = xk.shape + n_rep = num_heads // num_kv_heads if seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3]))