-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathattention_kernel.py
808 lines (720 loc) · 22.1 KB
/
attention_kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
from collections.abc import Callable
import functools
import math
from typing import Any
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
from jax.experimental.shard_map import shard_map
import numpy as np
import torch
import torch.nn.functional as F
from jetstream_pt import torchjax
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
P = jax.sharding.PartitionSpec
def ragged_flash_attention_kernel(
layer_ref,
start_ref,
end_ref,
line_end_ref,
pre_b_ref,
pre_i_ref,
q_ref,
k_ref,
v_ref,
k_scaler_ref,
v_scaler_ref,
o_ref, # outputs
m_ref, # row max
l_ref, # propogation coefficient
bk: int,
mask_value: float,
normalize_var: bool,
quantized: bool,
):
"""Pallas kernel for flash attention."""
with jax.named_scope("attention_kernel"):
b, i = pl.program_id(0), pl.program_id(1)
@pl.when(i == 0)
def init():
with jax.named_scope("init"):
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)
length = line_end_ref[b]
start = start_ref[b]
end = end_ref[b]
@pl.when(jnp.logical_and(i * bk < length, start != end))
def run():
with jax.named_scope("run_qk"):
q = q_ref[...].astype(jnp.float32)
k = k_ref[...].astype(jnp.float32)
v = v_ref[...].astype(jnp.float32)
m_prev, l_prev = m_ref[...], l_ref[...]
qk = jax.lax.dot_general(
q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32
)
if normalize_var:
qk = qk / jnp.sqrt(k.shape[-1])
if quantized:
qk = qk * k_scaler_ref[...]
with jax.named_scope("run_mask"):
start = start_ref[b]
end = end_ref[b]
iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1)
mask_start_lt_end = jnp.logical_and(
i * bk + iota >= start, i * bk + iota < end
).astype(jnp.int32)
mask_start_gt_end = jnp.logical_or(
i * bk + iota >= start, i * bk + iota < end
).astype(jnp.int32)
# mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end)
mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end)
qk = qk + jnp.where(mask, 0.0, mask_value)
with jax.named_scope("run_softmax"):
m_curr = qk.max(axis=-1)
s_curr = jnp.exp(qk - m_curr[..., None])
l_curr = jax.lax.broadcast_in_dim(
s_curr.sum(axis=-1), l_prev.shape, (0,)
)
if quantized:
s_curr = s_curr * v_scaler_ref[...]
o_curr_times_l_curr = jnp.dot(s_curr, v)
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
m_ref[...], l_ref[...] = m_next, l_next_safe
o_ref[...] = (
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr)
/ l_next_safe
).astype(o_ref.dtype)
@functools.partial(
jax.jit,
static_argnames=[
"bk",
"mask_value",
"normalize_var",
"testing",
"quantized",
],
)
def ragged_mqa(
q: jax.Array,
k: jax.Array,
v: jax.Array,
layer,
start: jax.Array,
end: jax.Array,
ragged_batch_index=None,
ragged_block_index=None,
k_scaler: jax.Array | None = None,
v_scaler: jax.Array | None = None,
bk: int = 512,
mask_value: float = DEFAULT_MASK_VALUE,
normalize_var: bool = True,
testing: bool = False,
quantized: bool = False,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi query attention."""
with jax.named_scope("ragged_mqa"):
batch_size, time, head_dim = q.shape
seq_len = k.shape[-2]
stacked = False
if k.ndim == 5:
stacked = True
def kv_index_map(
b,
i,
layer_ref,
start_ref,
end_ref,
line_end_ref,
ragged_batch_index_ref,
ragged_block_index_ref,
):
index = b * (seq_len // bk) + i
if stacked:
return (
layer_ref[0],
ragged_batch_index_ref[index],
ragged_block_index_ref[index],
0,
)
return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0
def q_index_map(
b,
i,
layer_ref,
start_ref,
end_ref,
line_end_ref,
ragged_batch_index_ref,
ragged_block_index_ref,
):
index = b * (seq_len // bk) + i
if stacked:
return layer_ref[0], ragged_batch_index_ref[index], 0, 0
return ragged_batch_index_ref[index], 0, 0
def scaler_index_map(b, i, layer_ref, *_):
if stacked:
return layer_ref[0], b, 0, i
return b, 0, i
line_end = jnp.where(start < end, end, seq_len - 1)
if stacked:
q_bp = (None, None, time, head_dim)
kv_bp = (None, None, bk, head_dim)
ks_bp = (None, None, 1, bk)
else:
q_bp = (None, time, head_dim)
kv_bp = (None, bk, head_dim)
ks_bp = (None, 1, bk)
in_specs = [
pl.BlockSpec(index_map=q_index_map, block_shape=q_bp),
pl.BlockSpec(index_map=kv_index_map, block_shape=kv_bp),
pl.BlockSpec(index_map=kv_index_map, block_shape=kv_bp),
pl.BlockSpec(index_map=scaler_index_map, block_shape=ks_bp),
pl.BlockSpec(index_map=scaler_index_map, block_shape=ks_bp),
]
inputs = (
start,
end,
line_end,
ragged_batch_index,
ragged_block_index,
q,
k,
v,
k_scaler,
v_scaler,
)
out, m, l = pl.pallas_call(
functools.partial(
ragged_flash_attention_kernel,
bk=bk,
mask_value=mask_value,
normalize_var=normalize_var,
quantized=quantized,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=5,
in_specs=in_specs,
out_specs=[
pl.BlockSpec(
index_map=q_index_map, block_shape=(None, time, head_dim)
),
pl.BlockSpec(
index_map=q_index_map, block_shape=(None, time, head_dim)
),
pl.BlockSpec(
index_map=q_index_map, block_shape=(None, time, head_dim)
),
],
grid=(batch_size, seq_len // bk),
),
compiler_params={"dimension_semantics": ("parallel", "arbitrary")},
interpret=testing,
out_shape=[
q,
jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32),
jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32),
],
)(*inputs)
return out, (m[..., 0], l[..., 0])
def ragged_mqa_kernel_reference(
layer_ref,
start_ref,
end_ref,
line_end_ref,
pre_b_ref,
pre_i_ref,
q_ref,
k_ref,
v_ref,
k_scaler_ref,
v_scaler_ref,
o_ref,
m_ref,
l_ref,
bk: int,
mask_value: float,
normalize_var: bool,
quantized: bool,
):
"""Pallas kernel for ragged attention."""
b, i = pl.program_id(0), pl.program_id(1)
del layer_ref
@pl.when(i == 0)
def init():
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)
# length = lengths_ref[b]
# Always start from 0, left aligned
length = end_ref[b]
@pl.when(i * bk < length)
def run():
q = q_ref[...].astype(jnp.float32)
k = k_ref[...].astype(jnp.float32)
v = v_ref[...].astype(jnp.float32)
m_prev, l_prev = m_ref[...], l_ref[...]
qk = jax.lax.dot_general(
q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32
)
if normalize_var:
qk = qk / math.sqrt(k.shape[-1]) # Align with meta llama
# Quantized
if quantized:
qk = qk * k_scaler_ref[...]
mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length
qk = qk + jnp.where(mask, 0.0, mask_value)
m_curr = qk.max(axis=-1)
s_curr = jnp.exp(qk - m_curr[..., None])
l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
# Quantized
if quantized:
s_curr = s_curr * v_scaler_ref[...]
o_curr_times_l_curr = jnp.dot(s_curr, v)
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
m_ref[...], l_ref[...] = m_next, l_next_safe
o_ref[...] = (
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe
).astype(o_ref.dtype)
@functools.partial(
jax.jit,
static_argnames=[
"bk",
"mask_value",
"normalize_var",
"testing",
"quantized",
],
)
def ragged_mqa_reference(
q: jax.Array,
k: jax.Array,
v: jax.Array,
layer,
start: jax.Array,
end: jax.Array,
ragged_batch_index=None,
ragged_block_index=None,
k_scaler: jax.Array = None,
v_scaler: jax.Array = None,
bk: int = 512,
mask_value: float = DEFAULT_MASK_VALUE,
normalize_var: bool = True,
testing: bool = False,
quantized: bool = False,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi query attention."""
batch_size, time, head_dim = q.shape
# assert end.shape == (batch_size,)
seq_len = k.shape[-2]
stacked = False
if k.ndim == 4:
stacked = True
def _compute_ragged_block_indices(b, i, lengths_ref):
length = lengths_ref[b]
not_done = i * bk < length
am_last_batch = b == batch_size - 1
# if length < bk, then it's -1, should be 0?
last_good_block = jax.lax.div(length, bk) - 1
# if not done, then still work on b, otherwise next batch
b_next = jnp.where(not_done, b, jnp.where(am_last_batch, b, b + 1))
# if not done, i next = i
# if done
# if last batch, previous good block
# if not last batch, i next = 0
i_next = jnp.where(
not_done, i, jnp.where(am_last_batch, last_good_block, 0)
)
return b_next, i_next
def kv_index_map(b, i, layer_ref, start_ref, end_ref, *_):
b_next, i_next = _compute_ragged_block_indices(b, i, end_ref)
if stacked:
return layer_ref[0], b_next, i_next, 0
return b_next, i_next, 0
def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_):
b_next, i_next = _compute_ragged_block_indices(b, i, end_ref)
if stacked:
return layer_ref[0], b_next, 0, i_next
return b_next, 0, i_next
if stacked:
kv_bp = (None, None, bk, head_dim)
ks_bp = (None, None, 1, bk)
else:
kv_bp = (None, bk, head_dim)
ks_bp = (None, 1, bk)
in_specs = [
pl.BlockSpec(
index_map=lambda b, i, *_: (b, 0, 0),
block_shape=(None, time, head_dim),
), # q
pl.BlockSpec(index_map=kv_index_map, block_shape=kv_bp), # k
pl.BlockSpec(index_map=kv_index_map, block_shape=kv_bp), # v
pl.BlockSpec(index_map=kv_scale_index_map, block_shape=ks_bp), # k_scaler
pl.BlockSpec(index_map=kv_scale_index_map, block_shape=ks_bp), # v_scaler
]
inputs = (
jnp.array([layer]),
start,
end,
end, # line_end, not actually used
ragged_batch_index,
ragged_block_index,
q,
k,
v,
k_scaler,
v_scaler,
)
out, m, l = pl.pallas_call(
functools.partial(
ragged_mqa_kernel_reference,
bk=bk,
mask_value=mask_value,
normalize_var=normalize_var,
quantized=quantized,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=6,
in_specs=in_specs,
out_specs=[
pl.BlockSpec(
index_map=lambda b, *_: (b, 0, 0),
block_shape=(None, time, head_dim),
),
pl.BlockSpec(
index_map=lambda b, *_: (b, 0, 0),
block_shape=(None, time, head_dim),
),
pl.BlockSpec(
index_map=lambda b, *_: (b, 0, 0),
block_shape=(None, time, head_dim),
),
],
grid=(batch_size, seq_len // bk),
),
interpret=testing,
# debug=True,
compiler_params={"dimension_semantics": ("parallel", "arbitrary")},
out_shape=[
q,
jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32),
jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32),
],
)(*inputs)
return out, (m[..., 0], l[..., 0])
@functools.partial(
jax.jit,
static_argnames=[
"bk",
"mask_value",
"normalize_var",
"q_shard_axis",
"kv_shard_axis",
"testing",
],
)
def ragged_mha(
q: jax.Array,
k: jax.Array,
v: jax.Array,
layer,
start: jax.Array,
end: jax.Array,
ragged_batch_index: jax.Array,
ragged_block_index: jax.Array,
k_scaler: jax.Array | None = None,
v_scaler: jax.Array | None = None,
bk: int = 512,
mask_value: float = DEFAULT_MASK_VALUE,
normalize_var: bool = True,
q_shard_axis: int = 0,
kv_shard_axis: int = 0,
testing: bool = False,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi head attention.
Args:
q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array.
k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or
PartitionQuantizedTensor.
v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or
PartitionQuantizedTensor.
start: A i32[batch_size] jax.Array
end: A i32[batch_size] jax.Array
bk: An integer that is the sequence block size.
logit_cap: An optional float that caps logits via tanh. By default there is
no logit capping.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
out_dtype: An optional dtype for the output. If not provided, the output
dtype will be q's dtype.
Returns:
The output of attention([batch_size, num_heads, compute_dim, head_dim]),
along with the max logit ([batch_size, num_heads, compute_dim, 1]) and
softmax denominator ([batch_size, num_heads, compute_dim, 1]).
"""
mask_value = DEFAULT_MASK_VALUE
bk = min(bk, k.shape[-2])
bq, hq, tq, dq = q.shape
hkv = k.shape[-3]
tk = k.shape[-2]
assert k.shape[-1] == q.shape[-1]
assert k.shape[-4] == q.shape[-4]
rep = hq // hkv
if rep > 1:
q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq)
stacked = k.ndim == 5
replicated_in_axes = 7
if k_scaler is None:
quantized = False
if k.ndim == 5:
kv_scale_shape = (k.shape[0], bq, 1, tk)
else:
kv_scale_shape = (bq, 1, tk)
k_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16)
v_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16)
else:
quantized = True
k_scale = jnp.squeeze(k_scaler, -1)
v_scale = jnp.squeeze(v_scaler, -1)
if stacked:
assert k_scale.shape == (k.shape[0], bq, 1, tk)
else:
assert k_scale.shape == (bq, 1, tk)
replicated_inputs = (
ragged_batch_index,
ragged_block_index,
k_scale,
v_scale,
)
# New cache has t=1
with jax.named_scope("ragged_mha_vmap"):
out, (m, l) = jax.vmap(
functools.partial(
# ragged_mqa,
ragged_mqa_reference,
bk=bk,
mask_value=mask_value,
normalize_var=normalize_var,
testing=testing,
quantized=quantized,
# out_dtype=out_dtype,
),
in_axes=(
q_shard_axis,
kv_shard_axis,
kv_shard_axis,
*([None] * replicated_in_axes),
),
out_axes=q_shard_axis,
)(q, k, v, layer, start, end, *replicated_inputs)
return out, (m, l)
def reshape_heads(xq, keys):
"""Reshapes the query head for GQA"""
bq, hq, tq, dq = xq.shape
hkv = keys.shape[-3]
rep = hq // hkv
if rep > 1:
xq = xq.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq)
return xq, rep
def reshape_outputs(rep, o, m=None, d=None):
"""Reshapes back the attention output for GQA"""
bq, hqo, tqo, dq = o.shape
tq = tqo // rep
hq = hqo * rep
o = o.reshape(bq, hqo, rep, tq, dq).reshape(bq, hq, tq, dq)
if m is not None and d is not None:
m = m.reshape(bq, hqo, rep, tq, 1).reshape(bq, hq, tq, 1)
d = d.reshape(bq, hqo, rep, tq, 1).reshape(bq, hq, tq, 1)
return o, (m, d)
def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""
bsz, _, _, head_dim = xq.shape
with jax.named_scope("attn_mat1"):
## Attention start
scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim)
if k_scaler is not None:
scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2]))
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen
with jax.named_scope("attn_soft"):
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
if v_scaler is not None:
scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2]))
with jax.named_scope("attn_mat2"):
output = torch.einsum(
"ikjm,ikml->ikjl", scores, values
) # (bs, n_local_heads, seqlen, head_dim)
return output
def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""
xq, rep = reshape_heads(xq, keys)
output = _dense_attention(xq, keys, values, k_scaler, v_scaler, mask)
output, _ = reshape_outputs(rep, output)
return output
def _flash_attention(
xq,
keys,
values,
layer,
k_scaler=None,
v_scaler=None,
mask=None,
normalize_var=True,
):
"""Flash attention kernel."""
if keys.ndim == 5:
keys = keys[layer]
values = values[layer]
k_scaler = k_scaler[layer] if k_scaler is not None else None
v_scaler = v_scaler[layer] if v_scaler is not None else None
logits = torch.einsum(
"bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32)
)
if normalize_var:
logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama
# Quantized
if k_scaler is not None:
logits = logits * k_scaler.reshape(
k_scaler.shape[-4], 1, 1, k_scaler.shape[-2]
)
# mask = jnp.arange(keys.shape[1])[None] < lengths[:, None]
if mask is not None:
# logits = logits + jnp.where(mask, 0.0, DEFAULT_MASK_VALUE)[:, None]
logits = logits + mask
logits_max, _ = torch.max(logits, axis=-1, keepdim=True)
unnormalized = torch.exp(logits - logits_max)
# Quantized, should not put here, otherwise sum will have this too, which cancels with denominator
# unnormalized = unnormalized * v_scaler
denominator = unnormalized.sum(axis=-1, keepdim=True)
if v_scaler is not None:
unnormalized = unnormalized * v_scaler.reshape(
v_scaler.shape[-4], 1, 1, v_scaler.shape[-2]
)
o = (
torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values)
/ denominator
)
return o, (logits_max, denominator)
def flash_attention(
xq,
keys,
values,
layer,
k_scaler=None,
v_scaler=None,
mask=None,
normalize_var=True,
):
"""Flash attention kernel."""
xq, rep = reshape_heads(xq, keys)
o, (logits_max, denominator) = _flash_attention(
xq=xq,
keys=keys,
values=values,
layer=layer,
k_scaler=k_scaler,
v_scaler=v_scaler,
mask=mask,
normalize_var=normalize_var,
)
return reshape_outputs(rep, o, logits_max, denominator)
class RaggedAttentionKernel:
"""Ragged attention kernel."""
def __init__(
self, env, input_specs, output_specs, q_shard_axis, kv_shard_axis
):
self.binded_ragged_mha = functools.partial(
ragged_mha,
bk=env.block_size,
q_shard_axis=q_shard_axis,
kv_shard_axis=kv_shard_axis,
testing=env.testing,
)
self.binded_ragged_mha = shard_map(
self.binded_ragged_mha,
env.mesh,
input_specs,
output_specs,
check_rep=False,
)
self.binded_ragged_mha = jax.jit(self.binded_ragged_mha)
def __call__(
self,
xq,
keys,
values,
layer,
start,
end,
ragged_batch_index,
ragged_block_index,
k_scaler=None,
v_scaler=None,
):
return self.binded_ragged_mha(
xq,
keys,
values,
layer,
start,
end,
ragged_batch_index,
ragged_block_index,
k_scaler,
v_scaler,
)
def shard_kv_heads(
paged_attention_impl: Callable[..., Any],
mesh: jax.sharding.Mesh,
kv_head_mesh_axis_name: str,
):
"""Shard map on kv head."""
in_specs = (
P(None, kv_head_mesh_axis_name, None), # q
P(kv_head_mesh_axis_name, None, None, None), # k
P(kv_head_mesh_axis_name, None, None, None), # v
P(), # lengths
P(), # page_indices
)
out_specs = P(None, kv_head_mesh_axis_name, None) # q
return jax.jit(
shard_map(
paged_attention_impl,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
)
def call_paged_attention(env, xq, keys, values, seq_lens, page_indices):
"""Paged attention kernel."""
xq, keys, values, seq_lens, page_indices = torchjax.from_torch(
(xq, keys, values, seq_lens, page_indices)
)
paged_attention_impl = functools.partial(
paged_attention,
pages_per_compute_block=env.block_size // env.paged_attention_page_size,
# mask_value=float("-inf")
)
sharded_paged_attention_impl = shard_kv_heads(
paged_attention_impl,
env.mesh,
kv_head_mesh_axis_name="x",
)
output = sharded_paged_attention_impl(
xq, keys, values, seq_lens, page_indices
)
return torchjax.to_torch(output)