-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathlayers.py
856 lines (754 loc) · 26.7 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable-all
"""This version contains modification to make it easier to trace and support batch."""
from typing import Optional, Tuple
import jax
from . import attention_kernel as ak
import jax.numpy as jnp
import torch
import torch.nn.functional as F
import torch_xla2
from jax import lax
from jetstream_pt import torchjax
from jetstream_pt.environment import QuantizationConfig
from jetstream_pt.model_base import ModuleBase
from jetstream_pt.quantize import (
dequantize_tensor,
load_q_weight_helper,
quantize_tensor,
blockwise_jax_kernel,
blockwise_jax_kernel_dot_general,
blockwise_jax_kernel_einsum_flatten,
)
from torch import nn
from . import attention_kernel as ak
from absl import flags
def _calc_cosine_dist(x, y):
x = x.flatten().to(torch.float32)
y = y.flatten().to(torch.float32)
return (torch.dot(x, y) / (x.norm() * y.norm())).item()
import numpy as np
class Int8Embedding(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dims, device="cpu"):
super().__init__()
table = torch.ones(
(num_embeddings, embedding_dims), device=device, dtype=torch.int8
)
self.register_buffer("weight", table)
embedding_scaler = torch.ones(
(embedding_dims,), device=device, dtype=torch.bfloat16
)
self.register_buffer("weight_scaler", embedding_scaler)
def forward(self, input):
return F.embedding(input, self.weight) * self.weight_scaler
class WeightOnlyPerChannelQuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
bias=False,
device=None,
quant_config=QuantizationConfig(),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
weight = torch.ones(
(out_features, in_features), dtype=torch.int8, device=device
)
self.register_buffer("weight", weight)
weight_scaler = torch.ones(
(out_features,), dtype=torch.bfloat16, device=device
)
self.register_buffer("weight_scaler", weight_scaler)
self.is_symmetric_weight = quant_config.is_symmetric_weight
if not self.is_symmetric_weight:
zero_point = torch.ones(
(out_features,), dtype=torch.bfloat16, device=device
)
self.register_buffer("zero_point", zero_point)
else:
self.register_buffer("zero_point", None)
assert not bias, "Quantized Linear doesn't support bias."
# Number of bits of weight tensor
self.n_bit = quant_config.num_bits_weight
# Quantize activation
self.quantize_activation = quant_config.enable_activation_quantization
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
self.run_fake_quantize = False
def _load_quantized_weights(self, w_q, scale, zp=None):
"""
Load weights quantized by 'quantize_tensor'.
"""
self.weight, self.weight_scaler, self.zero_point = load_q_weight_helper(
w_q, scale, zp, block_size=-1
)
def quantize_weight_from_nn_linear(self, weight):
assert weight.dim() == 2, "Expect 2D weight from torch.nn.Linear."
assert weight.shape == (
self.out_features,
self.in_features,
), f"Got unexpected weight of shape {weight.shape}, expected weight shape ({self.out_features}, {self.in_features})."
w_q, scale, zp = quantize_tensor(
weight, (1,), self.n_bit, self.is_symmetric_weight, block_size=-1
)
w_dq = dequantize_tensor(w_q, scale, zp)
self._load_quantized_weights(w_q, scale, zp)
def forward(self, inputs):
if not self.run_fake_quantize:
if self.quantize_activation:
inputs, act_s, _ = quantize_tensor(inputs, reduce_axis=(2,))
if not self.quantize_activation:
result = F.linear(inputs, self.weight)
else:
# We have to call jax because we need to specify the output dtype of dot
# dot(int8, int8)->bf16.
# This semantic cannot be represented in torch. The inferred output dtype
# will be int8 in torch, causing the dot result to overflow.
result = torchjax.call_jax(
jax.lax.dot_general,
inputs,
self.weight,
(((2,), (1)), ((), ())),
None,
jnp.bfloat16.dtype,
)
result = result * self.weight_scaler
if self.quantize_activation:
result = result * act_s
if not self.is_symmetric_weight:
zp_out = torch.einsum("...c,z->...z", inputs, self.zero_point)
result = result - zp_out
return result
else:
# Fake quantization, debugging purpose.
scaler = self.weight_scaler.unsqueeze(-1)
if not self.is_symmetric_weight:
zero_point = self.zero_point.unsqueeze(-1) / scaler
else:
zero_point = None
w_dequantized = dequantize_tensor(
self.weight.to(torch.bfloat16), scaler, zero_point
)
return F.linear(inputs, w_dequantized)
class WeightOnlyBlockwiseQuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
bias=False,
device=None,
quant_config=QuantizationConfig(),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Use dot general instead of einsum
# Use dot general is slow now.
self.use_dot_general = False
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
# Same perf as non flattened one now.
self.flatten = False
self.block_size = quant_config.block_size_weight
n_blocks = in_features // self.block_size
assert (
not quant_config.enable_activation_quantization
), "Activation quantization not supported for blockwise quantized matmul."
if self.use_dot_general:
weight = torch.ones(
(n_blocks, out_features, self.block_size),
dtype=torch.int8,
device=device,
)
else:
weight = torch.ones(
(n_blocks, self.block_size, out_features),
dtype=torch.int8,
device=device,
)
self.register_buffer("weight", weight)
weight_scaler = torch.ones(
(n_blocks, out_features), dtype=torch.bfloat16, device=device
)
self.register_buffer("weight_scaler", weight_scaler)
self.is_symmetric_weight = quant_config.is_symmetric_weight
if not self.is_symmetric_weight:
zero_point = torch.ones(
(n_blocks, out_features), dtype=torch.bfloat16, device=device
)
self.register_buffer("zero_point", zero_point)
else:
self.register_buffer("zero_point", None)
self.n_bit = quant_config.num_bits_weight
# Quantize activation
self.quantize_activation = quant_config.enable_activation_quantization
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
self.run_fake_quantize = False
def _load_quantized_weights(self, w_q, scale, zp=None):
"""
Load weights quantized by 'quantize_tensor'.'
"""
self.weight, self.weight_scaler, self.zero_point = load_q_weight_helper(
w_q, scale, zp, self.block_size
)
def quantize_weight_from_nn_linear(self, weight):
assert weight.dim() == 2, "Expect 2D weight from torch.nn.Linear."
assert weight.shape == (
self.out_features,
self.in_features,
), f"Unexpected weight shape ({self.out_features}, {self.in_features})."
w_q, scale, zp = quantize_tensor(
weight, (1,), self.n_bit, self.is_symmetric_weight, self.block_size
)
w_dq = dequantize_tensor(w_q, scale, zp)
self._load_quantized_weights(w_q, scale, zp)
def forward(self, inputs):
if not self.run_fake_quantize:
if self.use_dot_general or self.flatten:
assert (
self.zero_point is None
), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
blockwise_matmul_kernel = (
blockwise_jax_kernel
if not self.use_dot_general and not self.flatten
else blockwise_jax_kernel_dot_general
if self.use_dot_general
else blockwise_jax_kernel_einsum_flatten
)
result = torchjax.call_jax(
blockwise_matmul_kernel,
inputs,
self.weight,
self.weight_scaler,
self.zero_point,
)
return result
else:
# Fake quantization, debugging purpose.
weight = self.weight.permute(2, 0, 1).to(torch.bfloat16)
scaler = self.weight_scaler.unsqueeze(-1).transpose(1, 0)
if not self.is_symmetric_weight:
zero_point = self.zero_point.unsqueeze(-1).transpose(1, 0) / scaler
else:
zero_point = None
w_dequantized = dequantize_tensor(self.weight, scaler, zero_point)
w_dequantized = w_dequantized.reshape(w_dequantized.shape[0], -1)
return F.linear(inputs, w_dequantized)
def get_quantized_linear_layer(config: "QuantizationConfig"):
if not config.enable_weight_quantization:
return nn.Linear
if config.is_blockwise_weight:
return WeightOnlyBlockwiseQuantizedLinear
else:
return WeightOnlyPerChannelQuantizedLinear
def create_quantized_from_nn_linear(
float_linear: nn.Linear, config: "QuantizationConfig"
):
clazz_ = get_quantized_linear_layer(config)
obj = clazz_(
float_linear.in_features,
float_linear.out_features,
float_linear.bias is not None,
"meta",
config,
)
obj.quantize_weight_from_nn_linear(float_linear.weight)
return obj
def get_quantized_embedding_layer(config: "QuantizationConfig"):
if not config.enable_weight_quantization:
return nn.Embedding
else:
return Int8Embedding
def create_quantized_from_nn_embedding(
float_embedding: nn.Embedding, config: "QuantizationConfig"
):
clazz_ = get_quantized_embedding_layer(config)
obj = clazz_(
float_embedding.num_embeddings,
float_embedding.embedding_dim,
)
weights, scaler, _ = quantize_tensor(float_embedding.weight, 0)
obj.weight = weights
obj.weight_scaler = scaler
return obj
class RMSNorm(torch.nn.Module):
"""RMSNorm module."""
def __init__(self, dim: int, eps: float = 1e-6, device="meta"):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, device=device))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def reshape_for_broadcast(
freqs_cis: torch.Tensor, x: torch.Tensor
) -> torch.Tensor:
ndim = x.ndim
assert 1 < ndim
assert freqs_cis.shape == (
x.shape[0],
x.shape[-3],
x.shape[-1],
), f"freqs_cis: {freqs_cis.shape }, x: {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
shape[0] = x.shape[0] # batch size
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# bs, seqlen, heads, dim
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class AttentionKernel:
def __init__(self, env, layer_id):
self.env = env
self.q_shard_axis = 0 if self.env.shard_on_batch else 1
self.kv_shard_axis = (
0
if self.env.shard_on_batch
else 2
if self.env.generate_cache_stacked
else 1
)
q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads
kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads
others_pspec = self.env.partition_by_axis()
self.dense_attention = ak.dense_attention
self.flash_attention = ak.flash_attention
self.page_attention = ak.call_paged_attention
self.ragged_attention_orig = ak.RaggedAttentionKernel(
env,
input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)),
output_specs=(q_pspec, (q_pspec, q_pspec)),
q_shard_axis=self.q_shard_axis,
kv_shard_axis=self.kv_shard_axis,
)
self.ragged_attention_new = ak.RaggedAttentionKernel(
env,
input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)),
output_specs=(q_pspec, (q_pspec, q_pspec)),
q_shard_axis=self.q_shard_axis,
kv_shard_axis=self.q_shard_axis,
)
self.layer_id = layer_id
def __call__(
self,
xq,
xk,
xv,
mask,
cache,
start=None,
end=None,
ragged_batch_index=None,
ragged_block_index=None,
):
"""
Args:
xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
xk: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim)
xv: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim)
mask: mask with 0 and -inf, or None
cache: CacheManagerInterface object
"""
bsz, num_heads, seqlen, head_dim = xq.shape
num_kv_heads = xk.shape[-3]
kv_head_dim = xk.shape[-1]
n_rep = num_heads // num_kv_heads
def attend(xq, keys, values, local_mask=None):
if keys.ndim == 4:
impl = self.ragged_attention_new
else:
impl = self.ragged_attention_orig
true_len = seqlen
# When GQA is enabled, it not necessary to expand
if (
not (self.env.ragged_mha and n_rep > 1)
and seqlen == 1
and not self.env.page_attention
):
true_len = 2
xq = torch.nn.functional.pad(
xq, (0, 0, 0, true_len - seqlen), "constant", 0
)
local_max = None
local_denom = None
if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1:
local_output, (local_max, local_denom) = torch_xla2.interop.call_jax(
impl,
xq,
keys,
values,
self.layer_id,
start,
end,
ragged_batch_index,
ragged_block_index,
None, # k_scaler
None, # v_scaler
)
elif self.env.flash_attention and seqlen == 1:
with torch_xla2.default_env():
local_output, (local_max, local_denom) = self.flash_attention(
xq=xq,
keys=keys,
values=values,
layer=self.layer_id,
k_scaler=None,
v_scaler=None,
mask=local_mask,
)
elif self.env.page_attention and seqlen == 1:
local_output = self.page_attention(
self.env,
torch.squeeze(xq, 2),
keys,
values,
cache.page_attention_manager.lengths,
cache.page_attention_manager.page_indices,
)
else:
local_output = self.dense_attention(
xq=xq,
keys=keys,
values=values,
k_scaler=None,
v_scaler=None,
mask=local_mask,
)
local_output = local_output.reshape(bsz, num_heads, true_len, head_dim)
if local_max is not None:
local_max = local_max.reshape(bsz, num_heads, true_len, 1)
local_denom = local_denom.reshape(bsz, num_heads, true_len, 1)
if true_len != seqlen:
local_output = local_output[:, :, 0:seqlen, :]
if local_max is not None:
local_max = local_max[:, :, 0:seqlen, :]
if local_denom is not None:
local_denom = local_denom[:, :, 0:seqlen, :]
self.env.apply_sharding(local_output, axis=self.q_shard_axis)
return local_output, (local_max, local_denom)
with jax.named_scope("attn_insert_cache"):
orig_keys, orig_values = cache.update(xk, xv, self.layer_id)
# print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}")
with jax.named_scope("attn_qkv"):
existing_output, (existing_max, existing_denom) = attend(
xq=xq, keys=orig_keys, values=orig_values, local_mask=mask
)
# Updating cache during each step still has very large impact on latency.
# For non flash attention or prefill, existing output contains everything
if not self.env.lazy_cache_update or seqlen > 1:
return existing_output
# For flash attention, existing output contains the existing kv cache generated logits
with jax.named_scope("attn_new_qkv"):
new_output, (new_max, new_denom) = attend(
xq=xq, keys=xk, values=xv, local_mask=None
)
with jax.named_scope("attn_global"):
global_max = torch.max(existing_max, new_max)
alpha = torch.exp(existing_max - global_max)
beta = torch.exp(new_max - global_max)
global_denom = alpha * existing_denom + beta * new_denom
# global_denom = torch.where(global_denom == 0.0, 1.0, global_denom)
attn_out = (
existing_denom * alpha * existing_output
+ beta * new_output * new_denom
) / global_denom
return attn_out
class Int8KVAttentionKernel:
def __init__(self, env, layer_id):
self.env = env
self.q_shard_axis = 0 if self.env.shard_on_batch else 1
self.kv_shard_axis = (
0
if self.env.shard_on_batch
else 2
if self.env.generate_cache_stacked
else 1
)
q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads
kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads
others_pspec = self.env.partition_by_axis()
self.dense_attention = ak.dense_attention
self.flash_attention = ak.flash_attention
self.ragged_attention_orig = ak.RaggedAttentionKernel(
env,
input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)),
output_specs=(q_pspec, (q_pspec, q_pspec)),
q_shard_axis=self.q_shard_axis,
kv_shard_axis=self.kv_shard_axis,
)
self.ragged_attention_new = ak.RaggedAttentionKernel(
env,
input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)),
output_specs=(q_pspec, (q_pspec, q_pspec)),
q_shard_axis=self.q_shard_axis,
kv_shard_axis=self.q_shard_axis,
)
self.layer_id = layer_id
def __call__(
self,
xq,
xk,
xv,
mask,
cache,
start=None,
end=None,
ragged_batch_index=None,
ragged_block_index=None,
):
"""
Args:
xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
xk: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim)
xv: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim)
mask: mask with 0 and -inf, or None
cache: CacheManagerInterface object
"""
bsz, num_heads, seqlen, head_dim = xq.shape
num_kv_heads = xk.shape[-3]
kv_head_dim = xk.shape[-1]
n_rep = num_heads // num_kv_heads
def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
if keys.ndim == 4:
impl = self.ragged_attention_new
else:
impl = self.ragged_attention_orig
true_len = seqlen
# When GQA is enabled, it not necessary to expand
if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1:
true_len = 2
xq = torch.nn.functional.pad(
xq, (0, 0, 0, true_len - seqlen), "constant", 0
)
if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1:
local_output, (local_max, local_denom) = torch_xla2.interop.call_jax(
impl,
xq,
keys,
values,
self.layer_id,
start,
end,
ragged_batch_index,
ragged_block_index,
k_scaler,
v_scaler,
)
elif self.env.flash_attention and seqlen == 1:
with torch_xla2.default_env():
local_output, (local_max, local_denom) = self.flash_attention(
xq=xq,
keys=keys,
values=values,
layer=self.layer_id,
k_scaler=k_scaler,
v_scaler=v_scaler,
mask=local_mask,
)
else:
local_output = self.dense_attention(
xq=xq,
keys=keys,
values=values,
k_scaler=k_scaler,
v_scaler=v_scaler,
mask=local_mask,
)
local_max = None
local_denom = None
local_output = local_output.reshape(bsz, num_heads, true_len, head_dim)
if local_max is not None:
local_max = local_max.reshape(bsz, num_heads, true_len, 1)
local_denom = local_denom.reshape(bsz, num_heads, true_len, 1)
if true_len != seqlen:
local_output = local_output[:, :, 0:seqlen, :]
if local_max is not None:
local_max = local_max[:, :, 0:seqlen, :]
local_denom = local_denom[:, :, 0:seqlen, :]
self.env.apply_sharding(local_output, axis=self.q_shard_axis)
return local_output, (local_max, local_denom)
with jax.named_scope("attn_insert_cache"):
(
orig_keys,
orig_values,
new_key,
new_value,
k_scaler,
v_scaler,
new_k_scaler,
new_v_scaler,
) = cache.update(xk, xv, self.layer_id)
with jax.named_scope("attn_qkv"):
existing_output, (existing_max, existing_denom) = attend(
xq=xq,
keys=orig_keys,
values=orig_values,
k_scaler=k_scaler,
v_scaler=v_scaler,
local_mask=mask,
)
# For non flash attention or prefill, existing output contains everything
if not self.env.lazy_cache_update or seqlen > 1:
return existing_output
# For flash attention, existing output contains the existing kv cache generated logits
with jax.named_scope("attn_new_qkv"):
# At this point, flash attention or ragged attention must have been enabled
new_output, (new_max, new_denom) = attend(
xq, new_key, new_value, new_k_scaler, new_v_scaler, None
)
with jax.named_scope("attn_global"):
global_max = torch.max(existing_max, new_max)
alpha = torch.exp(existing_max - global_max)
beta = torch.exp(new_max - global_max)
global_denom = alpha * existing_denom + beta * new_denom
# global_denom = torch.where(global_denom == 0.0, 1.0, global_denom)
attn_out = (
existing_denom * alpha * existing_output
+ beta * new_output * new_denom
) / global_denom
return attn_out
class Attention(ModuleBase):
"""Attention module."""
def __init__(
self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_rep = self.n_heads // self.n_kv_heads
self.env = env
self.hidden_size = hidden_size
self.layer_id = layer_id
LinearLayer = get_quantized_linear_layer(env.quant_config)
linear_kwargs = {}
if LinearLayer != torch.nn.Linear:
linear_kwargs = {"quant_config": env.quant_config}
self.wo = LinearLayer(
n_heads * self.head_dim,
hidden_size,
bias=False,
device=device,
**linear_kwargs,
)
Kernel = (
Int8KVAttentionKernel
if env.quant_config.enable_kv_quantization
else AttentionKernel
)
self.attention_kernel = Kernel(env, self.layer_id)
self.q_size = n_heads * self.head_dim
self.kv_size = self.n_kv_heads * self.head_dim
if self.env.qkv_fusion:
self._register_load_state_dict_pre_hook(self.load_hook)
self.wqkv = LinearLayer(
hidden_size,
(n_heads + 2 * self.n_kv_heads) * self.head_dim,
bias=False,
device=device,
**linear_kwargs,
)
else:
self.wq = LinearLayer(
hidden_size,
n_heads * self.head_dim,
bias=False,
device=device,
**linear_kwargs,
)
self.wk = LinearLayer(
hidden_size,
self.n_kv_heads * self.head_dim,
bias=False,
device=device,
**linear_kwargs,
)
self.wv = LinearLayer(
hidden_size,
self.n_kv_heads * self.head_dim,
bias=False,
device=device,
**linear_kwargs,
)
def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
cache,
start=None,
end=None,
ragged_batch_index=None,
ragged_block_index=None,
):
with jax.named_scope("attn_linear_before_cache"):
bsz, seqlen = x.shape[0], x.shape[-2]
# qkv fuse
if self.env.qkv_fusion:
xq, xk, xv = self.wqkv(x).split(
[self.q_size, self.kv_size, self.kv_size], dim=-1
)
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
shard_axis = 0 if self.env.shard_on_batch else 2
self.env.apply_sharding(xq, axis=shard_axis)
self.env.apply_sharding(xk, axis=shard_axis)
self.env.apply_sharding(xv, axis=shard_axis)
with jax.named_scope("attn_rope"):
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
xq = xq.transpose(1, 2)
if mask.ndim == 2:
if seqlen == 1:
mask = mask[:, None, None, :]
else:
mask = mask[None, None, :, :]
# if cache is not None and cache.cache_k is not None:
# print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}")
output = self.attention_kernel(
xq=xq,
xk=xk,
xv=xv,
mask=mask,
# cache[self.layer_id],
cache=cache,
start=start,
end=end,
ragged_batch_index=ragged_batch_index,
ragged_block_index=ragged_block_index,
).type_as(xq)
# print(f"output {output.shape}")
output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)