diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index f1022f0a..01d8cc36 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -70,38 +70,44 @@ def __call__( queries, keys, values = self.wq(x), self.wk(x), self.wv(x) # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if not args.optimized_sdpa: - # Baseline implementation requires the keys and values to be repeated - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) + if args.optimized_sdpa: + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + else: + queries = queries.reshape( + B, + L, + self. n_kv_heads, + self.n_heads // self.n_kv_heads, + -1, + ).moveaxis(1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, 1, -1).moveaxis(1, 3) + values = values.reshape(B, L, self.n_kv_heads, 1, -1).moveaxis(1, 3) if cache is not None: + axis = 2 if args.optimized_sdpa else -2 key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=key_cache.shape[axis]) + keys = self.rope(keys, offset=key_cache.shape[axis]) + keys = mx.concatenate([key_cache, keys], axis=axis) + values = mx.concatenate([value_cache, values], axis=axis) else: queries = self.rope(queries) keys = self.rope(keys) if args.optimized_sdpa: - # Optimized implementation - # TODO(atiorh): mx.fast_inference_sdpa --> mx.fast.scaled_dot_product_attention - output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask).transpose(0, 2, 1, 3) else: # Baseline implementation - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + scores = (queries * self.scale) @ keys.swapaxes(-1, -2) if mask is not None: scores += mask scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values) + output = (scores @ values).moveaxis(3, 1) - return self.wo(output.transpose(0, 2, 1, 3).reshape(B, L, -1)), (keys, values) + return self.wo(output.reshape(B, L, -1)), (keys, values) class FeedForward(nn.Module):