-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathcache_manager.py
724 lines (646 loc) · 23 KB
/
cache_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
import torch
import torch_xla2
from jetstream_pt import torchjax
from jetstream_pt.page_attention_manager import PageAttentionManager
# pylint: disable-next=all
class CacheInterface:
"""Kv cache interface"""
# cache for ONE layer
def update(self, key, value):
"""Update the cache for this key and value.
The key, and val will have shape (Batch, Heads, Seqlen, Head dim)
The cache is free to store them in a different format.
Return the full cache after update.
This cache instance need to know which position / layer is
the update for.
"""
class KVCachePrefill:
"""Prefill kv cache"""
def __init__(self, kv_quantize=False, stacked=False):
self.kv_quantize = kv_quantize
self.cache_k = None
self.cache_v = None
self.stacked = stacked
def update(self, key, value, layer_id):
"""This cache just remembers the stuff."""
self.cache_k = key
self.cache_v = value
if self.kv_quantize: # pretend to be quantized
bsz, _, seq, _ = key.shape
ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16))
return key, value, None, None, ones, ones, None, None
return key, value
def state(self):
"""Get prefill cache state"""
return self.cache_k, self.cache_v
# Placeholder, to match with GenerateCache
def finalize(self):
"""Finalize the cache operation and updates the cache."""
return
# pylint: disable-next=all
def KVCachePrefill_flatten(cache):
return (
torchjax.from_torch((cache.cache_k, cache.cache_v)),
cache.kv_quantize,
)
# pylint: disable-next=all
def KVCachePrefill_unflatten(auxdata, data):
cache = KVCachePrefill(auxdata)
cache_k, cache_v = torchjax.from_torch(data)
cache.cache_k = cache_k
cache.cache_v = cache_v
jax.tree_util.register_pytree_node(
KVCachePrefill, KVCachePrefill_flatten, KVCachePrefill_unflatten
)
class KVCacheGenerate:
"""Kvache generator without quantization"""
# pylint: disable=too-many-instance-attributes
# More than 7 is reasonable in this case.
def __init__(
self,
cache_k: torch.Tensor, # previous cache
cache_v: torch.Tensor, # previous cache
position: int | torch.Tensor, # position to store the cache
sharding,
env=None,
):
super().__init__()
self.cache_k = cache_k
self.cache_v = cache_v
self.input_pos = position
self.sharding = sharding
self.env = env
self.new_ks = None
self.new_vs = None
self.env = env
# Keep this one it's used in the specific model code.
self.stacked = env.generate_cache_stacked
self.batch = jnp.arange(self.env.batch_size)
# The other way is to store the list and loop over to insert in finalize()
if self.env.lazy_cache_update:
if self.env.generate_cache_stacked:
if self.env.new_cache_stacked:
layer, batch, heads, _, dim = self.cache_k.shape
new_dim = (layer, batch, heads, 1, dim)
self.new_ks, self.new_vs = torchjax.to_torch(
(
jnp.zeros(new_dim, dtype=self.env.default_type),
jnp.zeros(new_dim, dtype=self.env.default_type),
)
)
else:
self.new_ks, self.new_vs = [], []
else: # when generate cache is not stacked, new cache cannot stack
assert not self.env.new_cache_stacked
cache_pspec = self.env.partition_by_axis(
self.env.cache_sharding_axis
) # Number of heads
none_pspec = self.env.partition_by_axis()
in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec, none_pspec)
out_specs = (cache_pspec, cache_pspec)
self.update_single_cache_line = jax.jit(
shard_map(
self.update_single_cache_line,
self.env.mesh,
in_specs,
out_specs,
check_rep=False,
)
)
# pylint: disable=method-hidden
# False alarm. The jit above doesn't hide this method.
def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos):
"""The shard map version of single cache line update."""
b = cache_k.shape[-4]
for bb, pp in enumerate(pos.reshape(b)):
slice_dim = 0
update_start_indices = (bb, 0, pp, 0)
if self.env.generate_cache_stacked:
if self.env.new_cache_stacked:
slice_dim = 1
update_start_indices = (0, bb, 0, pp, 0)
# We are not handling generate_cache_stacked=True new_cache_stacked=False here
new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim)
new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim)
cache_k = jax.lax.dynamic_update_slice(
cache_k, new_ks_slice, update_start_indices
)
cache_v = jax.lax.dynamic_update_slice(
cache_v, new_vs_slice, update_start_indices
)
return cache_k, cache_v
def finalize(self):
"""Finalize the cache operation and updates the cache."""
if not self.env.lazy_cache_update:
return
if self.env.ring_buffer:
# Assume no cache stack for ring buffer
# pylint: disable-next=all
self.cache_k._elem = (
self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax())
)
# pylint: disable-next=all
self.cache_v._elem = (
self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax())
)
else:
if self.env.generate_cache_stacked:
_, b, head, _, dim = self.cache_k.shape
if self.env.new_cache_stacked:
self.cache_k, self.cache_v = torch_xla2.interop.call_jax(
self.update_single_cache_line,
self.cache_k,
self.cache_v,
self.new_ks,
self.new_vs,
self.input_pos,
)
else:
for i in range(self.env.num_layers):
# pylint: disable-next=all
self.cache_k._elem = (
self.cache_k.jax()
.at[i, self.batch, :, self.input_pos, :]
.set(self.new_ks[i].jax().reshape(b, head, dim))
)
# pylint: disable-next=all
self.cache_v._elem = (
self.cache_v.jax()
.at[i, self.batch, :, self.input_pos, :]
.set(self.new_vs[i].jax().reshape(b, head, dim))
)
else:
# Try to use shard_map to get rid of the data copy
self.cache_k, self.cache_v = torch_xla2.interop.call_jax(
self.update_single_cache_line,
self.cache_k,
self.cache_v,
self.new_ks,
self.new_vs,
self.input_pos,
)
def update(self, key, value, layer_id: int):
"""Update kv cache"""
keyj, valuej = torchjax.to_torch((key, value))
if self.env.lazy_cache_update:
if self.env.new_cache_stacked:
assert (
self.env.generate_cache_stacked
), "When new cache stacked, must have generate_cache_stacked!"
self.new_ks[layer_id, ...] = keyj
self.new_vs[layer_id, ...] = valuej
return self.cache_k[layer_id], self.cache_v[layer_id]
# Generate cache stacked, but new cache unstacked
if self.env.generate_cache_stacked:
self.new_ks.append(keyj)
self.new_vs.append(valuej)
return self.cache_k[layer_id], self.cache_v[layer_id]
# all cache unstacked
self.new_ks = keyj
self.new_vs = valuej
return self.cache_k, self.cache_v
if self.env.ring_buffer:
assert (
not self.env.new_cache_stacked and not self.env.generate_cache_stacked
), "Ring buffer doesn't support stacked cache."
# pylint: disable-next=all
self.cache_k._elem = (
self.cache_k.jax().at[..., self.input_pos, :].set(keyj)
)
# pylint: disable-next=all
self.cache_v._elem = (
self.cache_v.jax().at[..., self.input_pos, :].set(valuej)
)
return self.cache_k, self.cache_v
# Non lazy cache update, non ring buffer, generate cache stacked
if self.env.generate_cache_stacked:
# pylint: disable-next=all
self.cache_k._elem = (
self.cache_k.jax()
.at[layer_id, self.batch, :, self.input_pos, :]
.set(keyj.squeeze(2))
)
# pylint: disable-next=all
self.cache_v._elem = (
self.cache_v.jax()
.at[layer_id, self.batch, :, self.input_pos, :]
.set(valuej.squeeze(2))
)
return self.cache_k[layer_id], self.cache_v[layer_id]
# Non lazy cache update, non ring buffer, generate cache non stacked
# pylint: disable-next=all
self.cache_k._elem = (
self.cache_k.jax()
.at[self.batch, :, self.input_pos, :]
.set(keyj.squeeze(2))
)
# pylint: disable-next=all
self.cache_v._elem = (
self.cache_v.jax()
.at[self.batch, :, self.input_pos, :]
.set(valuej.squeeze(2))
)
return self.cache_k, self.cache_v
def state(self):
"""Get kv cache state"""
return self.cache_k.jax(), self.cache_v.jax()
@classmethod
def empty(cls, shape, device, env):
"""Create empty kv caches"""
default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
in_shape = shape
if env.testing:
key = jax.random.key(env.testing_seed)
k_key, v_key = jax.random.split(key)
k = jax.random.uniform(k_key, shape=in_shape, dtype=default_dtype)
v = jax.random.uniform(v_key, shape=in_shape, dtype=default_dtype)
else:
k = jnp.zeros(in_shape, device=device, dtype=default_dtype)
v = jnp.zeros(in_shape, device=device, dtype=default_dtype)
k, v = torchjax.to_torch((k, v))
return cls(k, v, 0, device, env=env)
# pylint: disable-next=all
def KVCacheGenerate_flatten(cache):
return ((cache.cache_k.jax(), cache.cache_v.jax())), (
cache.pos.jax(),
cache.sharding.jax(),
)
# pylint: disable-next=all
def KVCacheGenerate_unflatten(auxdata, data):
position, sharding = auxdata
cache_k, cache_v = torchjax.to_torch(data)
cache = KVCacheGenerate(cache_k, cache_v, position, sharding)
return cache
jax.tree_util.register_pytree_node(
KVCacheGenerate, KVCacheGenerate_flatten, KVCacheGenerate_unflatten
)
class Int8KVCacheGenerate:
"""Int8 quantized kvache with scalers"""
# pylint: disable=too-many-instance-attributes
# More than 7 is reasonable in this case.
def __init__(
self,
cache_k,
cache_v,
cache_k_scaler,
cache_v_scaler,
input_pos, # used to write cache
sharding=None,
env=None,
):
super().__init__()
self.cache_k = cache_k
self.cache_v = cache_v
self.k_scaler = cache_k_scaler
self.v_scaler = cache_v_scaler
self.new_ks = None
self.new_vs = None
self.new_k_scaler = None
self.new_v_scaler = None
self.batch = jnp.arange(env.batch_size)
self.input_pos = input_pos
self.sharding = sharding
self.env = env
self.stacked = env.generate_cache_stacked
if self.env.lazy_cache_update:
if self.env.generate_cache_stacked:
layer, batch, heads, _, dim = self.cache_k.shape
new_kv_dim = (layer, batch, heads, 1, dim)
self.new_ks, self.new_vs = torchjax.to_torch(
(
jnp.zeros(new_kv_dim, dtype=jnp.int8),
jnp.zeros(new_kv_dim, dtype=jnp.int8),
)
)
if self.env.new_cache_stacked:
new_scale_dim = (layer, batch, 1, 1, 1)
self.new_k_scaler, self.new_v_scaler = torchjax.to_torch(
(
jnp.zeros(new_scale_dim, dtype=self.env.default_type),
jnp.zeros(new_scale_dim, dtype=self.env.default_type),
)
)
else:
self.new_ks, self.new_vs, self.new_k_scaler, self.new_v_scaler = (
[],
[],
[],
[],
)
else: # when generate cache is not stacked, new cache cannot stack
assert not self.env.new_cache_stacked
cache_pspec = self.env.partition_by_axis(
self.env.cache_sharding_axis
) # Number of heads
new_cache_pspec = (
self.env.partition_by_axis(2)
if self.env.new_cache_stacked
else self.env.partition_by_axis(1)
)
none_pspec = self.env.partition_by_axis()
in_specs = (
*([cache_pspec] * 2),
*([new_cache_pspec] * 2),
*([none_pspec] * 5),
)
out_specs = (cache_pspec, cache_pspec, none_pspec, none_pspec)
self.update_single_cache_line = shard_map(
self.update_single_cache_line,
self.env.mesh,
in_specs,
out_specs,
check_rep=False,
)
self.update_single_cache_line = jax.jit(self.update_single_cache_line)
# pylint: disable=method-hidden
# False alarm. The jit above doesn't hide this method.
def update_single_cache_line(
self,
cache_k,
cache_v,
new_ks,
new_vs,
k_scaler,
v_scaler,
new_k_scaler,
new_v_scaler,
pos,
):
"""The shard map version of single cache line update."""
b = cache_k.shape[-4]
for bb, pp in enumerate(pos.reshape(b)):
slice_dim = 0
update_start_indices = (bb, 0, pp, 0)
if self.env.generate_cache_stacked:
if self.env.new_cache_stacked:
slice_dim = 1
update_start_indices = (0, bb, 0, pp, 0)
if self.env.generate_cache_stacked and not self.env.new_cache_stacked:
for layer in range(self.env.num_layers):
update_start_indices = (layer, bb, 0, pp, 0)
new_ks_slice = jax.lax.dynamic_slice_in_dim(
new_ks[layer], bb, 1, slice_dim
)
new_ks_slice = jnp.expand_dims(new_ks_slice, 0)
cache_k = jax.lax.dynamic_update_slice(
cache_k, new_ks_slice, update_start_indices
)
new_vs_slice = jax.lax.dynamic_slice_in_dim(
new_vs[layer], bb, 1, slice_dim
)
new_vs_slice = jnp.expand_dims(new_vs_slice, 0)
cache_v = jax.lax.dynamic_update_slice(
cache_v, new_vs_slice, update_start_indices
)
new_k_scaler_slice = jax.lax.dynamic_slice_in_dim(
new_k_scaler[layer], bb, 1, slice_dim
)
new_k_scaler_slice = jnp.expand_dims(new_k_scaler_slice, 0)
k_scaler = jax.lax.dynamic_update_slice(
k_scaler, new_k_scaler_slice, update_start_indices
)
new_v_scaler_slice = jax.lax.dynamic_slice_in_dim(
new_v_scaler[layer], bb, 1, slice_dim
)
new_v_scaler_slice = jnp.expand_dims(new_v_scaler_slice, 0)
v_scaler = jax.lax.dynamic_update_slice(
v_scaler, new_v_scaler_slice, update_start_indices
)
else:
new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim)
cache_k = jax.lax.dynamic_update_slice(
cache_k, new_ks_slice, update_start_indices
)
new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim)
cache_v = jax.lax.dynamic_update_slice(
cache_v, new_vs_slice, update_start_indices
)
new_k_scaler_slice = jax.lax.dynamic_slice_in_dim(
new_k_scaler, bb, 1, slice_dim
)
k_scaler = jax.lax.dynamic_update_slice(
k_scaler, new_k_scaler_slice, update_start_indices
)
new_v_scaler_slice = jax.lax.dynamic_slice_in_dim(
new_v_scaler, bb, 1, slice_dim
)
v_scaler = jax.lax.dynamic_update_slice(
v_scaler, new_v_scaler_slice, update_start_indices
)
return cache_k, cache_v, k_scaler, v_scaler
def state(self):
"""Get kv cache state"""
return torchjax.from_torch((self.cache_k, self.cache_v))
def scalers(self):
"""Get kv cache scalers"""
return torchjax.from_torch((self.k_scaler, self.v_scaler))
@classmethod
# pylint: disable-next=all
def empty(cls, shape, device, env):
"""Create empty kv caches"""
cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8)
cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8)
if env.generate_cache_stacked:
s_shape = (shape[0], shape[1], 1, shape[3], 1)
else:
s_shape = (shape[0], 1, shape[2], 1)
kscaler = jnp.ones(s_shape, dtype=jnp.bfloat16)
vscaler = jnp.ones(s_shape, dtype=jnp.bfloat16)
cache_k, cache_v, kscaler, vscaler = torchjax.to_torch(
(cache_k, cache_v, kscaler, vscaler)
)
return cls(cache_k, cache_v, kscaler, vscaler, 0, device, env=env)
def quantize(self, val):
"""Quantize value"""
# val is (batch, heads, seqlen, dim)
scale = torch.amax(val.abs(), axis=(-3, -1), keepdim=True)
scale = scale / 127
return (val / scale).to(torch.int8), scale
def update(self, xk, xv, layer_id: int):
"""Update kv cache"""
k_quant, kscale = self.quantize(xk)
v_quant, vscale = self.quantize(xv)
if self.env.lazy_cache_update:
if self.env.new_cache_stacked:
self.new_ks[layer_id, ...] = k_quant
self.new_vs[layer_id, ...] = v_quant
self.new_k_scaler[layer_id, ...] = kscale
self.new_v_scaler[layer_id, ...] = vscale
else:
if self.env.generate_cache_stacked:
self.new_ks.append(k_quant)
self.new_vs.append(v_quant)
self.new_k_scaler.append(kscale)
self.new_v_scaler.append(vscale)
else:
self.new_ks = k_quant
self.new_vs = v_quant
self.new_k_scaler = kscale
self.new_v_scaler = vscale
elif self.env.ring_buffer:
self.cache_k[:, :, self.input_pos, :] = k_quant
self.cache_v[:, :, self.input_pos, :] = v_quant
self.k_scaler[:, :, self.input_pos, :] = kscale
self.v_scaler[:, :, self.input_pos, :] = vscale
else:
# We don't handle left aligned but lazy_cache_update=False
self.cache_k[self.batch, :, self.input_pos, :] = k_quant.squeeze(2)
self.cache_v[self.batch, :, self.input_pos, :] = v_quant.squeeze(2)
self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2)
self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2)
return (
self.cache_k,
self.cache_v,
k_quant,
v_quant,
self.k_scaler,
self.v_scaler,
kscale,
vscale,
)
def finalize(self):
"""Finalize the cache operation and updates the cache."""
if not self.env.lazy_cache_update:
return
if self.env.ring_buffer:
# Assume no cache stack for ring buffer
# pylint: disable-next=all
self.cache_k._elem = (
self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax())
)
# pylint: disable-next=all
self.cache_v._elem = (
self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax())
)
else:
if self.env.generate_cache_stacked:
if self.env.new_cache_stacked:
# new kv scaler also has to go through shard_map instead of indexing
# because it needs to reshape to (batch, layer) which mess up with the data
caches = [
self.cache_k,
self.cache_v,
self.new_ks,
self.new_vs,
self.k_scaler,
self.v_scaler,
self.new_k_scaler,
self.new_v_scaler,
]
(
self.cache_k,
self.cache_v,
self.k_scaler,
self.v_scaler,
) = torch_xla2.interop.call_jax(
self.update_single_cache_line, *caches, self.input_pos
)
else:
caches = [
self.cache_k,
self.cache_v,
self.new_ks,
self.new_vs,
self.k_scaler,
self.v_scaler,
self.new_k_scaler,
self.new_v_scaler,
]
(
self.cache_k,
self.cache_v,
self.k_scaler,
self.v_scaler,
) = torch_xla2.interop.call_jax(
self.update_single_cache_line, *caches, self.input_pos
)
else:
(
self.cache_k,
self.cache_v,
self.k_scaler,
self.v_scaler,
) = torch_xla2.interop.call_jax(
self.update_single_cache_line,
self.cache_k,
self.cache_v,
self.new_ks,
self.new_vs,
self.k_scaler,
self.v_scaler,
self.new_k_scaler,
self.new_v_scaler,
self.input_pos,
)
class PageKVCacheGenerate:
"""Page attention kvache generator without quantization"""
def __init__(
self,
cache_k: torch.Tensor, # previous cache
cache_v: torch.Tensor, # previous cache
page_attention_manager: PageAttentionManager,
page_token_indices: torch.Tensor, # page and token indices for the cache
sharding,
env=None,
):
super().__init__()
self.cache_k = cache_k
self.cache_v = cache_v
self.page_attention_manager = page_attention_manager
self.page_token_indices = page_token_indices
self.sharding = sharding
self.env = env
self.stacked = False
def update(self, key, value, layer_id=0):
"""Update kv cache"""
keyj, valuej, page_token_indicesj = torchjax.from_torch(
(key, value, self.page_token_indices)
)
def _update(cache, x):
x = x.squeeze(2).transpose((1, 0, 2))
x = x[:, page_token_indicesj[2], :]
head, _, paged_attention_page_size, dim = cache.shape
selected_cache = cache[:, page_token_indicesj[0], :, :]
selected_cache = selected_cache.reshape((head, -1, dim))
selected_cache = selected_cache.at[:, page_token_indicesj[1], :].set(x)
selected_cache = selected_cache.reshape(
(head, -1, paged_attention_page_size, dim)
)
cache = cache.at[:, page_token_indicesj[0], :, :].set(selected_cache)
return cache
# pylint: disable-next=all
self.cache_k._elem = _update(self.cache_k._elem, keyj)
# pylint: disable-next=all
self.cache_v._elem = _update(self.cache_v._elem, valuej)
return self.cache_k, self.cache_v
def state(self):
"""Get kv cache state"""
# pylint: disable-next=all
return torchjax.from_torch((self.cache_k, self.cache_v))
def finalize(self):
"""Do nothing now"""
return
@classmethod
def empty(cls, shape, device, env):
"""Create empty kv caches"""
default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
k = jnp.zeros(shape, device=device, dtype=default_dtype)
v = jnp.zeros(shape, device=device, dtype=default_dtype)
k, v = torchjax.to_torch((k, v))
return cls(k, v, None, None, device, env=env)