Skip to content

Commit

Permalink
Update bench to consider github.com/ml-explore/mlx#801
Browse files Browse the repository at this point in the history
  • Loading branch information
atiorh committed Mar 14, 2024
1 parent f2468ea commit fb3c5fa
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions llms/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,38 +70,44 @@ def __call__(
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)

# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if not args.optimized_sdpa:
# Baseline implementation requires the keys and values to be repeated
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if args.optimized_sdpa:
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
else:
queries = queries.reshape(
B,
L,
self. n_kv_heads,
self.n_heads // self.n_kv_heads,
-1,
).moveaxis(1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, 1, -1).moveaxis(1, 3)
values = values.reshape(B, L, self.n_kv_heads, 1, -1).moveaxis(1, 3)

if cache is not None:
axis = 2 if args.optimized_sdpa else -2
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=key_cache.shape[axis])
keys = self.rope(keys, offset=key_cache.shape[axis])
keys = mx.concatenate([key_cache, keys], axis=axis)
values = mx.concatenate([value_cache, values], axis=axis)
else:
queries = self.rope(queries)
keys = self.rope(keys)

if args.optimized_sdpa:
# Optimized implementation
# TODO(atiorh): mx.fast_inference_sdpa --> mx.fast.scaled_dot_product_attention
output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask).transpose(0, 2, 1, 3)
else:
# Baseline implementation
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
scores = (queries * self.scale) @ keys.swapaxes(-1, -2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values)
output = (scores @ values).moveaxis(3, 1)

return self.wo(output.transpose(0, 2, 1, 3).reshape(B, L, -1)), (keys, values)
return self.wo(output.reshape(B, L, -1)), (keys, values)


class FeedForward(nn.Module):
Expand Down

0 comments on commit fb3c5fa

Please sign in to comment.