Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions jetstream_pt/attention_kernel.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from collections.abc import Callable
import functools
import math
from typing import Any

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
from jax.experimental.shard_map import shard_map

import numpy as np
import torch
import torch.nn.functional as F
from jetstream_pt import torchjax

import numpy as np

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
P = jax.sharding.PartitionSpec


def ragged_flash_attention_kernel(
Expand Down Expand Up @@ -735,3 +739,52 @@ def __call__(
k_scaler,
v_scaler,
)


def shard_kv_heads(
paged_attention_impl: Callable[..., Any],
mesh: jax.sharding.Mesh,
kv_head_mesh_axis_name: str,
):
"""Shard map on kv head."""
in_specs = (
P(None, kv_head_mesh_axis_name, None), # q
P(kv_head_mesh_axis_name, None, None, None), # k
P(kv_head_mesh_axis_name, None, None, None), # v
P(), # lengths
P(), # page_indices
)

out_specs = P(None, kv_head_mesh_axis_name, None) # q

return jax.jit(
shard_map(
paged_attention_impl,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
)


def call_paged_attention(env, xq, keys, values, seq_lens, page_indices):
"""Paged attention kernel."""
xq, keys, values, seq_lens, page_indices = torchjax.from_torch(
(xq, keys, values, seq_lens, page_indices)
)
paged_attention_impl = functools.partial(
paged_attention,
pages_per_compute_block=env.block_size // env.paged_attention_page_size,
# mask_value=float("-inf")
)
sharded_paged_attention_impl = shard_kv_heads(
paged_attention_impl,
env.mesh,
kv_head_mesh_axis_name="x",
)
output = sharded_paged_attention_impl(
xq, keys, values, seq_lens, page_indices
)

return torchjax.to_torch(output)
26 changes: 18 additions & 8 deletions jetstream_pt/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch_xla2

from jetstream_pt import torchjax
from jetstream_pt.page_attention_manager import PageAttentionManager


# pylint: disable-next=all
Expand Down Expand Up @@ -663,18 +664,21 @@ def __init__(
self,
cache_k: torch.Tensor, # previous cache
cache_v: torch.Tensor, # previous cache
page_attention_manager: PageAttentionManager,
page_token_indices: torch.Tensor, # page and token indices for the cache
sharding,
env=None,
):
super().__init__()
self.cache_k = cache_k
self.cache_v = cache_v
self.page_attention_manager = page_attention_manager
self.page_token_indices = page_token_indices
self.sharding = sharding
self.env = env
self.stacked = False

def update(self, key, value):
def update(self, key, value, layer_id=0):
"""Update kv cache"""
keyj, valuej, page_token_indicesj = torchjax.from_torch(
(key, value, self.page_token_indices)
Expand All @@ -683,32 +687,38 @@ def update(self, key, value):
def _update(cache, x):
x = x.squeeze(2).transpose((1, 0, 2))
x = x[:, page_token_indicesj[2], :]
head, _, page_size, dim = cache.shape
head, _, paged_attention_page_size, dim = cache.shape
selected_cache = cache[:, page_token_indicesj[0], :, :]
selected_cache = selected_cache.reshape((head, -1, dim))

selected_cache = selected_cache.at[:, page_token_indicesj[1], :].set(x)
selected_cache = selected_cache.reshape((head, -1, page_size, dim))
selected_cache = selected_cache.reshape(
(head, -1, paged_attention_page_size, dim)
)

cache = cache.at[:, page_token_indicesj[0], :, :].set(selected_cache)
return cache

# pylint: disable-next=all
self.cache_k._elem = _update(self.cache_k._elem, keyj)
# pylint: disable-next=all
self.cache_k._elem = _update(self.cache_v._elem, valuej)
self.cache_v._elem = _update(self.cache_v._elem, valuej)
return self.cache_k, self.cache_v

def state(self):
"""Get kv cache state"""
# pylint: disable-next=all
return self.cache_k.jax(), self.cache_v.jax()
return torchjax.from_torch((self.cache_k, self.cache_v))

def finalize(self):
"""Do nothing now"""
return

@classmethod
def empty(cls, shape, device, bf16_enable, env):
def empty(cls, shape, device, env):
"""Create empty kv caches"""
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
k = jnp.zeros(shape, device=device, dtype=default_dtype)
v = jnp.zeros(shape, device=device, dtype=default_dtype)
k, v = torchjax.to_torch((k, v))
return cls(k, v, None, device, env=env)
return cls(k, v, None, None, device, env=env)
14 changes: 14 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@
"size of top k used when sampling next token",
)

flags.DEFINE_integer(
"paged_attention_total_num_pages",
0,
"total number of pages per layer for page attention",
)

flags.DEFINE_integer(
"paged_attention_page_size",
64,
"page size per page",
)


def create_quantization_config_from_flags():
"""Create Quantization Config from cmd flags"""
Expand Down Expand Up @@ -213,6 +225,8 @@ def create_engine_from_config_flags():
generate_cache_stacked=FLAGS.generate_cache_stacked,
new_cache_stacked=FLAGS.new_cache_stacked,
lazy_cache_update=FLAGS.lazy_cache_update,
paged_attention_total_num_pages=FLAGS.paged_attention_total_num_pages,
paged_attention_page_size=FLAGS.paged_attention_page_size,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down
Loading