From f61641faa2a608d80753dfa70af5773fdcfc8e44 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 2 Jul 2024 19:56:28 +0000 Subject: [PATCH 01/57] Almost working except mask, need to rebase to main to pick up the the ring buffer support then fix the mask. Int8 updates also included but not tested. --- jetstream_pt/attention_kernel.py | 58 +++++++ jetstream_pt/cache_manager.py | 121 +++++++++++--- jetstream_pt/engine.py | 67 +++++--- jetstream_pt/environment.py | 42 ++++- jetstream_pt/layers.py | 155 ++++++++++++++---- jetstream_pt/third_party/gemma/model.py | 8 +- .../third_party/llama/model_exportable.py | 17 +- jetstream_pt/third_party/mixtral/model.py | 5 +- keys_original | Bin 0 -> 66746 bytes original_scores | Bin 0 -> 2244 bytes tests/test_model_impl.py | 1 + 11 files changed, 379 insertions(+), 95 deletions(-) create mode 100644 keys_original create mode 100644 original_scores diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 96bb4233..1d213959 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -309,6 +309,64 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): output = torch.einsum("ikjm,ikml->ikjl", scores, values) return output +def flash_attention(xq, keys, values, mask=None, normalize_var=True): + """The vanilla attention kernel implementation.""" + import pdb; pdb.set_trace() + # mask_value: float = DEFAULT_MASK_VALUE + logits = torch.einsum( + "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) + ) + + if normalize_var: + logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama + if mask is not None: + # logits = logits + torch.where(mask, 0.0, mask_value)[:, None] + logits = logits + mask + + logits_max, _ = torch.max(logits, axis=-1, keepdim=True) + # unnormalized = torch.exp(logits - logits_max[..., None]) + unnormalized = torch.exp(logits - logits_max) + denominator = unnormalized.sum(axis=-1, keepdim=True) + # print(f"logits {logits.shape} logits_max {logits_max.shape} denominator {denominator}") + o = ( + torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(values), values) + # / denominator[..., None] + / denominator + ) + return o, (logits_max, denominator) + + +def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, normalize_var=True): + mask_value: float = DEFAULT_MASK_VALUE + logits = torch.einsum( + "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) + ) + + if normalize_var: + logits = logits / torch.sqrt(keys.shape[-1]) # Align with meta llama + # Quantized + logits = logits * k_scaler + + # mask = jnp.arange(keys.shape[1])[None] < lengths[:, None] + if mask is not None: + # logits = logits + jnp.where(mask, 0.0, DEFAULT_MASK_VALUE)[:, None] + logits = logits + mask + + logits_max = torch.max(axis=-1, keepdim=True) + unnormalized = torch.exp(logits - logits_max) + #Quantized, should not put here, otherwise sum will have this too, which cancels with denominator + # unnormalized = unnormalized * v_scaler + + denominator = unnormalized.sum(axis=-1, keepdim=True) + unnormalized = unnormalized * v_scaler + + o = ( + jnp.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(values), values) + / denominator + ) + + return o, (logits_max, denominator) + class RaggedAttentionKernel: """Ragged attention kernel.""" diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 13789f91..aa2b7d58 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -38,12 +38,13 @@ def update(self, key, value): class KVCachePrefill: """Prefill kv cache""" - def __init__(self, kv_quantize=False): + def __init__(self, kv_quantize=False, stacked=False): self.kv_quantize = kv_quantize self.cache_k = None self.cache_v = None + self.stacked = stacked - def update(self, key, value): + def update(self, key, value, layer_id): """This cache just remembers the stuff.""" self.cache_k = key self.cache_v = value @@ -100,24 +101,54 @@ def __init__( self.sharding = sharding self.env = env - def update(self, key, value): - """Update kv cache""" - keyj, valuej = torchjax.to_torch((key, value)) + self.new_ks = None + self.new_vs = None + self.env = env + self.stacked = env.generate_cache_stacked + # The other way is to store the list and loop over to insert in finalize() + if self.stacked: + layer, batch, heads, time, dim = self.cache_k.shape + self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros((layer, batch, heads, 1, dim)), jnp.zeros((layer, batch, heads, 1, dim)))) + + def finalize(self): + if not self.stacked: + return + # self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_ks._elem, -2)) + # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_vs._elem, -2)) if self.env.ring_buffer: - # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) - # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej) + self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(self.new_ks._elem) + self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(self.new_vs._elem) else: batch = jnp.arange(self.env.batch_size) - # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set( - keyj.squeeze(2) - ) - # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set( - valuej.squeeze(2) - ) + self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos].set(self.new_ks._elem) + self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos].set(self.new_vs._elem) + + def update(self, key, value, layer_id:int): + """Update kv cache""" + # Will process in insert() at the end of the transformer forward pass + keyj, valuej = torchjax.to_torch((key, value)) + if self.stacked: + self.new_ks[layer_id, :, :, :, :] = keyj + self.new_vs[layer_id, :, :, :, :] = valuej + # self.new_ks.append(value) + # self.new_vs.append(value) + return self.cache_k[layer_id], self.cache_v[layer_id] + else: + if self.env.ring_buffer: + # pylint: disable-next=all + self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) + # pylint: disable-next=all + self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej) + else: + batch = jnp.arange(self.env.batch_size) + # pylint: disable-next=all + self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set( + keyj.squeeze(2) + ) + # pylint: disable-next=all + self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set( + valuej.squeeze(2) + ) return self.cache_k, self.cache_v def state(self): @@ -126,11 +157,12 @@ def state(self): return self.cache_k.jax(), self.cache_v.jax() @classmethod - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" - default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 - k = jnp.zeros(shape, device=device, dtype=default_dtype) - v = jnp.zeros(shape, device=device, dtype=default_dtype) + default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 + in_shape = shape + k = jnp.zeros(in_shape, device=device, dtype=default_dtype) + v = jnp.zeros(in_shape, device=device, dtype=default_dtype) k, v = torchjax.to_torch((k, v)) return cls(k, v, 0, device, env=env) @@ -178,6 +210,11 @@ def __init__( self.input_pos = input_pos self.sharding = sharding self.env = env + self.stacked = env.generate_cache_stacked + + if self.stacked: + layer, batch, heads, len, dim = self.cache_k.shape + self.new_ks, self.new_vs, self.new_k_scalers, self.new_v_scalers = torchjax.to_torch((jnp.zeros((layer, batch, heads, 1, dim)), jnp.zeros((layer, batch, heads, 1, dim)), jnp.zeros((layer, batch, 1, 1, 1)), jnp.zeros((layer, batch, 1, 1, 1)))) def state(self): """Get kv cache state""" @@ -189,13 +226,17 @@ def scalers(self): @classmethod # pylint: disable-next=all - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) - # bf16_enable is a placeholder parameter, it's not used in Int8KVCache - kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) - vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) + + if env.generate_cache_stacked: + kscaler = jnp.ones((shape[0], shape[1], 1, shape[2], 1), dtype=jnp.bfloat16) + vscaler = jnp.ones((shape[0], shape[1], 1, shape[2], 1), dtype=jnp.bfloat16) + else: + kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) + vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) @@ -209,10 +250,19 @@ def quantize(self, val): scale = scale / 127 return (val / scale).to(torch.int8), scale - def update(self, xk, xv): + def update(self, xk, xv, layer_id:int): """Update kv cache""" k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) + + if self.stacked: + self.new_ks[layer_id, ...] = k_quant + self.new_vs[layer_id, ...] = v_quant + self.new_k_scalers[layer_id, ...] = kscale + self.new_v_scalers[layer_id, ...] = vscale + + return self.cache_k[layer_id], self.cache_v[layer_id], k_quant, v_quant, self.k_scaler[layer_id], self.v_scaler[layer_id], kscale, vscale + if self.env.ring_buffer: self.cache_k[:, :, self.input_pos, :] = k_quant self.cache_v[:, :, self.input_pos, :] = v_quant @@ -224,4 +274,21 @@ def update(self, xk, xv): self.cache_v[batch, :, self.input_pos, :] = v_quant.squeeze(2) self.k_scaler[batch, :, self.input_pos, :] = kscale.squeeze(2) self.v_scaler[batch, :, self.input_pos, :] = vscale.squeeze(2) - return self.cache_k, self.cache_v, self.k_scaler, self.v_scaler + return self.cache_k, self.cache_v, k_quant, v_quant, self.k_scaler, self.v_scaler, kscale, vscale + + def finalize(self): + if not self.stacked: + return + # self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_ks._elem, -2)) + # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_vs._elem, -2)) + if self.env.ring_buffer: + self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(self.new_ks._elem) + self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(self.new_vs._elem) + self.k_scaler._elem = self.k_scaler._elem.at[:, :, :, self.pos].set(self.new_k_scalers._elem) + self.v_scaler._elem = self.v_scaler._elem.at[:, :, :, self.pos].set(self.new_v_scalers._elem) + else: + batch = jnp.arange(self.env.batch_size) + self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos].set(self.new_ks._elem) + self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos].set(self.new_vs._elem) + self.k_scaler._elem = self.k_scaler._elem.at[:, batch, :, self.pos].set(self.new_k_scalers._elem) + self.v_scaler._elem = self.v_scaler._elem.at[:, batch, :, self.pos].set(self.new_v_scalers._elem) \ No newline at end of file diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 0b27db3d..a8fd4397 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -195,7 +195,10 @@ def _call_model_generate( # The mode is needed so that tensors created inside of # the model (such as via torch.ones etc) also have the right type res = torch.func.functional_call(self.pt_model, paramst, argst) - updated_caches = [c.state() for c in caches_obj] + updated_caches = [] + for c in caches_obj: + c.finalize() + updated_caches.append(c.state()) scales = [] if self.env.quant_config.enable_kv_quantization: scales = [c.scalers() for c in caches_obj] @@ -328,7 +331,7 @@ def _insert_no_wrap( tokens = decode_state.tokens.at[slot].set(prefix.token) x = jnp.arange(0, self.env.cache_sequence_length) - cond = jnp.logical_and(x <= current_pos, x >= pos) + cond = jnp.logical_and(x < decode_state.current_position, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) start = decode_state.start.at[slot].set( @@ -338,19 +341,28 @@ def _insert_no_wrap( if not self.env.quant_config.enable_kv_quantization: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, new_entry): + def insert(cache, new_entry, update_index): res = jax.lax.dynamic_update_slice( cache, new_entry, - [slot, 0, pos, 0], + update_index, ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res - caches = [ - (insert(k, newk), insert(v, newv)) - for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches) - ] + if self.env.generate_cache_stacked: + caches = decode_state.caches + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + newk = jnp.expand_dims(newk, 0) + newv = jnp.expand_dims(newv, 0) + caches = [(insert(caches[0][0], newk, update_index),insert(caches[0][1], newv, update_index))] + else: + update_index = [slot, 0, pos, 0] + caches = [ + (insert(k, newk, update_index), insert(v, newv, update_index)) + for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches) + ] else: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) @@ -416,10 +428,10 @@ def _insert_wrap( cond = jax.lax.cond( decode_state.current_position > start_insert, lambda x, start_insert, current_position: jnp.logical_and( - x >= start_insert, x <= current_position + x >= start_insert, x < current_position ), lambda x, start_insert, current_position: jnp.logical_or( - x >= start_insert, x <= current_position + x >= start_insert, x < current_position ), x, start_insert, @@ -435,21 +447,25 @@ def _insert_wrap( old_scales = decode_state.cache_scales cache_inserts = prefix.caches + print(f"YY old_caches: {len(decode_state.caches)} cache_inserts: {len(cache_inserts)}") scales = [] caches = [] if not self.env.quant_config.enable_kv_quantization: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, new_entry): + def insert(cache, new_entry, layer_id): new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2)) - res = cache.at[slot, :, update_indexes, :].set(new_entry) + if self.env.generate_cache_stacked: + res = cache.at[layer_id, slot, :, update_indexes, :].set(new_entry) + else: + res = cache.at[slot, :, update_indexes, :].set(new_entry) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res - caches = [ - (insert(k, newk), insert(v, newv)) - for (k, v), (newk, newv) in zip(old_caches, cache_inserts) - ] + for idx, (newk, newv) in enumerate(prefix.caches): + caches = [ + (insert(old_caches[0][0], newk, idx), insert(old_caches[0][1], newv, idx)) + ] else: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) @@ -580,11 +596,9 @@ def generate( pos = decode_state.current_position if self.env.ring_buffer: input_indexes = jnp.full((1,), pos) - mask = decode_state.mask.at[:, decode_state.current_position].set(0) else: input_indexes = decode_state.input_pos - batch = jnp.arange(self.env.batch_size) - mask = decode_state.mask.at[batch, decode_state.input_pos].set(0) + ragged_batch_index, ragged_block_index = ( self.precompute_ragged_block_indices(decode_state) ) @@ -592,6 +606,17 @@ def generate( (-1) ), ragged_block_index.reshape((-1)) + + def update_mask(): + if self.env.ring_buffer: + return decode_state.mask.at[:, decode_state.current_position].set(0) + + batch = jnp.arange(self.env.batch_size) + return decode_state.mask.at[batch, decode_state.input_pos].set(0) + + mask = decode_state.mask + if not self.env.flash_attention: + mask = update_mask(mask, decode_state.current_position) logits, new_caches, new_scales = self._call_model_generate( params, decode_state.tokens, @@ -605,6 +630,10 @@ def generate( ragged_block_index, ) + if self.env.flash_attention: + # fill mask later, now use flash attention + mask = update_mask() + next_token = self._sampling(logits, self.env.batch_size) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index fb1b99ba..a9f7bae9 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -96,6 +96,13 @@ class JetEngineEnvironmentData: # Ring buffer ring_buffer: bool = True + # Ring buffer + ring_buffer: bool = True + + # + flash_attention: bool = True + + generate_cache_stacked: bool = False # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") sampling_algorithm: str = "greedy" @@ -122,6 +129,9 @@ def __init__(self, data: JetEngineEnvironmentData): self.ragged_mha = self._data.ragged_mha self.block_size = self._data.block_size self.starting_position = self._data.starting_position + self.flash_attention = self._data.flash_attention + self.generate_cache_stacked = self._data.generate_cache_stacked + self.num_layers = self._data.num_layers self.ring_buffer = self._data.ring_buffer P = jax.sharding.PartitionSpec @@ -136,12 +146,23 @@ def __init__(self, data: JetEngineEnvironmentData): self.x_sharding = jsharding.NamedSharding(self.mesh, P("x")) self.replicated = jsharding.NamedSharding(self.mesh, P()) + if self.generate_cache_stacked: + + self.attention_kv_axis_names = ( + "layer", + "batch", + "num_attn_heads", + "sequence_length", + "head_dim", + ) if data.shard_on_batch: - cache_sharding_axis = 0 + self.kv_cache_shard_axis = "batch" else: - cache_sharding_axis = self.attention_kv_axis_names.index( - self.kv_cache_shard_axis - ) + self.kv_cache_shard_axis = "num_attn_heads" + + cache_sharding_axis = self.attention_kv_axis_names.index( + self.kv_cache_shard_axis + ) if self.cache_shape[cache_sharding_axis] == 1: # cannot shard on an axis that is 1 @@ -196,17 +217,24 @@ def make_caches_generate(self): caches = [] shape = self._data.cache_shape - for _ in range(self.num_layers): + if self.generate_cache_stacked: + cache_shape = (self.num_layers, *shape) + layered_cache_count = 1 + else: + cache_shape = shape + layered_cache_count = self.num_layers + + for _ in range(layered_cache_count): if self._data.quant_config.enable_kv_quantization: caches.append( cache_manager.Int8KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable, env=self + cache_shape, self.cache_sharding, self, env=self ) ) else: caches.append( cache_manager.KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable, env=self + cache_shape, self.cache_sharding, self, env=self ) ) return caches diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index e2756d73..ca703073 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -380,18 +380,20 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class AttentionKernel: - def __init__(self, env): + def __init__(self, env, layer_id): self.env = env self.shard_axis = 0 if self.env.shard_on_batch else 1 qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention + self.flash_attention = ak.flash_attention self.ragged_attention = ak.RaggedAttentionKernel( env, input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), output_specs=(qkv_pspec, (others_pspec, others_pspec)), sharding_axis=self.shard_axis, ) + self.layer_id = layer_id def __call__( self, @@ -416,17 +418,15 @@ def __call__( bsz, num_heads, seqlen, head_dim = xq.shape _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - with jax.named_scope("attn_insert_cache"): - keys, values = cache.update(xk, xv) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + if not self.env.ragged_mha and seqlen == 1: + xq_expanded = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + else: + xq_expanded = xq - with jax.named_scope("attn_qkv"): + def attend(xq, keys, values, local_mask=None): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax( + local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, xq, keys, @@ -436,31 +436,79 @@ def __call__( ragged_batch_index, ragged_block_index, ) + elif self.env.flash_attention: + with torch_xla2.default_env(): + local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, mask=local_mask) else: - output = self.dense_attention(xq, keys, values, None, None, mask) + local_output = self.dense_attention(xq, keys, values, None, None, local_mask) + local_max = None + local_denom = None if not self.env.ragged_mha and seqlen == 1: - output = output[:, :, 0:1, :] - # For XLA matmul performance boost - # output = torch.matmul(scores, values) - self.env.apply_sharding(output, axis=self.shard_axis) - return output + local_output = local_output[:, :, 0:1, :] + if local_max is not None: + local_max = local_max[:, :, 0:1, :] + if local_denom is not None: + local_denom = local_denom[:, :, 0:1, :] + + print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") + if local_max is not None and local_denom is not None: + print(f"local_max {local_max.shape} local_denom {local_denom.shape}") + self.env.apply_sharding(local_output, axis=self.shard_axis) + return local_output, (local_max, local_denom) + + + + with jax.named_scope("attn_insert_cache"): + keys, values = cache.update(xk, xv, self.layer_id) + keys = repeat_kv(keys, n_rep) + values = repeat_kv(values, n_rep) + + print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") + with jax.named_scope("attn_qkv"): + existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, mask) + + # For non flash attention or prefill, existing output contains everything + if not self.env.flash_attention or seqlen > 1: + return existing_output + + # For flash attention, existing output contains the existing kv cache generated logits + with jax.named_scope("attn_new_qkv"): + new_keys = repeat_kv(xk, n_rep) + new_values = repeat_kv(xv, n_rep) + new_output, (new_max, new_denom) = attend(xq, new_keys, new_values, None) + # if cache.cache_k is None: # Prefill + # return new_output + + with jax.named_scope("attn_global"): + print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") + print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") + + global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) + existing_output = existing_output * existing_denom * torch.exp(existing_max) / global_sum + new_output = new_output * new_denom * torch.exp(new_max) / global_sum + attn_out = existing_output + new_output + + + return attn_out class Int8KVAttentionKernel: - def __init__(self, env): + def __init__(self, env, layer_id): self.env = env self.shard_axis = 0 if self.env.shard_on_batch else 1 qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention + self.flash_attention = ak.flash_attention_quantized self.ragged_attention = ak.RaggedAttentionKernel( env, input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), output_specs=(qkv_pspec, (others_pspec, others_pspec)), sharding_axis=self.shard_axis, ) + self.layer_id = layer_id def __call__( self, @@ -487,16 +535,13 @@ def __call__( n_rep = num_heads // num_kv_heads if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - - with jax.named_scope("attn_insert_cache"): - keys, values, k_scaler, v_scaler = cache.update(xk, xv) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + xq_expanded = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + else: + xq_expanded = xq - with jax.named_scope("attn_qkv"): + def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax( + local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, xq, keys, @@ -508,22 +553,63 @@ def __call__( k_scaler, v_scaler, ) + elif self.env.flash_attention: + with torch_xla2.default_env(): + local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, mask=local_mask) else: - output = self.dense_attention( - xq, keys, values, k_scaler, v_scaler, mask - ) + local_output = self.dense_attention(xq, keys, values, k_scaler, v_scaler, local_mask) + local_max = None + local_denom = None if not self.env.ragged_mha and seqlen == 1: - output = output[:, :, 0:1, :] + local_output = local_output[:, :, 0:1, :] + if local_max is not None: + local_max = local_max[:, :, 0:1, :] + local_denom = local_denom[:, :, 0:1, :] + + print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") + if local_max is not None and local_denom is not None: + print(f"local_max {local_max.shape} local_denom {local_denom.shape}") + self.env.apply_sharding(local_output, axis=self.shard_axis) + return local_output, (local_max, local_denom) + + with jax.named_scope("attn_insert_cache"): + keys, values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) + keys = repeat_kv(keys, n_rep) + values = repeat_kv(values, n_rep) + + print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") + with jax.named_scope("attn_qkv"): + existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, k_scaler, v_scaler, mask) + + # For non flash attention or prefill, existing output contains everything + if not self.env.flash_attention or seqlen > 1: + return existing_output + + # For flash attention, existing output contains the existing kv cache generated logits + with jax.named_scope("attn_new_qkv"): + new_keys = repeat_kv(new_key, n_rep) + new_values = repeat_kv(new_value, n_rep) + new_output, (new_max, new_denom) = attend(xq, new_keys, new_values, new_k_scaler, new_v_scaler, None) + # if cache.cache_k is None: # Prefill + # return new_output + + with jax.named_scope("attn_global"): + print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") + print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") - self.env.apply_sharding(output, axis=self.shard_axis) - return output + global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) + existing_output = existing_output * existing_denom * torch.exp(existing_max) / global_sum + new_output = new_output * new_denom * torch.exp(new_max) / global_sum + attn_out = existing_output + new_output + + return attn_out class Attention(ModuleBase): """Attention module.""" - def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): + def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads @@ -531,6 +617,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): self.n_rep = self.n_heads // self.n_kv_heads self.env = env self.hidden_size = hidden_size + self.layer_id = layer_id LinearLayer = get_quantized_linear_layer(env.quant_config) linear_kwargs = {} @@ -550,7 +637,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): if env.quant_config.enable_kv_quantization else AttentionKernel ) - self.attention_kernel = Kernel(env) + self.attention_kernel = Kernel(env, self.layer_id) self.q_size = n_heads * self.head_dim self.kv_size = self.n_kv_heads * self.head_dim @@ -630,16 +717,20 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) + if cache is not None and cache.cache_k is not None: + print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( xq, xk, xv, mask, + # cache[self.layer_id], cache, start, end, ragged_batch_index, ragged_block_index, ).type_as(xq) + print(f"output {output.shape}") output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 1072dad9..ff0a903e 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -73,6 +73,7 @@ def __init__( head_dim: int, device, env, + layer_id, ): super().__init__() @@ -272,7 +273,7 @@ def forward(self, x): class GemmaDecoderLayer(nn.Module): - def __init__(self, config: gemma_config.GemmaConfig, env): + def __init__(self, config: gemma_config.GemmaConfig, env, layer_id): super().__init__() self.self_attn = GemmaAttention( config.hidden_size, @@ -281,6 +282,7 @@ def __init__(self, config: gemma_config.GemmaConfig, env): config.head_dim, config.device, env, + layer_id, ) self.mlp = GemmaMLP( @@ -340,8 +342,8 @@ def __init__(self, config: gemma_config.GemmaConfig, env): self.env = env self.layers = nn.ModuleList() - for _ in range(config.num_hidden_layers): - self.layers.append(GemmaDecoderLayer(config, env)) + for layer_id, _ in enumerate(range(config.num_hidden_layers)): + self.layers.append(GemmaDecoderLayer(config, env, layer_id)) self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, device=config.device ) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 0752cc45..896ad62f 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -103,6 +103,7 @@ def __init__( args.dim, env=env, device=args.device, + layer_id=layer_id, ) self.feed_forward = FeedForward( dim=args.dim, @@ -259,12 +260,18 @@ def forward( freqs_cis = self.freqs_cis[input_pos] freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - assert len(caches) == len( - self.layers - ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" + # Should check more thoroughly, as of now, when prefill, it's always not stacked. When generate, it's controlled by the parameter. + # target_cache_layers = 1 if self.env.generate_cache_stacked else len(self.layers) + # assert len(caches) == target_cache_layers, f"Number of caches ({len(caches)}) and layers ({target_cache_layers}) dont match" + end = None if start is None else (start + input_pos) % self.env.cache_len - for layer, cache in zip(self.layers, caches): - with jax.named_scope("TransformerBlock_Layer_" + str(layer.layer_id)): + for layer_id, layer in enumerate(self.layers): + if not caches[0].stacked: + cache = caches[layer_id] + else: # For stacked case, there is only 1 layer of kv cache + cache = caches[0] + + with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): h = layer( h, freqs_cis, diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index 7d053703..df0e0056 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -38,7 +38,7 @@ def __init__(self, config: ModelArgs, env) -> None: config.vocab_size, config.dim, device=config.device ) self.layers = nn.ModuleList( - TransformerBlock(config, env) for _ in range(config.n_layer) + TransformerBlock(config, env, layer_id) for layer_id, _ in enumerate(range(config.n_layer)) ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) LinearLayer = get_quantized_linear_layer(env.quant_config) @@ -142,7 +142,7 @@ def get_weight_sharding_type(): class TransformerBlock(nn.Module): - def __init__(self, config: ModelArgs, env) -> None: + def __init__(self, config: ModelArgs, env, layer_id) -> None: super().__init__() self.attention = Attention( config.n_head, @@ -151,6 +151,7 @@ def __init__(self, config: ModelArgs, env) -> None: config.dim, env=env, device=config.device, + layer_id=layer_id ) self.block_sparse_moe = MOEFeedForward(config, config.device, env) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) diff --git a/keys_original b/keys_original new file mode 100644 index 0000000000000000000000000000000000000000..272e12a02610df127bad2bea3c64347f0d43e72b GIT binary patch literal 66746 zcmZ^~c{Ejj^fzuMBx6Y#l4eDk4EO9UO_Gq*r%4fosECS2GE`EL$`FMnNs1C0;K6g0yW)p-o$2kjBt-Uq{X}(g&`R&Hng8pcYKYt% z*{*XeL*(bkcYXQHTH+HHuyRGHs_QD>C99VBP7Lq~Qw>oW;vUj1EJW2kL~VrIG-Z{Yt25G%bw&Q4 z-R=LHbuSsc|0C^?f)G`9P?hP_%gxUre9IUJb>t^R_m2*Lnx9 z-}dF>C;!2+d8YWM^bJXWt`q*54uCm2>fEu@f=xtsGTu0j1{DYx=9i3?L2Gc4Wxm*V z)F%*2Z}Oh`!-dN#L6W$*a44NzD~9~{9_IHMOrcSIup>zpqpgd;Izo;D-g?2F`wL-B zehihJ(1krm7xSjA+T8GB43CU^D~b6Si);GX;=Bj#^uEt&m|Br529)h**lEJUTO(0# z@E!V3`h*sq+m1GOyW`X}IriEWNRvjCz@d|x6mmd?G&?@ivl?4AHuHoPKLaTH>=;pI zikfuXC!jU)f!IIx9*i~|fFtbnsr-?maH`FOo0g2?6GIB2t*;+G4eH0qw`|4hfl^R- zCgGgQ&9w4|5w`3|l$Iy|1G_DAS^L@zIy$07Jk{>ZCXbI$ePE{eu)5q<)+U~5UIyt! z@8`>lmhlq-(e;QgUR^RtRGS`y5!|2juGW)FSgfe2dJs%CRQTo0Sm}GMSh^C`LcLy` zA<2wHzV4GF8q-m1{yYe;MW=&t_Zw8%HV!I|I2EdZ?H*}x>eM2;&2K(|$L<;MZ%h{69kvrn`r7cb;SQW?HHrFF20+FFcYOb9 zA$@B}hGEiTQDs=YV16%xr* zE2Jbo5%Wr>!dtDIoOPm!hHf}RD$N5qurVI=l@I@% zFOZy&NBtJ6ipJgUiX*H}QG&(>nDllcF0yffqXVOGg?xQ!ptTjMj-EiuGuuF_7%7g| zt%XS`zr_Pj27tqxdz5gZ5*Er@VWPDOZp@X(i0T@6w0s99?arlhF$W=eZW?(?$MDSy z5n%qTQ!rk)2RgdFfvIiIw61mmpS&%HDe>!R(=S>4Gi3_?UZ%(mL&ws{;Dr!cE3mz1 zhp3t5C;0164!^fa(sx>U9JIOCa-_)Jrzt?`aM%haI?VTXFhNW`DmYs`1|lj(X#h0(ITc0$85?(P@V;S*0Za7He5Nt^`YP!J0CKC^o8xS1gxCoEq<6+MSW+j z$6sbcaQTcmJYYsN-LM|QryX~4bNmOW^Yz71)qWIcVImHQAehF}V3%|iSsv8Jm7gSZ zY=#o$*kp)D(3Hcy&0*&rHg{@YO7Sz|va=6hxQTIm+ zxQz-G9?$wt0hu!Vv-73oaoY$!G(-cjrx! zy}4vp246gKE`yjrs*bQ&!X3&(>`MApa zI+U;41%;3HL2NfAbhu{91u08-*Y{)C8>fpS(~j`QrBe7aTbE1wmBaFYuOzx{#h`xy zcqJtpCWUUp`qPf+usD~-98hBqiE`Or)gPknt90>sP#Qd~yF_OK9I!4+1Ua7*EVEG= z-X=CU_49~ zV+{+?y-yefh4nO*40xa6ZdzX@$8nvL@bBz?WUvWDr_~R| z*!-jH4xCfrOwHWh${DB1%8!^s#iqLR(8Bb2#fIjDz2_@;>*f%1E#q6H= z^>!aj%TI*;|HY!u#5oXhQV$a5Cc&+yBH@$eIu6oZ$Nk@0V(Ei#LfxMM_;hwTebsqR z8-f`=NXtv|3T{zdk9slis0|FbBM}c9RMLcsi8w509zEB3$=_4QqoIcjCjW>gm#9La z=IIgI{_7$gDeMg%+9FNAoCYr6HE`<1e>8K-A0aV?g>7@zmL4$D$J_E{;+|)lFmR+xG^ z75V~ySt$CP20Uo`m|73bB#mAIKbq8yzb`F;Uh#Rtg5JAWvF{bw7U?7U2fU$6vNF)9 zAd7B_y{K8z14ryprn2$!tnPD;q7%H?`rA*?QA($$@{ZEZ;&_g1SC(#{Z-VOac2w18 zF-_Uh6AIik>E7*VRw+({QteRwdUia?L@NtR^rC3q!3QK?5ylG=dvS;1NX&CgWXHx` z;)km7(0M{+mKseAZUQf(PS zJC2m_+%cx?cCo;&x~3mT9@4?+M>pYd`3Q(fc|-#G)p@cm)3U!hl7b=SvVE58#m$1R4w#b zDknaP7)c6-wc8C(ziONVujqDRhD8X~_Gr}ZqPkFm`Fx|f8;_R~V9doz6P zs)058GRdh$8&~&S3?W;ygnvIL!w{cX_&8@P9jaX{re4eix8ZK$5bZ&@H)tsgeEk8s zsl-6#yeK>Wfu&;3n`C^IlLVszz0l{YlaM3b#PtDD@WiQruD&@+&2!D6_@^mk_L_;C zXB+|#_1z%T^jEn3y*tj>UPaM=Gej8shHNX7X^;09_WQYw9>l5PScS>FMlFI)-Z_mm zTk_dR?<$ymIxjg{lMitjXQ|HQ5Pe=WffUz`g8LeuDCON@R8{N6S}(NthP6KCU%bZ} zSMJeeRU7{MuLaWo3t{W7=VDt|B7MLUIm)m*9$GMjwd+0#MgzWqTyPJrJ30bYX96gv z7eda2&osSbI!A_&lIEsur|esL7f$im7C#p! zt{hDN+GKdvfIxQXV}pOw`ojJ6@6@t&F6r-T1n;u#Q1Z{4UmqICMfpEO*AM7{ zDl)zO8;?W1lc6hVn&!HZj@~#9Rbkg5xbPT#o;-uycTD1OomEgDcLX|b<-qeENg&_+ zgU*bt&KPvo?m)dI##IT_&@U2$? z9nIMX-d@@qa4fh@JnHkDLbkVv z<UgWALC_TQ5bX;522|K#X{`WJE5q%@KVUWAjTuH@=vPlUz)b)(O(9*ZLmDWaU; zbpCTv6|OxAf!}%ErF|l9h)s_xsiJ2g>IRh3`k-X^_@_`vwK)qt;$BjX?>*7t+7_~m zOk=$pL!j}9F5Ry=BKRH>NfMq4AAi23q31%eug7Vkx~;e*YBPN4JDnur5Am2ZTF5RR z1R-r2F!$JWd=osCkNp`>e6vYhbkven)6}J=zqdi$s?k(p!g6npZyHRZmjORL45kVQ5%h zN-JdF(Zc_hL2EhiU*QDVe_M$KHgWv)q8{&b2_t*wtI(>ijqNH?xcT2N+L1aBZEdok z+4T%9au~v1>)(_@Y`pMtqANIN_NP?|GWc%XRE|wP1k-1qgkE~l5Enc5-&_h+zLUGS!}v`J@w4> z64wsQhwVZ<9MxL7@}w>yE)qXvL?t)WDP;gSf^Oz_K>B>u3A116qt*1alDvqYLin)RR9zj2w*3Ri=;3Wp-#OtGizQ?fVTfjlrvcNu-Htsg&h+8aL+*=1p~*=+@jqXz#a)r<(kL znD7NOwDO8*kkAT8&tyZ^+9C8QHyF$M-viEF&SQT>;-78DprAyFcSNLtdZRZEkjtfm z*48-s%pi8~pUb<>9LKe*jIl+1K2`XCC5_f`;8q#{$LBnvRIll@-D??cu*pP;;$}g+ zatOCH3_+vgW>7y~8`tZN6ZV_!5@Rwyfws*@_+x(<)|_5|SM-{Nd1@U(-4Sg_S~!QT zerCec=3k&?YQ+}|Zqi;wJ^Svx^~5zFjG;6@p7or3#cgY)d~k;&?sL0|b6tjVOv82X ziMcBll`4t-4n@+n;i|&-rh72*n>Th$$OiXaXNA)A1ehEH(8tvR69$v*3O&n{M#`BvsdGT-)ay_91Q|=!m+qk3f+dLAFoBGk`#~s4N zTN6d2^Y6uwCtHO5Q;liK-w{0GyD#^fyNRD~H^no%wsKaqH8y=dER<%Szj4_PK`v(F-;=BOTonE(HInmyvx%L@@L%u$HYh>_1IlO#uhmB-e2L7 z&KJtN6b!2+FTv*Hc>eloJXV``;hU9C*gNATjP3sx-d^ZS#`}Bl?~m4k$J=EbY~M-g zZZYImqK1uIGN?>Z78^zu;c}&7yTdeZ>Oq8)r;~KV8VQ za1S(<^x;o7t)S&F1rlq9l6~f3-ZUZytZQD1<6C;LRp7Eu#;!wY@1oan`WJ$Lfbzc(fK=WeIUY8$M7IEGt2kBIKq zpTPNvspOK_6YuWzqwQx>D1Ku_79pyKOVKff$}Ys`l=FqdMM!JkT`L1 zLpzNP`VV6aBguCEc_{w%Q~EQhAMnW#>4e-F@Bn4RpC6wKW^;R@W9c;B+A9*2QVi+Y zS}S;=tB-3$PjE4wiu)tgIXk#F%#D8oarL(F;`cpLeWJ$QZoj7YP6fiK&u&cFu@F`> z9b4P)(EExi+EF?j_pY|(V4uG<(7cU`?f+2yV?%!8&;mv$LqPke5f%=q7v?X#P6cY^Rjx8gXm7JUe_nO!3psiD#tp6xQn& z)Yu8|AT)}+raHm>eH-!j%?#MKdNQk@oQWI7IKF z)*5f(r#ZXEV6eqRaQxB- zwBxp*%k6NIoQR;u_mz1109ooECd08)RPk|VCqSqnmh`_uXTHmql%$;{b(1~xc`(DF z{3Uquycz_3Hx`OM*1+HylZAh&&0_tGaPZlb$cYj?PXEyxS15$T`0Qb5V=$BNZh20c zO#^VV-%LJmsR9l<4F}sA7qp}|!m44}*ixy$l4Z?s&e@Fj#uf>GQxY(sbTB@*&k`w~dW3PVKyS`+O{f`YN*Jfg18GUkypJ zrL=!*j&xz;Da!5r9%2J*!O8s)e-BvA35vV9qWdM<{o)Rd*UsS@t$cbqcMmN}U&P(p zlR)>`N0{N#OL*C_6ql+ll25^#3#TCA+7{u!4|)7>emM35SYA4%Scce%3UyE5L(R87G zA0C_PMOSiTIdo<@nEf@QmJ4zaR~mMH#^Ib*v0Z(Gn{?zUH7vHC zgeu$9=}xEyj}M9=zHf-b7X(7zku$NoX(|1*C@RgJbsv{H`wNvGqfu_^6!^VhKC6rg zftSC=pvsuHLV+ZmG+Pw0xhzUz)8GOVMrcrvy>Ds9g+Y9vYBvqm_#lm59mj7xnzOO< z56~#@fY)pPiN5NGDQ3Mho&2Oh6AqXP3UQu%W8P&Nm+DLYOS(u5w+p2G@S-62_k!!@N%!W%bv9Q#8ZWK3-ACcfMcZV?_YI;eZeDO+{_(t%hQ^z z8~tFIyq_@KID~b_>G9)~VsLy^DEyu912%gvF>7P$7Bj@rUFW@9YAXZS*pC8;=!7i9URWLN4gQ=uN}8^Hrm$y)cnn z!h48Yba%tjn{r%xDob?w*At@_-ho>;>LE;1iz{zx@a#zypkwAmac_^qH@}_q+H?<_ z?U@VOpTZbbeDKQEouIx$0n;zb(fRpS9Q`6ooY~-jFV~rYQ$!QI+hxweyHPl8>lF%2 zy$ef}en9#70-^gEE3B#+B=}xNj*1DwEyeSAQ|Lx`86k_`U(2HJ)s5V_A`VP5s=#S? z1LCGrmTuq^&#Osl`c4@9TLm4P9EBaas^sFm6vOk!vcvs$7`d?>64ias zxzPmID0XAR0WqTO&9RiIf0}ww3U}OjMjLiq0>^#l$v$YeR7NJ}QiP5>^&QYDUGgjo z4C|gs1H^8WA2~p3bt#U|j#0yty9VIm*naqa{YVJx($xklsFPA^4(wI#pu%J?wCeSV z`uZ5)ui8{7xE+rHItRJWvyJp74&cfPb-u8=PIOux#iO(rVB5W&aAn(29;$46Wagble^d;_hwp~s!+Kdd64wOXe-EQ)R$J(Cryoq&Z;uOuG|;WZh}X@z zBU~`)h6|1k!Kk@QT*aBKvggwKN5`eEQw;Iq^ccZv z-)N4HibQu~4g69*iy8v2g6L65Yya*-U+b43IpZU`r#z!b%{{`dKk;bLep3j`nI=Z( zE~DvJ%isb^V1lV5E^Yb(_lEVsf&@);G~bQkDE`=Z7+87qWBoNyf_*3F@ZO{>xly{hp|DXAeOB;267GoxK97XUg?f0vz>Fik&Fmy&z6s)JBi8NM zE1Z06AQXs-_^flkXwW$ae$PK49jNLnH}ZGwbk+mvu+rv;gXCJVFGOX%69^>8~; zjUF7oNu6r|gHPWPKRB*`L7l`fH(o_pW*qF&`&?+knl( zcG9T9&*0I79&qAgExol@7q5D+z-Lk?lu;gJw_9;7?q4^7w+9^%GN1S7dDr5ERj#`5 zBuNgXCGvvrfd=$ivQLsRw?X3G?Fm)>yefFQSmNOt4Y;lLh&D%R3&BA;B|H9mOw)#z zlZ;Ock9w!X#Ss>y;Ao6iyWK$7MW1b_{G(O1X`~qkc*)|Um~K);9lF7wrmxQ`?!%!a zHwUzz`$$hbcquBzC{ePOEKMKY4_9o zJ=n$9m(#-=i}8(K2zU?wE2N&6@EWs?Y~dt}7kmoIdinsQWlmh7)CA7+&WacM-6w~V z?)>?Q0p4g(#GLGTcqK9h+a8ojZ}hCCVM}E3xL-cD*RFws@Qr$=T%trn4L)f&9zruV zVWnj{9Mwu;qqNgJNWX;^ztpFT=?PeJ`hZxKA`$n`j-gyjMK0}Lh{;`b?^atf%;_}; ze00yyl!Le7`1&if>ChmyH!NWHPf92!^9FWW02l4dppwzc#ng3C;HvzFrgm)QvtBOP zv-KSn{5u8E{VILVc}N}&f9QGbDExV?i#xWC=6kb$!Gc});iR4^ce|TSb1YY&XHb%z znQXFHb3cNV{*&+?z12K1`zNGCG3}A{MLnJK^kG<@@EFvgzIYt_oZ87lg;B!S&#}Cs ztMBcxRD#R@Q^l)Iec;XO1e#%OMPb|i3QCDP*uk?JDgsP|*7%_~X5R{oaolw&b{^#t71J^-J+mjLWP z0D46SxG`@qH)%{4_e}1CimMl(?v6@4Xt)fOYf|C+*-S}*Zx*OpD+;L-+vuQ0FEaOf zO}D!G!%M$<@{NK6(0!$!STWIr3r5O=nTs}F!Dph^vlXzWIR!2cyoqi%z6fpp8t`|T zKdN{6B2;n;rTvNyOD4Bj16&p$ets0y|Iw1Jx-*V!pCt=bad$xK#4y&e8kSsL53y6a zr@x^dfmk*Mit=Hr8~qyABy3b!+W8Jj69y06a&rtm5L@e!UQo( z*y`E~6UJ`Dsu&L#(>k1f%XENubsdeGO8 z_>fL@^`WRfI$Ja=4&Vph`tZttr8K{38joJlBo21Cf-;ZGZHxZwgAt|2A+vfjH4W>> zTO0R5MYJ=PeafMK4=$5Q{X^LHZ5qh!Q=!S{k3q$!XqppbBWAiR=jOSQ!gaG0?rHE? zJeQb_;}o{@#v4Ov^M`!yc{5JbZ_~kd=RbnXWPS8V?&8i>!&vp;EiuB}0$!_>ixp-a z5b)L*J?OaP!n8bL*y?VW(0>m89&r`I1CC>hfig}M?~-Ykem>l(FZ6ui3`_iMFf3Oc zvXlF<`ZrhnJg!i3#4n$oRaQgum#%j-*phDUn2mv(i*UC338=9cDdgsCWv%($MemXo zr2oi=WV?6L)N&Kv@H2?(&36gT^6g~y*IhEI|1a9SeH%|0pD5l)yhmQqd3Nvm&B4Q8 zCbOlECXZb;2H)z(!~8LI;H5N_;o}7KFW8OHUyqOK3_^#W;ZS1hMV^{{sQi5a?H!*1 zgT6izgKh4L^G9l7!T2IdUzP^*od4rN%0ubYmtHvQSGu^lF@H3%snZq*r~ML+Y}?0gHJU)OpeIinbPlR6MzU4ubIN@a z3nsSCth2xsp4~{|;{$JrwfYCpzPkdRu^ft_mPz0i@QJM2lko4aYvM$M1blwEFDfn( zIJ<5=9{k$`Ur$%F{_f}ELYW8R%sg9Mm@kmhN)0@E--kBo-j+PxGLFt=OGz2-P;P%k z*lR`P7rdW7862Zsfm)~^KVE8)enfQYzgldX^qQjm`U_J6Pm|*aDeRbW7EPY?!Kh10 zxV9|{7uGJriub+o?~c7N;dMA)UK&X+=1&&;C4CYO>s74rz!v(5Ple^ztlD*1ZBL zk&41M^D+uKy^i#)UW>zvj+6VN^`sD`L}zpwprgxMks3A_z9czgwQn_FyW3lm9@`I= zuU}31wy~t6{6g$uqbs^yc`L^Mhg3ECu@Jgr85d#|HCt~hHG8*@YEBPkX|5u9f0K%{ z4c}5>-gZ7Ac?G|$9Vw{7mEX5|)8N~-pyDad#jgwD%J%Ny*E*Ijm~SRqGQ{1PMhG+Dj3ak9uVa%E??dr!?8nZBHA9vu=%E4;D97A{!wE2% zHUypDe5c#P(lM}X4&Sdl29Yl|mE>Mh;NCrBao@!aFtwzb{QWv<_8on)?0i@1pmJ66 zD#QbCKROTMCW6?y&A2ztLG1V?LnqeW5{9jcr%`p&akSer&@%RiR)RimO&QR2e1O38rDhT!l}3W}}!N%Ay>>~qWn&*A2r>TQa%^_6hO+U{79 z)F|B(r$ckjx4?sUJ@`f04E79<6Hi(tuxS@h@}bceew-3}8uPAp33))GhfT-JHIcLp#pX+t>v9oM^z?=4X)42^lu`YcnNrzf;RQL^{J?k*2h+_qp z^|JKn&j|JmX@|;9E|4DGATf%SB^4_NamMdiylDPlQdzKESW+EFzK@rQP*X3_OS0mP zn)l$4jRzZiF(cd0bNHvO6X=dc&RTS&G)zH_x9Ay?*3893H)n9)!?z*q$P3D=n}zBT zPW0*00bI1UjLw_9qJfzX95*(Vyl#aHx*dnHpo?$4viTtt+%Fefk6J*maX%isM+!xE zvnXV`gE($S09D$Yqwk~a1o^FrI7R8MK`i0AQ>IxrFiR>h^A+h53wsDv*c=FuQv(dzU7r9L7()WjF z!cUX4G|D9mofoDFqgI}U-Bunj**j6t`*@YQ)$GKESJkBX>x0m~?+Ev}dkFu^4B>Ft zFEs4i3pbgtEQQ6TXJ1+x$NL`qW4$p_u&F^D*mDYSP>!%_*EkjR|HQ2$;oX*xtcxCB5L1}%VFtXXS>)KNA)RrdUpv_6p{17D^7?ccO zFT`W1;&7bp-iMzD4dKbjF>G-@ABJq5%Flk=z=9Yz$V!z+`bc#+-n0(ZS%+isv@Q)r z(vvlxPUE$`M^oko_*{elkjD}5_au=)yn<>IjK`B%huCXiS6 z`>-R}6~-Pc5<)NNko~Y=PVAN|4w>s;WwZ+P%UXo~MJ6z%)`~aLHrl14Ot<@baB4sn`TpAi$rgL) z&ucvx-KNeh%|;kvW+Ix{R7eH{x#9SM@2Jkz6>O6Mo#mWyZ$}zlzF-bEkLL0sJ6E0& zM{wqR9fS<%^4-lnL~X}U@jd*!xE|4v=Y{xhxNTtElKC_C`Qn<^+8Z{=q* z?cw(!9iIK?5Z+y4#G;D~7by7f^Oo5Rnr8&Re=*SXWh-{*M6hEI8C+Z_%b(4xx_FTa zD&A#+){z^?KwpW+KhffWTQm4&o+sxNJQ0fLK9g7~y%3+>9s(D`A5pQc1x0RA;3+F+ zKpT{T%^YKRyVf4X;9A9PIBPvS_uDw_0YJ{Q-Fd z4Y_xj9A?*mbmSGO{oFVR+c+P?Jxs-Udo8i}#VQ``C69Aj8sJ0tbC~M5pUus?;U$-? z9DcQcWfnA)WV?6cYg*l?K_De+9jP_5Q4{&wejZ^fgNO5Qfo#!*o)60w_IL$ zba1a2E}06)N(}|4rTQ3jS(ZE4Idu(xj&1upY1^6;E$8O!adfIQ?AUWu-EiBBQ2gQ2_fy%8H zg4ezGrTsVUgP<*SLiyb-9KPF4SlXpm%I35}{)}YFAlY;9qx=J^D5$nWPUWWyj->TRYu$w&J z_JhXaBW!OO4K+0zIq$O-SB$tz_JXmU_s%X2-AxXMOO8T&zB6i>_ND)7i{P!%oYJeQ zwG!RkZ4^*;kLs%4iTR$gY_?)6v}IgkUCAk7?!bAhdP12;A3ILLUGGP9iH8fBDu_55VoA!s^9)=Cm zlkofB%fi?z-BD?HqtJgq9w}#7a8vGcvP^0uS#w$JH>Z!?`D4>~vyCRN^;ykl_sr+I zk1I%8u@PkwPD&*W>pA?i2g;t?K|MuNbhxmE>$5IMjDqw9X~1cr(eP=P_a&WdNB$Nb z*5}e}&j`BeqRtafbx><$4>Ee=$!l^FL1v{HzEhmZ>t5}EhxP|}@Buw8?spX~je1Es z(efyJ<2X%PYC;$HcKO#omOy%W3dh`8g+o;Ck%5Z_+FzS3$@$oq|2sa0hrH5)S*{&Y zrN`~$VA>l@$Cz+e2Srf>b_+w+x$yN2Bg{Q;gN8jnj3#>*3%U8zrN6ePkl_WSj^$nJ zyvJEs88rdt*z9Q<%>JjR*qF?48NU%N#g2nO}j!XvkGp{8IU{;s$o`rM5cVvZ!h z@FBm2@!Gv0%=@05{L&I3GAkN0|4Lx?Z5RIGy8{&DW9V{>Jk(z?;l8~!K;1nPvscF9 z6NO|bIKNQX5Uhh{a=kEA#|ZEFzM}oZI>~qbUV5APQ`kA ztbVf?Vy<@eqvtfp{^vo^bekgf%^rv9Lv~ZA_yT^!7~%bC?`g^6cG4TW4EBCGNcPpM zxPNU7h zM)&$**pPN#P)J_^23sBp<0+4vOTNIgjej7gzZIE&UH~$!ekJ)QZo|JrAF1i(9k`}Z zNyA#M($2dtgo0NcaBukP|bl?TeYMTU3Do<%a=?dZJneJ?L(G|_dewGwv zMp4!2{h9Yh6KY7F{}f%R$)b7B7Sx_Q07!7G`>P z$AUn4SR55A+U}l#RVjJGHsuA}E|teW(si6Wt2g&{8U=kUlu$X=126mI?RA@B553on4tykghDKFUeWFcOvj)wjp^|(^kkv$88Iqvdy*6PCY?>|NG zybp&2Ul$|XxVuJLda^)#a=dGXDk6*ryZ;g0B5%M01*D9J>G-l+4_F_2p5|QNY^Ua- z3}gQI^1Rw`>A!?H+GzhDhPbVgju~$$Y5S5Sq>K-&z{19LmU;l*+aKrdvbL(!LsNA>iX#NydxE5VFJ&r|7HUh(tv;I5LzgdSwcmtBonN z;D;d1G{PgluhQt}fxP#tDz~=Y5=-7}rygDXkiKmWNv2Ylx35bZtNut{4OZfQ|3Z4H zr-nBD_Tg{yKca0S6u#g|GK zk3~`aAJruHmi{$*3&q!~;m*vFJodj~eCkOk*7xtj`j5Uz9^4rVKR^GVmb#DNB->0$ z-+Su<~gJoB*Q{AEnvYPjc)^&78jgXV*TvLp`rny3G!95yjZcYYeUxf81j=&JV z^Yn6dA#6W70Q7EaVCnh$rE41H*~a>vV43lOvud-*U&|dnt?kmPD;mYxz%FdYE5ttq z19@+k$1w52b+C$Z=Yw^_(Ph{~zS`qHUomjOIfia{Ah=Oz^|>gf`{z4uQ5 zLL+ly(JbwvAQSRUu>3Y2qFxl^a#WyV6$6|U+cn2BB}J^+y-#XhyN~n#b=5)jdb;;t z1-*YU5*5DqK+?>!F#CZW3v#(^HROpXbj`RmL|uY0GZpbkr#4?q+6{9HXX1DZH<0X6 z6x^vtm5rjtK+H z$yH-JMJn{bMTX<(hvrh=xviag3h$u)ZBL$P8N)xWo6$?PJ`^~xJ4c2%viXz_I2^E; z0~T-L#Si1jV^JX5{}+$bA1LE4E{5Y<+sS1{ca*NohYQ_y!S#^CyrT;be#u?M3#78J zXK{bpJGDO+>{Diy1IlpTY6}#YXX012F=CH-@*J5HQQ`s4AanT;#qXBmw$VGNK4}pT z?b3uj?q!On6S{h1w{5(}RG-t`^>E4!8%(S@D}>a!(Ub>2!B=50FYj>}M}7693kJrt z;$H^7@l*z}*a2+g4^SI)iZkXp!#M>zIOq5c+L~vID%uiEFLK0{I!chp0lddn!aZXY z@XOK5WG}0UGVwjBLURk0eNzL^va9s8Pd>@49K>Pv?}f?wrtH&W7MU#=2+H|WXrRe` zaQpR3QWl{k>AOCM^JUgaZ|qUzZ;c70q#r{;Z-$B%ufB=FFHBMU_&r|mq((5x9mzVc z3-FY80_~fVf(wT|f+Jt$;n%HkJZM6HyyC6PqrzJ#wBjt?GH-#7wH~aT&EUk8 z?(Dg`xA^VOL6VbGWEa1i!YQj}u)k#|nOv8^(_K4o@P{-CT|%IivJU55ehu~7=7K{- zmk-+U2JLC=#uLB#bZMCdG;FmaKKOfqW>-Ce71EKsvoFE%#uz>()8%uk-b>5Z--lk? zJz=FmmDqce23jXxA?MY*XuI_cs9U&`U);OGe*QXaFmy0(JQziTOV&s${wsm&cdW_j zrvYBnj0A_~)5!PIA@79`7#=g=dYNvl{ZU0YV0lkGlWWAe`9}P;ZIy6RMoH4gv4_y7 zF+nUp$y904KozGH#2t#Mv_wCHY}GcQ>d}Ez`dfw~!uL`|jRIOYd;gwb8LR_`TDqau$RxB6TTWJmS15O$8tl-};zgGl z!Gg5Kl@Z$9`YcZ{)954pc+-h8yYwl;%K=#QX*CO!a7gSo=n%&T^@^e$2RSsgNabzjnU z^1CRtIMv0wJ{5G#qv)+{svvcaW__&_!tos;NIehJ*Np<+ej-nNBn6*4Xp;TXX;d|H z9KG?nEtx;?@#&3OVtixoM_ zPM<%`>?@|MljqN#c~GbxzpjVpe%eqfF2v=>m7%Cj7ZoChl5Cs3chu2g|_{aawC zAcwiubJ@;Afp;ytLLP&+(bR<#;kHwhaMWsK7mu(4=jc=k%dYQ-V>^s7Ks{ba-YduF zR>shlb+w?JSl2ag4j>()%5-`Ln=bzj6_Z}WraQ7Y<>CO`Dvn^gR)1P|qYm`m+kku9 z6Sy`*n{~2(fx%6G+Ekk^JerjUFDLs_3Jt;BLNoFIY3FWZn!3U`fZxiCgfOx=UV?K# zgfK9GrC4Yk_Z$ZfjRjT(l_?pr0a77lYZ;>s-{uS&5JA`wf;GmCkuh1tKsMx_69UOF zYy=Y6{NR8M7tk%i56qd2$nKrY&~8hXecFfr)12Pk^qzm8{LXEfG`-#N^Cu*K-jCa| zKNhzR1WSwa5v2UWHd?qpl#HfdCRw}A_cioO_M1b@?kJ*-w?e7CCW_AX zM{r+z8lNgop+4$pc0|4|_m#mbAK8hDPU*}*2(|h}up=*xk#zrlI+vs&8NZE4(~hsDiRvOc->{uq zyHASY@^2M%IUV#wg&VaDJ`gi}-(^*4CYIThMvd)j`E+Zr_(97}y6GK3{yZNh$4if~ z><>GWMlf~sq@wJ&b>AC(H1LzxR*+k!LhWbN=2jfuk-OR#s;5HrvtWDrg6W>&))W+ z!`v=jlfO&fx4xZu*Cz1LX8C%pknJ>~bOR6SNaqvYduUhhr$lJG`^AWHfq0v1c=dNR z)Hd^+7}~^GVZ*Vc>_-l1|BKg|I%NknZlJ93R1}Y0pGCJ`YNqMk%_Ql{X|@={=>6FN zaxL%-nOl2+?zO!jR)+ORYqAgVFFQVBEjjf}9Dk1s$<@8zPhDkSMYr*coj$Cyq)YnA z{4H6i+d_+~!fCBfF`acw<#n(6kT<21ygq!LB>&k23IY&-00bZa0SG_<0uX=z1Rwwb z2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$## zAOHafKmY;|fB*y_009U<00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;| zfB*y_009U<00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U< z00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz00bZa z0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fWZGkz-UrzS`K>xR_Y!~~sIr_(0J>$G}(f+kU?QOgw< zuZz=ciqpnv^*W7St5dohs=MwTDR-z0qd}!K`$X#OTdgf^^1d52&Ka-S$_h#?`2`l&kd=$+=M=I>USa2uKlIeD2`d-F-zi~K zrE*DFJ_TGOR!W6CMeJJT8u9mQXf$|wMla6{c8htE;j#Sa&+P%_w3X|joU4>PZ?&?` Wc7MFD+#3xZ3aQ{B-<0DYANxB-Sv36s literal 0 HcmV?d00001 diff --git a/original_scores b/original_scores new file mode 100644 index 0000000000000000000000000000000000000000..d44b7cb61ccbe9d9e1fc09e10119fcdabeb855ec GIT binary patch literal 2244 zcmZ{m4OCNQ7{@Q4&IyS41z`x3fEGAF*nr*VrHNA+C`!r#3dgo!vK-s4WPXB%0)9X! zDtR&sPl2g%l8zkNeXf?_2cA$9B##;?mYz)1<42Z7_kwv~#Iy74ZufoP-}}7xIp_bs zDJmytiV}&aLE%k_DU-#JV_2dwW?8i+i{3g_r(rZ<`FTdA!j($zjY3B)W3p&-!?Kn# z2BTG+WzlCZH5heSjDCsLWXURs5Syux5cM#gZ{=rda`a}0z9nL%(WGHg`)-QOj%o*f zj>PPwcH*zv{H12+WRU}7cA2N5LR1ba5uX8!*)@c}s|w&h{MfJ(t=42T>a~o)w8Sc& zYtiW~dR>A+%ZSaQ$@9&_8MAo4c|=%>YMAKU^^Cti>IHr=eci{EXA3zre;gG(h%Y?VsS>3k!XEy7m!W^mtW4n2qN4Hq z;%nIb>Iif>atQYR7)Dy#CXw=B8Y0TdVPi}TKKgSwzTx+nP9GOT^0&0WiS%gP;ui}i z+uYbSmpr&@PtHQ4*@tv)z69$ISF!8MZrTzrZfD&^2ieVO$y`rcIXsO?g%#auGP$iC z>Z9XXkLge71-d#o`srSqd!{p)l<9~SnU3TY+fDEYJOry--3S$aBKEtCJqV{g=?jl< z;3bm@I~2|Gq6_XE&ZM9$D35fzy0N(%ljzfNzBs3_oqg>~F&XJu2;Rp(X8pCk_*Z=( zX}%H#)kT%i)wl{Sm;He0SIzj!4;U~{|j5`sRrwvB6zFJi)*|YL^pV(AkGzW zt?wVh2PK{I^EckW2DTQj%y|zh99_{|+J@WmcfqpresJHZ8klXh^vT^B^46DfNbcBT ztYKog%H&vRyzEHo7rM!7>O;t0g&9vh+==imp(jp$gzX#F(GwyM(u=w%xy6y(sio#oKNs9Y(LqJrElH1*SLW;L{1_fltBEk?xvjV>0q325fg5>z^0UqwCBd1ptQC?<|GMT3p@d}w@P9A6d77a zJpgs$U0cV*>vZR#>GYT<;oR*DEDKJFFoo~?`iFXSK5z@(5zT;m3J9_v>up-#$i3+AjLc2m6savz|V4=O0@`bsG67cm{X7nxadHJ6o|TjxedgD1RXt zr*%CAztyq0=KG_Vd+Q-v(d7YD+N>sr+dtBY0sBy1_Y$7{qKh8Wu^2`i^v6Q|LRkLt zC@k9@Of%9S=$5D%uuPIjK6Ua#_fLXwN@yrlD?d2>p%Sf6fPNYsV z4A!W^PhT70h3MEm}Xjzm@7A1|8Ma9HKM@VH65Hoxn20-TXHfLYR|MxOQ Date: Tue, 2 Jul 2024 22:59:26 +0000 Subject: [PATCH 02/57] Fixed the test_model_impl for llama, but test_llama_e2e is still failing. --- jetstream_pt/attention_kernel.py | 2 +- jetstream_pt/cache_manager.py | 5 +++-- jetstream_pt/config.py | 16 +++++++++++++++- jetstream_pt/engine.py | 4 +++- jetstream_pt/environment.py | 26 +++++++++++--------------- keys_original | Bin 66746 -> 0 bytes original_scores | Bin 2244 -> 0 bytes tests/test_model_impl.py | 11 ++++++----- 8 files changed, 39 insertions(+), 25 deletions(-) delete mode 100644 keys_original delete mode 100644 original_scores diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 1d213959..afb25d39 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -311,7 +311,7 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): def flash_attention(xq, keys, values, mask=None, normalize_var=True): """The vanilla attention kernel implementation.""" - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() # mask_value: float = DEFAULT_MASK_VALUE logits = torch.einsum( "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index aa2b7d58..04e5bca7 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -120,8 +120,9 @@ def finalize(self): self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(self.new_vs._elem) else: batch = jnp.arange(self.env.batch_size) - self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos].set(self.new_ks._elem) - self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos].set(self.new_vs._elem) + layer, batch, head, len, dim = self.cache_k.shape + self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos].set(self.new_ks._elem.reshape(batch, layer, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos].set(self.new_vs._elem.reshape(batch, layer, head, dim)) def update(self, key, value, layer_id:int): """Update kv cache""" diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 78f8da9f..72653a48 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -86,10 +86,22 @@ ) flags.DEFINE_bool( "ring_buffer", - True, + False, "Whether to enable ring buffer", required=False, ) +flags.DEFINE_bool( + "flash_attention", + True, + "Whether to enable flas attention", + required=False, +) +flags.DEFINE_bool( + "generate_cache_stacked", + True, + "Whether to stack the generate cache to the layer dimension", + required=False, +) flags.DEFINE_float( "temperature", 1.0, @@ -184,6 +196,8 @@ def create_engine_from_config_flags(): nucleus_topp=FLAGS.nucleus_topp, topk=FLAGS.topk, ring_buffer=FLAGS.ring_buffer, + flash_attention=FLAGS.flash_attention, + generate_cache_stacked=FLAGS.generate_cache_stacked, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index a8fd4397..44868af5 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -331,7 +331,7 @@ def _insert_no_wrap( tokens = decode_state.tokens.at[slot].set(prefix.token) x = jnp.arange(0, self.env.cache_sequence_length) - cond = jnp.logical_and(x < decode_state.current_position, x >= pos) + cond = jnp.logical_and(x < current_pos, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) start = decode_state.start.at[slot].set( @@ -861,6 +861,8 @@ def create_pytorch_engine( nucleus_topp=None, topk=None, ring_buffer=True, + flash_attention=False, + generate_cache_stacked=False, ) -> PyTorchEngine: """Returns: The pytorch engine.""" diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index a9f7bae9..e4c20fe6 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -94,15 +94,11 @@ class JetEngineEnvironmentData: starting_position: int = 512 # Ring buffer - ring_buffer: bool = True + ring_buffer: bool = False - # Ring buffer - ring_buffer: bool = True - - # flash_attention: bool = True - generate_cache_stacked: bool = False + generate_cache_stacked: bool = True # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") sampling_algorithm: str = "greedy" @@ -133,6 +129,12 @@ def __init__(self, data: JetEngineEnvironmentData): self.generate_cache_stacked = self._data.generate_cache_stacked self.num_layers = self._data.num_layers self.ring_buffer = self._data.ring_buffer + + if self.generate_cache_stacked: + self.cache_shape = (self.num_layers, *self._data.cache_shape) + else: + self.cache_shape = self._data.cache_shape + P = jax.sharding.PartitionSpec num_of_partitions = jax.device_count() @@ -215,26 +217,20 @@ def make_caches_prefill(self): def make_caches_generate(self): """Create kv caches for inference generation""" caches = [] - shape = self._data.cache_shape - if self.generate_cache_stacked: - cache_shape = (self.num_layers, *shape) - layered_cache_count = 1 - else: - cache_shape = shape - layered_cache_count = self.num_layers + layered_cache_count = 1 if self.generate_cache_stacked else self.num_layers for _ in range(layered_cache_count): if self._data.quant_config.enable_kv_quantization: caches.append( cache_manager.Int8KVCacheGenerate.empty( - cache_shape, self.cache_sharding, self, env=self + self.cache_shape, self.cache_sharding, env=self ) ) else: caches.append( cache_manager.KVCacheGenerate.empty( - cache_shape, self.cache_sharding, self, env=self + self.cache_shape, self.cache_sharding, env=self ) ) return caches diff --git a/keys_original b/keys_original deleted file mode 100644 index 272e12a02610df127bad2bea3c64347f0d43e72b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 66746 zcmZ^~c{Ejj^fzuMBx6Y#l4eDk4EO9UO_Gq*r%4fosECS2GE`EL$`FMnNs1C0;K6g0yW)p-o$2kjBt-Uq{X}(g&`R&Hng8pcYKYt% z*{*XeL*(bkcYXQHTH+HHuyRGHs_QD>C99VBP7Lq~Qw>oW;vUj1EJW2kL~VrIG-Z{Yt25G%bw&Q4 z-R=LHbuSsc|0C^?f)G`9P?hP_%gxUre9IUJb>t^R_m2*Lnx9 z-}dF>C;!2+d8YWM^bJXWt`q*54uCm2>fEu@f=xtsGTu0j1{DYx=9i3?L2Gc4Wxm*V z)F%*2Z}Oh`!-dN#L6W$*a44NzD~9~{9_IHMOrcSIup>zpqpgd;Izo;D-g?2F`wL-B zehihJ(1krm7xSjA+T8GB43CU^D~b6Si);GX;=Bj#^uEt&m|Br529)h**lEJUTO(0# z@E!V3`h*sq+m1GOyW`X}IriEWNRvjCz@d|x6mmd?G&?@ivl?4AHuHoPKLaTH>=;pI zikfuXC!jU)f!IIx9*i~|fFtbnsr-?maH`FOo0g2?6GIB2t*;+G4eH0qw`|4hfl^R- zCgGgQ&9w4|5w`3|l$Iy|1G_DAS^L@zIy$07Jk{>ZCXbI$ePE{eu)5q<)+U~5UIyt! z@8`>lmhlq-(e;QgUR^RtRGS`y5!|2juGW)FSgfe2dJs%CRQTo0Sm}GMSh^C`LcLy` zA<2wHzV4GF8q-m1{yYe;MW=&t_Zw8%HV!I|I2EdZ?H*}x>eM2;&2K(|$L<;MZ%h{69kvrn`r7cb;SQW?HHrFF20+FFcYOb9 zA$@B}hGEiTQDs=YV16%xr* zE2Jbo5%Wr>!dtDIoOPm!hHf}RD$N5qurVI=l@I@% zFOZy&NBtJ6ipJgUiX*H}QG&(>nDllcF0yffqXVOGg?xQ!ptTjMj-EiuGuuF_7%7g| zt%XS`zr_Pj27tqxdz5gZ5*Er@VWPDOZp@X(i0T@6w0s99?arlhF$W=eZW?(?$MDSy z5n%qTQ!rk)2RgdFfvIiIw61mmpS&%HDe>!R(=S>4Gi3_?UZ%(mL&ws{;Dr!cE3mz1 zhp3t5C;0164!^fa(sx>U9JIOCa-_)Jrzt?`aM%haI?VTXFhNW`DmYs`1|lj(X#h0(ITc0$85?(P@V;S*0Za7He5Nt^`YP!J0CKC^o8xS1gxCoEq<6+MSW+j z$6sbcaQTcmJYYsN-LM|QryX~4bNmOW^Yz71)qWIcVImHQAehF}V3%|iSsv8Jm7gSZ zY=#o$*kp)D(3Hcy&0*&rHg{@YO7Sz|va=6hxQTIm+ zxQz-G9?$wt0hu!Vv-73oaoY$!G(-cjrx! zy}4vp246gKE`yjrs*bQ&!X3&(>`MApa zI+U;41%;3HL2NfAbhu{91u08-*Y{)C8>fpS(~j`QrBe7aTbE1wmBaFYuOzx{#h`xy zcqJtpCWUUp`qPf+usD~-98hBqiE`Or)gPknt90>sP#Qd~yF_OK9I!4+1Ua7*EVEG= z-X=CU_49~ zV+{+?y-yefh4nO*40xa6ZdzX@$8nvL@bBz?WUvWDr_~R| z*!-jH4xCfrOwHWh${DB1%8!^s#iqLR(8Bb2#fIjDz2_@;>*f%1E#q6H= z^>!aj%TI*;|HY!u#5oXhQV$a5Cc&+yBH@$eIu6oZ$Nk@0V(Ei#LfxMM_;hwTebsqR z8-f`=NXtv|3T{zdk9slis0|FbBM}c9RMLcsi8w509zEB3$=_4QqoIcjCjW>gm#9La z=IIgI{_7$gDeMg%+9FNAoCYr6HE`<1e>8K-A0aV?g>7@zmL4$D$J_E{;+|)lFmR+xG^ z75V~ySt$CP20Uo`m|73bB#mAIKbq8yzb`F;Uh#Rtg5JAWvF{bw7U?7U2fU$6vNF)9 zAd7B_y{K8z14ryprn2$!tnPD;q7%H?`rA*?QA($$@{ZEZ;&_g1SC(#{Z-VOac2w18 zF-_Uh6AIik>E7*VRw+({QteRwdUia?L@NtR^rC3q!3QK?5ylG=dvS;1NX&CgWXHx` z;)km7(0M{+mKseAZUQf(PS zJC2m_+%cx?cCo;&x~3mT9@4?+M>pYd`3Q(fc|-#G)p@cm)3U!hl7b=SvVE58#m$1R4w#b zDknaP7)c6-wc8C(ziONVujqDRhD8X~_Gr}ZqPkFm`Fx|f8;_R~V9doz6P zs)058GRdh$8&~&S3?W;ygnvIL!w{cX_&8@P9jaX{re4eix8ZK$5bZ&@H)tsgeEk8s zsl-6#yeK>Wfu&;3n`C^IlLVszz0l{YlaM3b#PtDD@WiQruD&@+&2!D6_@^mk_L_;C zXB+|#_1z%T^jEn3y*tj>UPaM=Gej8shHNX7X^;09_WQYw9>l5PScS>FMlFI)-Z_mm zTk_dR?<$ymIxjg{lMitjXQ|HQ5Pe=WffUz`g8LeuDCON@R8{N6S}(NthP6KCU%bZ} zSMJeeRU7{MuLaWo3t{W7=VDt|B7MLUIm)m*9$GMjwd+0#MgzWqTyPJrJ30bYX96gv z7eda2&osSbI!A_&lIEsur|esL7f$im7C#p! zt{hDN+GKdvfIxQXV}pOw`ojJ6@6@t&F6r-T1n;u#Q1Z{4UmqICMfpEO*AM7{ zDl)zO8;?W1lc6hVn&!HZj@~#9Rbkg5xbPT#o;-uycTD1OomEgDcLX|b<-qeENg&_+ zgU*bt&KPvo?m)dI##IT_&@U2$? z9nIMX-d@@qa4fh@JnHkDLbkVv z<UgWALC_TQ5bX;522|K#X{`WJE5q%@KVUWAjTuH@=vPlUz)b)(O(9*ZLmDWaU; zbpCTv6|OxAf!}%ErF|l9h)s_xsiJ2g>IRh3`k-X^_@_`vwK)qt;$BjX?>*7t+7_~m zOk=$pL!j}9F5Ry=BKRH>NfMq4AAi23q31%eug7Vkx~;e*YBPN4JDnur5Am2ZTF5RR z1R-r2F!$JWd=osCkNp`>e6vYhbkven)6}J=zqdi$s?k(p!g6npZyHRZmjORL45kVQ5%h zN-JdF(Zc_hL2EhiU*QDVe_M$KHgWv)q8{&b2_t*wtI(>ijqNH?xcT2N+L1aBZEdok z+4T%9au~v1>)(_@Y`pMtqANIN_NP?|GWc%XRE|wP1k-1qgkE~l5Enc5-&_h+zLUGS!}v`J@w4> z64wsQhwVZ<9MxL7@}w>yE)qXvL?t)WDP;gSf^Oz_K>B>u3A116qt*1alDvqYLin)RR9zj2w*3Ri=;3Wp-#OtGizQ?fVTfjlrvcNu-Htsg&h+8aL+*=1p~*=+@jqXz#a)r<(kL znD7NOwDO8*kkAT8&tyZ^+9C8QHyF$M-viEF&SQT>;-78DprAyFcSNLtdZRZEkjtfm z*48-s%pi8~pUb<>9LKe*jIl+1K2`XCC5_f`;8q#{$LBnvRIll@-D??cu*pP;;$}g+ zatOCH3_+vgW>7y~8`tZN6ZV_!5@Rwyfws*@_+x(<)|_5|SM-{Nd1@U(-4Sg_S~!QT zerCec=3k&?YQ+}|Zqi;wJ^Svx^~5zFjG;6@p7or3#cgY)d~k;&?sL0|b6tjVOv82X ziMcBll`4t-4n@+n;i|&-rh72*n>Th$$OiXaXNA)A1ehEH(8tvR69$v*3O&n{M#`BvsdGT-)ay_91Q|=!m+qk3f+dLAFoBGk`#~s4N zTN6d2^Y6uwCtHO5Q;liK-w{0GyD#^fyNRD~H^no%wsKaqH8y=dER<%Szj4_PK`v(F-;=BOTonE(HInmyvx%L@@L%u$HYh>_1IlO#uhmB-e2L7 z&KJtN6b!2+FTv*Hc>eloJXV``;hU9C*gNATjP3sx-d^ZS#`}Bl?~m4k$J=EbY~M-g zZZYImqK1uIGN?>Z78^zu;c}&7yTdeZ>Oq8)r;~KV8VQ za1S(<^x;o7t)S&F1rlq9l6~f3-ZUZytZQD1<6C;LRp7Eu#;!wY@1oan`WJ$Lfbzc(fK=WeIUY8$M7IEGt2kBIKq zpTPNvspOK_6YuWzqwQx>D1Ku_79pyKOVKff$}Ys`l=FqdMM!JkT`L1 zLpzNP`VV6aBguCEc_{w%Q~EQhAMnW#>4e-F@Bn4RpC6wKW^;R@W9c;B+A9*2QVi+Y zS}S;=tB-3$PjE4wiu)tgIXk#F%#D8oarL(F;`cpLeWJ$QZoj7YP6fiK&u&cFu@F`> z9b4P)(EExi+EF?j_pY|(V4uG<(7cU`?f+2yV?%!8&;mv$LqPke5f%=q7v?X#P6cY^Rjx8gXm7JUe_nO!3psiD#tp6xQn& z)Yu8|AT)}+raHm>eH-!j%?#MKdNQk@oQWI7IKF z)*5f(r#ZXEV6eqRaQxB- zwBxp*%k6NIoQR;u_mz1109ooECd08)RPk|VCqSqnmh`_uXTHmql%$;{b(1~xc`(DF z{3Uquycz_3Hx`OM*1+HylZAh&&0_tGaPZlb$cYj?PXEyxS15$T`0Qb5V=$BNZh20c zO#^VV-%LJmsR9l<4F}sA7qp}|!m44}*ixy$l4Z?s&e@Fj#uf>GQxY(sbTB@*&k`w~dW3PVKyS`+O{f`YN*Jfg18GUkypJ zrL=!*j&xz;Da!5r9%2J*!O8s)e-BvA35vV9qWdM<{o)Rd*UsS@t$cbqcMmN}U&P(p zlR)>`N0{N#OL*C_6ql+ll25^#3#TCA+7{u!4|)7>emM35SYA4%Scce%3UyE5L(R87G zA0C_PMOSiTIdo<@nEf@QmJ4zaR~mMH#^Ib*v0Z(Gn{?zUH7vHC zgeu$9=}xEyj}M9=zHf-b7X(7zku$NoX(|1*C@RgJbsv{H`wNvGqfu_^6!^VhKC6rg zftSC=pvsuHLV+ZmG+Pw0xhzUz)8GOVMrcrvy>Ds9g+Y9vYBvqm_#lm59mj7xnzOO< z56~#@fY)pPiN5NGDQ3Mho&2Oh6AqXP3UQu%W8P&Nm+DLYOS(u5w+p2G@S-62_k!!@N%!W%bv9Q#8ZWK3-ACcfMcZV?_YI;eZeDO+{_(t%hQ^z z8~tFIyq_@KID~b_>G9)~VsLy^DEyu912%gvF>7P$7Bj@rUFW@9YAXZS*pC8;=!7i9URWLN4gQ=uN}8^Hrm$y)cnn z!h48Yba%tjn{r%xDob?w*At@_-ho>;>LE;1iz{zx@a#zypkwAmac_^qH@}_q+H?<_ z?U@VOpTZbbeDKQEouIx$0n;zb(fRpS9Q`6ooY~-jFV~rYQ$!QI+hxweyHPl8>lF%2 zy$ef}en9#70-^gEE3B#+B=}xNj*1DwEyeSAQ|Lx`86k_`U(2HJ)s5V_A`VP5s=#S? z1LCGrmTuq^&#Osl`c4@9TLm4P9EBaas^sFm6vOk!vcvs$7`d?>64ias zxzPmID0XAR0WqTO&9RiIf0}ww3U}OjMjLiq0>^#l$v$YeR7NJ}QiP5>^&QYDUGgjo z4C|gs1H^8WA2~p3bt#U|j#0yty9VIm*naqa{YVJx($xklsFPA^4(wI#pu%J?wCeSV z`uZ5)ui8{7xE+rHItRJWvyJp74&cfPb-u8=PIOux#iO(rVB5W&aAn(29;$46Wagble^d;_hwp~s!+Kdd64wOXe-EQ)R$J(Cryoq&Z;uOuG|;WZh}X@z zBU~`)h6|1k!Kk@QT*aBKvggwKN5`eEQw;Iq^ccZv z-)N4HibQu~4g69*iy8v2g6L65Yya*-U+b43IpZU`r#z!b%{{`dKk;bLep3j`nI=Z( zE~DvJ%isb^V1lV5E^Yb(_lEVsf&@);G~bQkDE`=Z7+87qWBoNyf_*3F@ZO{>xly{hp|DXAeOB;267GoxK97XUg?f0vz>Fik&Fmy&z6s)JBi8NM zE1Z06AQXs-_^flkXwW$ae$PK49jNLnH}ZGwbk+mvu+rv;gXCJVFGOX%69^>8~; zjUF7oNu6r|gHPWPKRB*`L7l`fH(o_pW*qF&`&?+knl( zcG9T9&*0I79&qAgExol@7q5D+z-Lk?lu;gJw_9;7?q4^7w+9^%GN1S7dDr5ERj#`5 zBuNgXCGvvrfd=$ivQLsRw?X3G?Fm)>yefFQSmNOt4Y;lLh&D%R3&BA;B|H9mOw)#z zlZ;Ock9w!X#Ss>y;Ao6iyWK$7MW1b_{G(O1X`~qkc*)|Um~K);9lF7wrmxQ`?!%!a zHwUzz`$$hbcquBzC{ePOEKMKY4_9o zJ=n$9m(#-=i}8(K2zU?wE2N&6@EWs?Y~dt}7kmoIdinsQWlmh7)CA7+&WacM-6w~V z?)>?Q0p4g(#GLGTcqK9h+a8ojZ}hCCVM}E3xL-cD*RFws@Qr$=T%trn4L)f&9zruV zVWnj{9Mwu;qqNgJNWX;^ztpFT=?PeJ`hZxKA`$n`j-gyjMK0}Lh{;`b?^atf%;_}; ze00yyl!Le7`1&if>ChmyH!NWHPf92!^9FWW02l4dppwzc#ng3C;HvzFrgm)QvtBOP zv-KSn{5u8E{VILVc}N}&f9QGbDExV?i#xWC=6kb$!Gc});iR4^ce|TSb1YY&XHb%z znQXFHb3cNV{*&+?z12K1`zNGCG3}A{MLnJK^kG<@@EFvgzIYt_oZ87lg;B!S&#}Cs ztMBcxRD#R@Q^l)Iec;XO1e#%OMPb|i3QCDP*uk?JDgsP|*7%_~X5R{oaolw&b{^#t71J^-J+mjLWP z0D46SxG`@qH)%{4_e}1CimMl(?v6@4Xt)fOYf|C+*-S}*Zx*OpD+;L-+vuQ0FEaOf zO}D!G!%M$<@{NK6(0!$!STWIr3r5O=nTs}F!Dph^vlXzWIR!2cyoqi%z6fpp8t`|T zKdN{6B2;n;rTvNyOD4Bj16&p$ets0y|Iw1Jx-*V!pCt=bad$xK#4y&e8kSsL53y6a zr@x^dfmk*Mit=Hr8~qyABy3b!+W8Jj69y06a&rtm5L@e!UQo( z*y`E~6UJ`Dsu&L#(>k1f%XENubsdeGO8 z_>fL@^`WRfI$Ja=4&Vph`tZttr8K{38joJlBo21Cf-;ZGZHxZwgAt|2A+vfjH4W>> zTO0R5MYJ=PeafMK4=$5Q{X^LHZ5qh!Q=!S{k3q$!XqppbBWAiR=jOSQ!gaG0?rHE? zJeQb_;}o{@#v4Ov^M`!yc{5JbZ_~kd=RbnXWPS8V?&8i>!&vp;EiuB}0$!_>ixp-a z5b)L*J?OaP!n8bL*y?VW(0>m89&r`I1CC>hfig}M?~-Ykem>l(FZ6ui3`_iMFf3Oc zvXlF<`ZrhnJg!i3#4n$oRaQgum#%j-*phDUn2mv(i*UC338=9cDdgsCWv%($MemXo zr2oi=WV?6L)N&Kv@H2?(&36gT^6g~y*IhEI|1a9SeH%|0pD5l)yhmQqd3Nvm&B4Q8 zCbOlECXZb;2H)z(!~8LI;H5N_;o}7KFW8OHUyqOK3_^#W;ZS1hMV^{{sQi5a?H!*1 zgT6izgKh4L^G9l7!T2IdUzP^*od4rN%0ubYmtHvQSGu^lF@H3%snZq*r~ML+Y}?0gHJU)OpeIinbPlR6MzU4ubIN@a z3nsSCth2xsp4~{|;{$JrwfYCpzPkdRu^ft_mPz0i@QJM2lko4aYvM$M1blwEFDfn( zIJ<5=9{k$`Ur$%F{_f}ELYW8R%sg9Mm@kmhN)0@E--kBo-j+PxGLFt=OGz2-P;P%k z*lR`P7rdW7862Zsfm)~^KVE8)enfQYzgldX^qQjm`U_J6Pm|*aDeRbW7EPY?!Kh10 zxV9|{7uGJriub+o?~c7N;dMA)UK&X+=1&&;C4CYO>s74rz!v(5Ple^ztlD*1ZBL zk&41M^D+uKy^i#)UW>zvj+6VN^`sD`L}zpwprgxMks3A_z9czgwQn_FyW3lm9@`I= zuU}31wy~t6{6g$uqbs^yc`L^Mhg3ECu@Jgr85d#|HCt~hHG8*@YEBPkX|5u9f0K%{ z4c}5>-gZ7Ac?G|$9Vw{7mEX5|)8N~-pyDad#jgwD%J%Ny*E*Ijm~SRqGQ{1PMhG+Dj3ak9uVa%E??dr!?8nZBHA9vu=%E4;D97A{!wE2% zHUypDe5c#P(lM}X4&Sdl29Yl|mE>Mh;NCrBao@!aFtwzb{QWv<_8on)?0i@1pmJ66 zD#QbCKROTMCW6?y&A2ztLG1V?LnqeW5{9jcr%`p&akSer&@%RiR)RimO&QR2e1O38rDhT!l}3W}}!N%Ay>>~qWn&*A2r>TQa%^_6hO+U{79 z)F|B(r$ckjx4?sUJ@`f04E79<6Hi(tuxS@h@}bceew-3}8uPAp33))GhfT-JHIcLp#pX+t>v9oM^z?=4X)42^lu`YcnNrzf;RQL^{J?k*2h+_qp z^|JKn&j|JmX@|;9E|4DGATf%SB^4_NamMdiylDPlQdzKESW+EFzK@rQP*X3_OS0mP zn)l$4jRzZiF(cd0bNHvO6X=dc&RTS&G)zH_x9Ay?*3893H)n9)!?z*q$P3D=n}zBT zPW0*00bI1UjLw_9qJfzX95*(Vyl#aHx*dnHpo?$4viTtt+%Fefk6J*maX%isM+!xE zvnXV`gE($S09D$Yqwk~a1o^FrI7R8MK`i0AQ>IxrFiR>h^A+h53wsDv*c=FuQv(dzU7r9L7()WjF z!cUX4G|D9mofoDFqgI}U-Bunj**j6t`*@YQ)$GKESJkBX>x0m~?+Ev}dkFu^4B>Ft zFEs4i3pbgtEQQ6TXJ1+x$NL`qW4$p_u&F^D*mDYSP>!%_*EkjR|HQ2$;oX*xtcxCB5L1}%VFtXXS>)KNA)RrdUpv_6p{17D^7?ccO zFT`W1;&7bp-iMzD4dKbjF>G-@ABJq5%Flk=z=9Yz$V!z+`bc#+-n0(ZS%+isv@Q)r z(vvlxPUE$`M^oko_*{elkjD}5_au=)yn<>IjK`B%huCXiS6 z`>-R}6~-Pc5<)NNko~Y=PVAN|4w>s;WwZ+P%UXo~MJ6z%)`~aLHrl14Ot<@baB4sn`TpAi$rgL) z&ucvx-KNeh%|;kvW+Ix{R7eH{x#9SM@2Jkz6>O6Mo#mWyZ$}zlzF-bEkLL0sJ6E0& zM{wqR9fS<%^4-lnL~X}U@jd*!xE|4v=Y{xhxNTtElKC_C`Qn<^+8Z{=q* z?cw(!9iIK?5Z+y4#G;D~7by7f^Oo5Rnr8&Re=*SXWh-{*M6hEI8C+Z_%b(4xx_FTa zD&A#+){z^?KwpW+KhffWTQm4&o+sxNJQ0fLK9g7~y%3+>9s(D`A5pQc1x0RA;3+F+ zKpT{T%^YKRyVf4X;9A9PIBPvS_uDw_0YJ{Q-Fd z4Y_xj9A?*mbmSGO{oFVR+c+P?Jxs-Udo8i}#VQ``C69Aj8sJ0tbC~M5pUus?;U$-? z9DcQcWfnA)WV?6cYg*l?K_De+9jP_5Q4{&wejZ^fgNO5Qfo#!*o)60w_IL$ zba1a2E}06)N(}|4rTQ3jS(ZE4Idu(xj&1upY1^6;E$8O!adfIQ?AUWu-EiBBQ2gQ2_fy%8H zg4ezGrTsVUgP<*SLiyb-9KPF4SlXpm%I35}{)}YFAlY;9qx=J^D5$nWPUWWyj->TRYu$w&J z_JhXaBW!OO4K+0zIq$O-SB$tz_JXmU_s%X2-AxXMOO8T&zB6i>_ND)7i{P!%oYJeQ zwG!RkZ4^*;kLs%4iTR$gY_?)6v}IgkUCAk7?!bAhdP12;A3ILLUGGP9iH8fBDu_55VoA!s^9)=Cm zlkofB%fi?z-BD?HqtJgq9w}#7a8vGcvP^0uS#w$JH>Z!?`D4>~vyCRN^;ykl_sr+I zk1I%8u@PkwPD&*W>pA?i2g;t?K|MuNbhxmE>$5IMjDqw9X~1cr(eP=P_a&WdNB$Nb z*5}e}&j`BeqRtafbx><$4>Ee=$!l^FL1v{HzEhmZ>t5}EhxP|}@Buw8?spX~je1Es z(efyJ<2X%PYC;$HcKO#omOy%W3dh`8g+o;Ck%5Z_+FzS3$@$oq|2sa0hrH5)S*{&Y zrN`~$VA>l@$Cz+e2Srf>b_+w+x$yN2Bg{Q;gN8jnj3#>*3%U8zrN6ePkl_WSj^$nJ zyvJEs88rdt*z9Q<%>JjR*qF?48NU%N#g2nO}j!XvkGp{8IU{;s$o`rM5cVvZ!h z@FBm2@!Gv0%=@05{L&I3GAkN0|4Lx?Z5RIGy8{&DW9V{>Jk(z?;l8~!K;1nPvscF9 z6NO|bIKNQX5Uhh{a=kEA#|ZEFzM}oZI>~qbUV5APQ`kA ztbVf?Vy<@eqvtfp{^vo^bekgf%^rv9Lv~ZA_yT^!7~%bC?`g^6cG4TW4EBCGNcPpM zxPNU7h zM)&$**pPN#P)J_^23sBp<0+4vOTNIgjej7gzZIE&UH~$!ekJ)QZo|JrAF1i(9k`}Z zNyA#M($2dtgo0NcaBukP|bl?TeYMTU3Do<%a=?dZJneJ?L(G|_dewGwv zMp4!2{h9Yh6KY7F{}f%R$)b7B7Sx_Q07!7G`>P z$AUn4SR55A+U}l#RVjJGHsuA}E|teW(si6Wt2g&{8U=kUlu$X=126mI?RA@B553on4tykghDKFUeWFcOvj)wjp^|(^kkv$88Iqvdy*6PCY?>|NG zybp&2Ul$|XxVuJLda^)#a=dGXDk6*ryZ;g0B5%M01*D9J>G-l+4_F_2p5|QNY^Ua- z3}gQI^1Rw`>A!?H+GzhDhPbVgju~$$Y5S5Sq>K-&z{19LmU;l*+aKrdvbL(!LsNA>iX#NydxE5VFJ&r|7HUh(tv;I5LzgdSwcmtBonN z;D;d1G{PgluhQt}fxP#tDz~=Y5=-7}rygDXkiKmWNv2Ylx35bZtNut{4OZfQ|3Z4H zr-nBD_Tg{yKca0S6u#g|GK zk3~`aAJruHmi{$*3&q!~;m*vFJodj~eCkOk*7xtj`j5Uz9^4rVKR^GVmb#DNB->0$ z-+Su<~gJoB*Q{AEnvYPjc)^&78jgXV*TvLp`rny3G!95yjZcYYeUxf81j=&JV z^Yn6dA#6W70Q7EaVCnh$rE41H*~a>vV43lOvud-*U&|dnt?kmPD;mYxz%FdYE5ttq z19@+k$1w52b+C$Z=Yw^_(Ph{~zS`qHUomjOIfia{Ah=Oz^|>gf`{z4uQ5 zLL+ly(JbwvAQSRUu>3Y2qFxl^a#WyV6$6|U+cn2BB}J^+y-#XhyN~n#b=5)jdb;;t z1-*YU5*5DqK+?>!F#CZW3v#(^HROpXbj`RmL|uY0GZpbkr#4?q+6{9HXX1DZH<0X6 z6x^vtm5rjtK+H z$yH-JMJn{bMTX<(hvrh=xviag3h$u)ZBL$P8N)xWo6$?PJ`^~xJ4c2%viXz_I2^E; z0~T-L#Si1jV^JX5{}+$bA1LE4E{5Y<+sS1{ca*NohYQ_y!S#^CyrT;be#u?M3#78J zXK{bpJGDO+>{Diy1IlpTY6}#YXX012F=CH-@*J5HQQ`s4AanT;#qXBmw$VGNK4}pT z?b3uj?q!On6S{h1w{5(}RG-t`^>E4!8%(S@D}>a!(Ub>2!B=50FYj>}M}7693kJrt z;$H^7@l*z}*a2+g4^SI)iZkXp!#M>zIOq5c+L~vID%uiEFLK0{I!chp0lddn!aZXY z@XOK5WG}0UGVwjBLURk0eNzL^va9s8Pd>@49K>Pv?}f?wrtH&W7MU#=2+H|WXrRe` zaQpR3QWl{k>AOCM^JUgaZ|qUzZ;c70q#r{;Z-$B%ufB=FFHBMU_&r|mq((5x9mzVc z3-FY80_~fVf(wT|f+Jt$;n%HkJZM6HyyC6PqrzJ#wBjt?GH-#7wH~aT&EUk8 z?(Dg`xA^VOL6VbGWEa1i!YQj}u)k#|nOv8^(_K4o@P{-CT|%IivJU55ehu~7=7K{- zmk-+U2JLC=#uLB#bZMCdG;FmaKKOfqW>-Ce71EKsvoFE%#uz>()8%uk-b>5Z--lk? zJz=FmmDqce23jXxA?MY*XuI_cs9U&`U);OGe*QXaFmy0(JQziTOV&s${wsm&cdW_j zrvYBnj0A_~)5!PIA@79`7#=g=dYNvl{ZU0YV0lkGlWWAe`9}P;ZIy6RMoH4gv4_y7 zF+nUp$y904KozGH#2t#Mv_wCHY}GcQ>d}Ez`dfw~!uL`|jRIOYd;gwb8LR_`TDqau$RxB6TTWJmS15O$8tl-};zgGl z!Gg5Kl@Z$9`YcZ{)954pc+-h8yYwl;%K=#QX*CO!a7gSo=n%&T^@^e$2RSsgNabzjnU z^1CRtIMv0wJ{5G#qv)+{svvcaW__&_!tos;NIehJ*Np<+ej-nNBn6*4Xp;TXX;d|H z9KG?nEtx;?@#&3OVtixoM_ zPM<%`>?@|MljqN#c~GbxzpjVpe%eqfF2v=>m7%Cj7ZoChl5Cs3chu2g|_{aawC zAcwiubJ@;Afp;ytLLP&+(bR<#;kHwhaMWsK7mu(4=jc=k%dYQ-V>^s7Ks{ba-YduF zR>shlb+w?JSl2ag4j>()%5-`Ln=bzj6_Z}WraQ7Y<>CO`Dvn^gR)1P|qYm`m+kku9 z6Sy`*n{~2(fx%6G+Ekk^JerjUFDLs_3Jt;BLNoFIY3FWZn!3U`fZxiCgfOx=UV?K# zgfK9GrC4Yk_Z$ZfjRjT(l_?pr0a77lYZ;>s-{uS&5JA`wf;GmCkuh1tKsMx_69UOF zYy=Y6{NR8M7tk%i56qd2$nKrY&~8hXecFfr)12Pk^qzm8{LXEfG`-#N^Cu*K-jCa| zKNhzR1WSwa5v2UWHd?qpl#HfdCRw}A_cioO_M1b@?kJ*-w?e7CCW_AX zM{r+z8lNgop+4$pc0|4|_m#mbAK8hDPU*}*2(|h}up=*xk#zrlI+vs&8NZE4(~hsDiRvOc->{uq zyHASY@^2M%IUV#wg&VaDJ`gi}-(^*4CYIThMvd)j`E+Zr_(97}y6GK3{yZNh$4if~ z><>GWMlf~sq@wJ&b>AC(H1LzxR*+k!LhWbN=2jfuk-OR#s;5HrvtWDrg6W>&))W+ z!`v=jlfO&fx4xZu*Cz1LX8C%pknJ>~bOR6SNaqvYduUhhr$lJG`^AWHfq0v1c=dNR z)Hd^+7}~^GVZ*Vc>_-l1|BKg|I%NknZlJ93R1}Y0pGCJ`YNqMk%_Ql{X|@={=>6FN zaxL%-nOl2+?zO!jR)+ORYqAgVFFQVBEjjf}9Dk1s$<@8zPhDkSMYr*coj$Cyq)YnA z{4H6i+d_+~!fCBfF`acw<#n(6kT<21ygq!LB>&k23IY&-00bZa0SG_<0uX=z1Rwwb z2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$## zAOHafKmY;|fB*y_009U<00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;| zfB*y_009U<00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U< z00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz00bZa z0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fWZGkz-UrzS`K>xR_Y!~~sIr_(0J>$G}(f+kU?QOgw< zuZz=ciqpnv^*W7St5dohs=MwTDR-z0qd}!K`$X#OTdgf^^1d52&Ka-S$_h#?`2`l&kd=$+=M=I>USa2uKlIeD2`d-F-zi~K zrE*DFJ_TGOR!W6CMeJJT8u9mQXf$|wMla6{c8htE;j#Sa&+P%_w3X|joU4>PZ?&?` Wc7MFD+#3xZ3aQ{B-<0DYANxB-Sv36s diff --git a/original_scores b/original_scores deleted file mode 100644 index d44b7cb61ccbe9d9e1fc09e10119fcdabeb855ec..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2244 zcmZ{m4OCNQ7{@Q4&IyS41z`x3fEGAF*nr*VrHNA+C`!r#3dgo!vK-s4WPXB%0)9X! zDtR&sPl2g%l8zkNeXf?_2cA$9B##;?mYz)1<42Z7_kwv~#Iy74ZufoP-}}7xIp_bs zDJmytiV}&aLE%k_DU-#JV_2dwW?8i+i{3g_r(rZ<`FTdA!j($zjY3B)W3p&-!?Kn# z2BTG+WzlCZH5heSjDCsLWXURs5Syux5cM#gZ{=rda`a}0z9nL%(WGHg`)-QOj%o*f zj>PPwcH*zv{H12+WRU}7cA2N5LR1ba5uX8!*)@c}s|w&h{MfJ(t=42T>a~o)w8Sc& zYtiW~dR>A+%ZSaQ$@9&_8MAo4c|=%>YMAKU^^Cti>IHr=eci{EXA3zre;gG(h%Y?VsS>3k!XEy7m!W^mtW4n2qN4Hq z;%nIb>Iif>atQYR7)Dy#CXw=B8Y0TdVPi}TKKgSwzTx+nP9GOT^0&0WiS%gP;ui}i z+uYbSmpr&@PtHQ4*@tv)z69$ISF!8MZrTzrZfD&^2ieVO$y`rcIXsO?g%#auGP$iC z>Z9XXkLge71-d#o`srSqd!{p)l<9~SnU3TY+fDEYJOry--3S$aBKEtCJqV{g=?jl< z;3bm@I~2|Gq6_XE&ZM9$D35fzy0N(%ljzfNzBs3_oqg>~F&XJu2;Rp(X8pCk_*Z=( zX}%H#)kT%i)wl{Sm;He0SIzj!4;U~{|j5`sRrwvB6zFJi)*|YL^pV(AkGzW zt?wVh2PK{I^EckW2DTQj%y|zh99_{|+J@WmcfqpresJHZ8klXh^vT^B^46DfNbcBT ztYKog%H&vRyzEHo7rM!7>O;t0g&9vh+==imp(jp$gzX#F(GwyM(u=w%xy6y(sio#oKNs9Y(LqJrElH1*SLW;L{1_fltBEk?xvjV>0q325fg5>z^0UqwCBd1ptQC?<|GMT3p@d}w@P9A6d77a zJpgs$U0cV*>vZR#>GYT<;oR*DEDKJFFoo~?`iFXSK5z@(5zT;m3J9_v>up-#$i3+AjLc2m6savz|V4=O0@`bsG67cm{X7nxadHJ6o|TjxedgD1RXt zr*%CAztyq0=KG_Vd+Q-v(d7YD+N>sr+dtBY0sBy1_Y$7{qKh8Wu^2`i^v6Q|LRkLt zC@k9@Of%9S=$5D%uuPIjK6Ua#_fLXwN@yrlD?d2>p%Sf6fPNYsV z4A!W^PhT70h3MEm}Xjzm@7A1|8Ma9HKM@VH65Hoxn20-TXHfLYR|MxOQ= pos - seqlen) + cond = jnp.logical_and(x < pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) return torchjax.to_torch(res) @@ -91,6 +91,7 @@ def _make_one_cache_for_generate(self, env, pos): # pylint: disable-next=all def test_attention(self): + torch.manual_seed(0) env, model_arg = helpers.make_env_tiny(False) attention_orig = model_original.Attention(model_arg) @@ -137,11 +138,11 @@ def test_attention(self): # insert prefilled cache entry cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_k._elem) cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_v._elem) # self._compare_cache(attention_orig.cache_k, cache_decode.cache_k) @@ -301,10 +302,10 @@ def test_transformer_block(self): # insert prefilled cache entry cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_k._elem) cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_v._elem) # Now do one with decode From eca6de750a4dbacd9e34cea8ba26570f4b8d073c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 3 Jul 2024 17:06:15 +0000 Subject: [PATCH 03/57] Adds lazy_cache_update and restructure the cache flags. --- jetstream_pt/cache_manager.py | 77 +++++++++++++------ jetstream_pt/config.py | 7 ++ jetstream_pt/engine.py | 1 + jetstream_pt/environment.py | 22 ++++-- .../third_party/llama/model_exportable.py | 1 + 5 files changed, 76 insertions(+), 32 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 04e5bca7..2dc58f7a 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -59,6 +59,9 @@ def state(self): """Get prefill cache state""" return self.cache_k, self.cache_v + # Placeholder, to match with GenerateCache + def finalize(self): + return # pylint: disable-next=all def KVCachePrefill_flatten(cache): @@ -106,48 +109,67 @@ def __init__( self.env = env self.stacked = env.generate_cache_stacked # The other way is to store the list and loop over to insert in finalize() - if self.stacked: - layer, batch, heads, time, dim = self.cache_k.shape - self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros((layer, batch, heads, 1, dim)), jnp.zeros((layer, batch, heads, 1, dim)))) + if self.env.lazy_cache_update: + if self.stacked: + layer, batch, heads, time, dim = self.cache_k.shape + new_dim = (layer, batch, heads, 1, dim) + else: + batch, heads, time, dim = self.cache_k.shape + new_dim = (batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_dim), jnp.zeros(new_dim))) + def finalize(self): - if not self.stacked: + if not self.env.lazy_cache_update: return # self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_ks._elem, -2)) # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_vs._elem, -2)) if self.env.ring_buffer: - self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(self.new_ks._elem) - self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(self.new_vs._elem) + self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(self.new_ks._elem) + self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(self.new_vs._elem) else: - batch = jnp.arange(self.env.batch_size) - layer, batch, head, len, dim = self.cache_k.shape - self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos].set(self.new_ks._elem.reshape(batch, layer, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos].set(self.new_vs._elem.reshape(batch, layer, head, dim)) + if self.env.generate_cache_stacked: + layer, batch, head, len, dim = self.cache_k.shape + self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos, :].set(self.new_ks._elem.reshape(batch, layer, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos, :].set(self.new_vs._elem.reshape(batch, layer, head, dim)) + else: + batch, head, len, dim = self.cache_k.shape + self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos, :].set(self.new_ks._elem.reshape(batch, layer, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos, :].set(self.new_vs._elem.reshape(batch, layer, head, dim)) def update(self, key, value, layer_id:int): """Update kv cache""" # Will process in insert() at the end of the transformer forward pass keyj, valuej = torchjax.to_torch((key, value)) - if self.stacked: - self.new_ks[layer_id, :, :, :, :] = keyj - self.new_vs[layer_id, :, :, :, :] = valuej + if self.env.lazy_cache_update: + self.new_ks[layer_id, ...] = keyj + self.new_vs[layer_id, ...] = valuej # self.new_ks.append(value) # self.new_vs.append(value) return self.cache_k[layer_id], self.cache_v[layer_id] + + if self.env.ring_buffer: + # pylint: disable-next=all + self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(keyj) + # pylint: disable-next=all + self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(valuej) else: - if self.env.ring_buffer: - # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) + batch = jnp.arange(self.env.batch_size) + # pylint: disable-next=all + if self.env.generate_cache_stacked: + self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos, :].set( + keyj.squeeze(2) + ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej) + self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos, :].set( + valuej.squeeze(2) + ) else: - batch = jnp.arange(self.env.batch_size) - # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set( + self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos, :].set( keyj.squeeze(2) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set( + self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos, :].set( valuej.squeeze(2) ) return self.cache_k, self.cache_v @@ -162,8 +184,14 @@ def empty(cls, shape, device, env): """Create empty kv caches""" default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 in_shape = shape - k = jnp.zeros(in_shape, device=device, dtype=default_dtype) - v = jnp.zeros(in_shape, device=device, dtype=default_dtype) + if env.testing: + key = jax.random.key(env.seed) + k_key, v_key = jax.random.split(key) + k = jnp.random.uniform(k_key, shape=in_shape, dtype=default_dtype) + v = jnp.random.uniform(v_key, shape=in_shape, dtype=default_dtype) + else: + k = jnp.zeros(in_shape, device=device, dtype=default_dtype) + v = jnp.zeros(in_shape, device=device, dtype=default_dtype) k, v = torchjax.to_torch((k, v)) return cls(k, v, 0, device, env=env) @@ -211,9 +239,8 @@ def __init__( self.input_pos = input_pos self.sharding = sharding self.env = env - self.stacked = env.generate_cache_stacked - if self.stacked: + if self.env.generate_cache_stacked: layer, batch, heads, len, dim = self.cache_k.shape self.new_ks, self.new_vs, self.new_k_scalers, self.new_v_scalers = torchjax.to_torch((jnp.zeros((layer, batch, heads, 1, dim)), jnp.zeros((layer, batch, heads, 1, dim)), jnp.zeros((layer, batch, 1, 1, 1)), jnp.zeros((layer, batch, 1, 1, 1)))) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 72653a48..d7dfd640 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -102,6 +102,12 @@ "Whether to stack the generate cache to the layer dimension", required=False, ) +flags.DEFINE_bool( + "lazy_cache_update", + True, + "Whether to update the cache during attention or delayed until all the layers are done", + required=False, +) flags.DEFINE_float( "temperature", 1.0, @@ -198,6 +204,7 @@ def create_engine_from_config_flags(): ring_buffer=FLAGS.ring_buffer, flash_attention=FLAGS.flash_attention, generate_cache_stacked=FLAGS.generate_cache_stacked, + create_pytorch_engine=FLAGS.create_pytorch_engine, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 44868af5..64340d71 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -863,6 +863,7 @@ def create_pytorch_engine( ring_buffer=True, flash_attention=False, generate_cache_stacked=False, + lazy_cache_update=False, ) -> PyTorchEngine: """Returns: The pytorch engine.""" diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index e4c20fe6..7075fa6a 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -99,6 +99,8 @@ class JetEngineEnvironmentData: flash_attention: bool = True generate_cache_stacked: bool = True + + lazy_cache_update: bool = True # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") sampling_algorithm: str = "greedy" @@ -112,6 +114,8 @@ class JetEngineEnvironmentData: # temperature parameter for scaling probability temperature: float = 1.0 + testing: bool = False + # pylint: disable-next=all class JetEngineEnvironment: @@ -129,6 +133,11 @@ def __init__(self, data: JetEngineEnvironmentData): self.generate_cache_stacked = self._data.generate_cache_stacked self.num_layers = self._data.num_layers self.ring_buffer = self._data.ring_buffer + self.lazy_cache_update = self._data.lazy_cache_update + self.testing = self._data.testing + + if self.lazy_cache_update: + self.flash_attention = True if self.generate_cache_stacked: self.cache_shape = (self.num_layers, *self._data.cache_shape) @@ -149,14 +158,13 @@ def __init__(self, data: JetEngineEnvironmentData): self.replicated = jsharding.NamedSharding(self.mesh, P()) if self.generate_cache_stacked: - self.attention_kv_axis_names = ( - "layer", - "batch", - "num_attn_heads", - "sequence_length", - "head_dim", - ) + "layer", + "batch", + "num_attn_heads", + "sequence_length", + "head_dim", + ) if data.shard_on_batch: self.kv_cache_shard_axis = "batch" else: diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 896ad62f..015cacb8 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -282,6 +282,7 @@ def forward( ragged_batch_index, ragged_block_index, ) + cache.finalize() with jax.named_scope("transformer_norm"): h = self.norm(h) From 841f39382489f23660037bf23a4996224b8bcac6 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 3 Jul 2024 17:31:45 +0000 Subject: [PATCH 04/57] Disable all the prints. Fix create engine. --- jetstream_pt/config.py | 2 +- jetstream_pt/engine.py | 12 ++++++------ jetstream_pt/layers.py | 28 ++++++++++++++-------------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index d7dfd640..d8fbfb8d 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -204,7 +204,7 @@ def create_engine_from_config_flags(): ring_buffer=FLAGS.ring_buffer, flash_attention=FLAGS.flash_attention, generate_cache_stacked=FLAGS.generate_cache_stacked, - create_pytorch_engine=FLAGS.create_pytorch_engine, + lazy_cache_update=FLAGS.lazy_cache_update, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 64340d71..f0084d36 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -447,7 +447,7 @@ def _insert_wrap( old_scales = decode_state.cache_scales cache_inserts = prefix.caches - print(f"YY old_caches: {len(decode_state.caches)} cache_inserts: {len(cache_inserts)}") + # print(f"YY old_caches: {len(decode_state.caches)} cache_inserts: {len(cache_inserts)}") scales = [] caches = [] if not self.env.quant_config.enable_kv_quantization: @@ -677,11 +677,11 @@ def update_mask(): input_pos, mask, ) - print( - "new_pos", - (decode_state.current_position + 1) % self.env.cache_sequence_length, - ) - print(f"new_token: {jnp.squeeze(next_token)}") + # print( + # "new_pos", + # (decode_state.current_position + 1) % self.env.cache_sequence_length, + # ) + # print(f"new_token: {jnp.squeeze(next_token)}") return new_decode_state, result_tokens # pylint: disable-next=all diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index ca703073..6af6126f 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -451,9 +451,9 @@ def attend(xq, keys, values, local_mask=None): if local_denom is not None: local_denom = local_denom[:, :, 0:1, :] - print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") - if local_max is not None and local_denom is not None: - print(f"local_max {local_max.shape} local_denom {local_denom.shape}") + # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") + # if local_max is not None and local_denom is not None: + # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") self.env.apply_sharding(local_output, axis=self.shard_axis) return local_output, (local_max, local_denom) @@ -464,7 +464,7 @@ def attend(xq, keys, values, local_mask=None): keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) - print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") + # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, mask) @@ -481,8 +481,8 @@ def attend(xq, keys, values, local_mask=None): # return new_output with jax.named_scope("attn_global"): - print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") - print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") + # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") + # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) existing_output = existing_output * existing_denom * torch.exp(existing_max) / global_sum @@ -567,9 +567,9 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): local_max = local_max[:, :, 0:1, :] local_denom = local_denom[:, :, 0:1, :] - print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") - if local_max is not None and local_denom is not None: - print(f"local_max {local_max.shape} local_denom {local_denom.shape}") + # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") + # if local_max is not None and local_denom is not None: + # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") self.env.apply_sharding(local_output, axis=self.shard_axis) return local_output, (local_max, local_denom) @@ -578,7 +578,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) - print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") + # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, k_scaler, v_scaler, mask) @@ -595,8 +595,8 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # return new_output with jax.named_scope("attn_global"): - print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") - print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") + # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") + # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) existing_output = existing_output * existing_denom * torch.exp(existing_max) / global_sum @@ -718,7 +718,7 @@ def forward( xq = xq.transpose(1, 2) if cache is not None and cache.cache_k is not None: - print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") + # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( xq, xk, @@ -731,6 +731,6 @@ def forward( ragged_batch_index, ragged_block_index, ).type_as(xq) - print(f"output {output.shape}") + # print(f"output {output.shape}") output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) From 90ffbbbd67373bdded7ef9295513f8f120a55853 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 3 Jul 2024 17:39:54 +0000 Subject: [PATCH 05/57] Fix typos and minor errors. --- jetstream_pt/cache_manager.py | 6 +++--- jetstream_pt/environment.py | 3 +++ jetstream_pt/layers.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 2dc58f7a..0dec446b 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -185,10 +185,10 @@ def empty(cls, shape, device, env): default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 in_shape = shape if env.testing: - key = jax.random.key(env.seed) + key = jax.random.key(env.testing_seed) k_key, v_key = jax.random.split(key) - k = jnp.random.uniform(k_key, shape=in_shape, dtype=default_dtype) - v = jnp.random.uniform(v_key, shape=in_shape, dtype=default_dtype) + k = jax.random.uniform(k_key, shape=in_shape, dtype=default_dtype) + v = jax.random.uniform(v_key, shape=in_shape, dtype=default_dtype) else: k = jnp.zeros(in_shape, device=device, dtype=default_dtype) v = jnp.zeros(in_shape, device=device, dtype=default_dtype) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 7075fa6a..e5c2b9b7 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -116,6 +116,8 @@ class JetEngineEnvironmentData: testing: bool = False + testing_seed: int = 0 + # pylint: disable-next=all class JetEngineEnvironment: @@ -135,6 +137,7 @@ def __init__(self, data: JetEngineEnvironmentData): self.ring_buffer = self._data.ring_buffer self.lazy_cache_update = self._data.lazy_cache_update self.testing = self._data.testing + self.testing_seed = self._data.testing_seed if self.lazy_cache_update: self.flash_attention = True diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 6af6126f..9aa48ce4 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -717,7 +717,7 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - if cache is not None and cache.cache_k is not None: + # if cache is not None and cache.cache_k is not None: # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( xq, From dd4de9ee93adc5e1772b077fa303e1416d762608 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 3 Jul 2024 23:45:04 +0000 Subject: [PATCH 06/57] Fixes create engine. --- jetstream_pt/engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index f0084d36..8d9e9f86 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -934,6 +934,9 @@ def create_pytorch_engine( nucleus_topp=nucleus_topp, topk=topk, ring_buffer=ring_buffer, + flash_attention=flash_attention, + generate_cache_stacked=generate_cache_stacked, + lazy_cache_update=lazy_cache_update, ) if shard_on_batch and sharding_config: From 91181cd4d76dabc1b912dd71f5745f0280b3ccf9 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 4 Jul 2024 00:44:05 +0000 Subject: [PATCH 07/57] Adds new_cache_stacked and fixes cache update. --- jetstream_pt/cache_manager.py | 56 +++++++++++++++++++++-------------- jetstream_pt/config.py | 7 +++++ jetstream_pt/engine.py | 2 ++ jetstream_pt/environment.py | 3 ++ 4 files changed, 46 insertions(+), 22 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 0dec446b..50216b96 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -107,16 +107,21 @@ def __init__( self.new_ks = None self.new_vs = None self.env = env + # Keep this one it's used in the specific model code. self.stacked = env.generate_cache_stacked + self.batch = jnp.arange(self.env.batch_size) + # The other way is to store the list and loop over to insert in finalize() if self.env.lazy_cache_update: - if self.stacked: - layer, batch, heads, time, dim = self.cache_k.shape - new_dim = (layer, batch, heads, 1, dim) - else: - batch, heads, time, dim = self.cache_k.shape - new_dim = (batch, heads, 1, dim) - self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_dim), jnp.zeros(new_dim))) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + layer, batch, heads, time, dim = self.cache_k.shape + new_dim = (layer, batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_dim), jnp.zeros(new_dim))) + else: + self.new_ks, self.new_vs = [], [] + else: # when generate cache is not stacked, new cache cannot stack + assert not self.env.new_cache_stacked def finalize(self): @@ -129,23 +134,31 @@ def finalize(self): self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(self.new_vs._elem) else: if self.env.generate_cache_stacked: - layer, batch, head, len, dim = self.cache_k.shape - self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos, :].set(self.new_ks._elem.reshape(batch, layer, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos, :].set(self.new_vs._elem.reshape(batch, layer, head, dim)) + layer, b, head, len, dim = self.cache_k.shape + if self.env.new_cache_stacked: + self.cache_k._elem = self.cache_k._elem.at[:, self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, layer, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[:, self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, layer, head, dim)) + else: + def body_func(i): + self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) + _ = jax.lax.fori_loop(0, self.env.num_layers, body_func) else: - batch, head, len, dim = self.cache_k.shape - self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos, :].set(self.new_ks._elem.reshape(batch, layer, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos, :].set(self.new_vs._elem.reshape(batch, layer, head, dim)) + b, head, len, dim = self.cache_k.shape + self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, head, dim)) def update(self, key, value, layer_id:int): """Update kv cache""" # Will process in insert() at the end of the transformer forward pass keyj, valuej = torchjax.to_torch((key, value)) if self.env.lazy_cache_update: - self.new_ks[layer_id, ...] = keyj - self.new_vs[layer_id, ...] = valuej - # self.new_ks.append(value) - # self.new_vs.append(value) + if self.env.new_cache_stacked: + self.new_ks[layer_id, ...] = keyj + self.new_vs[layer_id, ...] = valuej + else: + self.new_ks.append(keyj) + self.new_vs.append(valuej) return self.cache_k[layer_id], self.cache_v[layer_id] if self.env.ring_buffer: @@ -154,22 +167,21 @@ def update(self, key, value, layer_id:int): # pylint: disable-next=all self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(valuej) else: - batch = jnp.arange(self.env.batch_size) # pylint: disable-next=all if self.env.generate_cache_stacked: - self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos, :].set( + self.cache_k._elem = self.cache_k._elem.at[:, self.batch, :, self.pos, :].set( keyj.squeeze(2) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos, :].set( + self.cache_v._elem = self.cache_v._elem.at[:, self.batch, :, self.pos, :].set( valuej.squeeze(2) ) else: - self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos, :].set( + self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set( keyj.squeeze(2) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos, :].set( + self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.pos, :].set( valuej.squeeze(2) ) return self.cache_k, self.cache_v diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index d8fbfb8d..c2e44f6e 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -102,6 +102,12 @@ "Whether to stack the generate cache to the layer dimension", required=False, ) +flags.DEFINE_bool( + "new_cache_stacked", + True, + "Whether to stack the generate cache to the layer dimension", + required=False, +) flags.DEFINE_bool( "lazy_cache_update", True, @@ -204,6 +210,7 @@ def create_engine_from_config_flags(): ring_buffer=FLAGS.ring_buffer, flash_attention=FLAGS.flash_attention, generate_cache_stacked=FLAGS.generate_cache_stacked, + new_cache_stacked=FLAGS.new_cache_stacked, lazy_cache_update=FLAGS.lazy_cache_update, ) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 8d9e9f86..7eaa3b1d 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -863,6 +863,7 @@ def create_pytorch_engine( ring_buffer=True, flash_attention=False, generate_cache_stacked=False, + new_cache_stacked=False, lazy_cache_update=False, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -936,6 +937,7 @@ def create_pytorch_engine( ring_buffer=ring_buffer, flash_attention=flash_attention, generate_cache_stacked=generate_cache_stacked, + new_cache_stacked=new_cache_stacked, lazy_cache_update=lazy_cache_update, ) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index e5c2b9b7..ea103a9e 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -100,6 +100,8 @@ class JetEngineEnvironmentData: generate_cache_stacked: bool = True + new_cache_stacked: bool = True + lazy_cache_update: bool = True # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") @@ -133,6 +135,7 @@ def __init__(self, data: JetEngineEnvironmentData): self.starting_position = self._data.starting_position self.flash_attention = self._data.flash_attention self.generate_cache_stacked = self._data.generate_cache_stacked + self.new_cache_stacked = self._data.new_cache_stacked self.num_layers = self._data.num_layers self.ring_buffer = self._data.ring_buffer self.lazy_cache_update = self._data.lazy_cache_update From 50a83d4d07c2e703e921e0ff69ca6140d6961cec Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 4 Jul 2024 01:06:14 +0000 Subject: [PATCH 08/57] Fix cache update when new_cach_stacked is False. --- jetstream_pt/cache_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 50216b96..a589a17f 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -157,8 +157,12 @@ def update(self, key, value, layer_id:int): self.new_ks[layer_id, ...] = keyj self.new_vs[layer_id, ...] = valuej else: - self.new_ks.append(keyj) - self.new_vs.append(valuej) + if self.env.generate_cache_stacked: + self.new_ks.append(keyj) + self.new_vs.append(valuej) + else: + self.new_ks = keyj + self.new_vs = valuej return self.cache_k[layer_id], self.cache_v[layer_id] if self.env.ring_buffer: From 0336fb55f90c3e1d2ea9419a67eec1651d375162 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Sun, 7 Jul 2024 20:49:20 +0000 Subject: [PATCH 09/57] Fix the cache manager and make unit tests pass except for 1. --- jetstream_pt/cache_manager.py | 23 +++--- jetstream_pt/engine.py | 2 +- tests/helpers.py | 4 +- tests/test_llama_e2e.py | 129 ++++++++++++++++++++++++++++++++++ 4 files changed, 148 insertions(+), 10 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index a589a17f..a556af7a 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -130,6 +130,7 @@ def finalize(self): # self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_ks._elem, -2)) # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_vs._elem, -2)) if self.env.ring_buffer: + # Assume no cache stack for ring buffer self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(self.new_ks._elem) self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(self.new_vs._elem) else: @@ -139,10 +140,10 @@ def finalize(self): self.cache_k._elem = self.cache_k._elem.at[:, self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, layer, head, dim)) self.cache_v._elem = self.cache_v._elem.at[:, self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, layer, head, dim)) else: - def body_func(i): + def body_func(i:int, _): self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) - _ = jax.lax.fori_loop(0, self.env.num_layers, body_func) + _ = jax.lax.fori_loop(0, self.env.num_layers, body_func, init_val=jnp.zeros((self.env.num_layers,))) else: b, head, len, dim = self.cache_k.shape self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) @@ -156,31 +157,37 @@ def update(self, key, value, layer_id:int): if self.env.new_cache_stacked: self.new_ks[layer_id, ...] = keyj self.new_vs[layer_id, ...] = valuej + return self.cache_k[layer_id], self.cache_v[layer_id] else: if self.env.generate_cache_stacked: self.new_ks.append(keyj) self.new_vs.append(valuej) + return self.cache_k[layer_id], self.cache_v[layer_id] else: self.new_ks = keyj self.new_vs = valuej - return self.cache_k[layer_id], self.cache_v[layer_id] + return self.cache_k, self.cache_v + if self.env.ring_buffer: + # Assume no cache stack for ring buffer # pylint: disable-next=all self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(keyj) - # pylint: disable-next=all self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(valuej) + return self.cache_k, self.cache_v else: - # pylint: disable-next=all if self.env.generate_cache_stacked: - self.cache_k._elem = self.cache_k._elem.at[:, self.batch, :, self.pos, :].set( + # pylint: disable-next=all + self.cache_k._elem = self.cache_k._elem.at[layer_id, self.batch, :, self.pos, :].set( keyj.squeeze(2) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[:, self.batch, :, self.pos, :].set( + self.cache_v._elem = self.cache_v._elem.at[layer_id, self.batch, :, self.pos, :].set( valuej.squeeze(2) ) + return self.cache_k[layer_id], self.cache_v[layer_id] else: + # pylint: disable-next=all self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set( keyj.squeeze(2) ) @@ -188,7 +195,7 @@ def update(self, key, value, layer_id:int): self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.pos, :].set( valuej.squeeze(2) ) - return self.cache_k, self.cache_v + return self.cache_k, self.cache_v def state(self): """Get kv cache state""" diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 7eaa3b1d..b8832334 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -616,7 +616,7 @@ def update_mask(): mask = decode_state.mask if not self.env.flash_attention: - mask = update_mask(mask, decode_state.current_position) + mask = update_mask() logits, new_caches, new_scales = self._call_model_generate( params, decode_state.tokens, diff --git a/tests/helpers.py b/tests/helpers.py index 00442517..62c0789b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,7 +6,7 @@ from jetstream_pt import environment -def make_env_tiny(bf16_enable=True): +def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) @@ -26,6 +26,8 @@ def make_env_tiny(bf16_enable=True): environment_data.cache_sequence_length, config.dim // config.n_heads, ) + environment_data.testing = True + env_data_update_fn(environment_data) env = environment.JetEngineEnvironment(environment_data) env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu return env, config diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index dcbcf5f2..1f3108e7 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -223,6 +223,135 @@ def test_llama_e2e_float32(self): out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertEqual(out_tokens, expected_output_tokens) + def test_llama_e2e_float32_left_aligned_cache(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.flash_attention=False + env_data.generate_cache_stacked=False + env_data.new_cache_stacked=False + env_data.lazy_cache_update=False + env, model_arg = helpers.make_env_tiny(False, update_env_data) + + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + + def test_llama_e2e_float32_left_aligned_generate_cache_stacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.flash_attention=False + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=False + env_data.lazy_cache_update=False + env, model_arg = helpers.make_env_tiny(False, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + def test_llama_e2e_float32_left_aligned_new_cache_stacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.flash_attention=False + env_data.generate_cache_stacked=False + env_data.new_cache_stacked=True + env_data.lazy_cache_update=False + env, model_arg = helpers.make_env_tiny(False, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + def test_llama_e2e_float32_left_aligned_all_cache_stacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.flash_attention=False + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=True + env_data.lazy_cache_update=False + env, model_arg = helpers.make_env_tiny(False, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + + def test_llama_e2e_float32_left_aligned_lazy_cache_update(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.flash_attention=True + env_data.generate_cache_stacked=False + env_data.new_cache_stacked=False + env_data.lazy_cache_update=True + env, model_arg = helpers.make_env_tiny(False, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + + # Foriloop has issue + def test_llama_e2e_float32_left_aligned_lazy_cache_update_generate_cache_stacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.flash_attention=True + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=False + env_data.lazy_cache_update=True + env, model_arg = helpers.make_env_tiny(False, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + + @unittest.skip("When generate cache is not stacked, new cache cannot stack") + def test_llama_e2e_float32_left_aligned_lazy_cache_update_new_cache_stacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.flash_attention=True + env_data.generate_cache_stacked=False + env_data.new_cache_stacked=True + env_data.lazy_cache_update=True + env, model_arg = helpers.make_env_tiny(False, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + + def test_llama_e2e_float32_left_aligned_lazy_cache_update_all_cache_stacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.flash_attention=True + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=True + env_data.lazy_cache_update=True + env, model_arg = helpers.make_env_tiny(False, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + def test_llama_e2e_bfloat16(self): "end to end jetstream llama test with bfloat16" jax.config.update("jax_platform_name", "cpu") From ba3d3852ba5b36fd1e882dbd0c4c42837d48add9 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Sun, 7 Jul 2024 21:22:13 +0000 Subject: [PATCH 10/57] Updates the exportable model to return cache. --- jetstream_pt/third_party/llama/model_exportable.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 015cacb8..499fd10a 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -162,7 +162,7 @@ def forward( with jax.named_scope("ffn"): out = h + self.feed_forward.forward(ffns) - return out + return out, cache def precompute_freqs_cis( @@ -263,16 +263,17 @@ def forward( # Should check more thoroughly, as of now, when prefill, it's always not stacked. When generate, it's controlled by the parameter. # target_cache_layers = 1 if self.env.generate_cache_stacked else len(self.layers) # assert len(caches) == target_cache_layers, f"Number of caches ({len(caches)}) and layers ({target_cache_layers}) dont match" - + end = None if start is None else (start + input_pos) % self.env.cache_len + # For stacked case, cannot get cache inside the loop which will cause cache copy + cache = caches[0] for layer_id, layer in enumerate(self.layers): if not caches[0].stacked: cache = caches[layer_id] - else: # For stacked case, there is only 1 layer of kv cache - cache = caches[0] + # else: # For stacked case, there is only 1 layer of kv cache with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): - h = layer( + h, cache = layer( h, freqs_cis, mask, @@ -282,7 +283,7 @@ def forward( ragged_batch_index, ragged_block_index, ) - cache.finalize() + cache.finalize() with jax.named_scope("transformer_norm"): h = self.norm(h) From c808ca8dcc49f15f232a60c612e2c661881c1902 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 8 Jul 2024 17:50:04 +0000 Subject: [PATCH 11/57] Removed the fori loop in cache finalize. Moves the cache.finalize() to the end of existing cache attention. --- jetstream_pt/cache_manager.py | 4 ++-- jetstream_pt/layers.py | 3 ++- jetstream_pt/third_party/llama/model_exportable.py | 6 +++--- tests/test_llama_e2e.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index a556af7a..97703800 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -140,10 +140,9 @@ def finalize(self): self.cache_k._elem = self.cache_k._elem.at[:, self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, layer, head, dim)) self.cache_v._elem = self.cache_v._elem.at[:, self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, layer, head, dim)) else: - def body_func(i:int, _): + for i in range(self.env.num_layers): self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) - _ = jax.lax.fori_loop(0, self.env.num_layers, body_func, init_val=jnp.zeros((self.env.num_layers,))) else: b, head, len, dim = self.cache_k.shape self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) @@ -154,6 +153,7 @@ def update(self, key, value, layer_id:int): # Will process in insert() at the end of the transformer forward pass keyj, valuej = torchjax.to_torch((key, value)) if self.env.lazy_cache_update: + # When new cache stacked, must have generate_cache_stacked if self.env.new_cache_stacked: self.new_ks[layer_id, ...] = keyj self.new_vs[layer_id, ...] = valuej diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 9aa48ce4..50ee6e5d 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -467,7 +467,8 @@ def attend(xq, keys, values, local_mask=None): # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, mask) - + with jax.named_scope("attn_cache_lazy_update"): + cache.finalize() # For non flash attention or prefill, existing output contains everything if not self.env.flash_attention or seqlen > 1: return existing_output diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 499fd10a..4fa26e46 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -162,7 +162,7 @@ def forward( with jax.named_scope("ffn"): out = h + self.feed_forward.forward(ffns) - return out, cache + return out def precompute_freqs_cis( @@ -273,7 +273,7 @@ def forward( # else: # For stacked case, there is only 1 layer of kv cache with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): - h, cache = layer( + h = layer( h, freqs_cis, mask, @@ -283,7 +283,7 @@ def forward( ragged_batch_index, ragged_block_index, ) - cache.finalize() + # cache.finalize() with jax.named_scope("transformer_norm"): h = self.norm(h) diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 1f3108e7..916485f3 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -302,7 +302,7 @@ def update_env_data(env_data): self.assertEqual(out_tokens, expected_output_tokens) - # Foriloop has issue + # Won't work after removed the cache.finalize() in the Transformer def test_llama_e2e_float32_left_aligned_lazy_cache_update_generate_cache_stacked(self): """end to end jetstream llama test with float32""" jax.config.update("jax_platform_name", "cpu") From e2874fc80be15babc754872c70883183a57df6d2 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 8 Jul 2024 18:27:47 +0000 Subject: [PATCH 12/57] Try to use shard_map for cache update. --- jetstream_pt/cache_manager.py | 17 +++++++++++++++-- jetstream_pt/environment.py | 8 ++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 97703800..f2275ffb 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -16,6 +16,7 @@ import jax.numpy as jnp import torch from jetstream_pt import torchjax +from jax.experimental.shard_map import shard_map # pylint: disable-next=all @@ -123,6 +124,16 @@ def __init__( else: # when generate cache is not stacked, new cache cannot stack assert not self.env.new_cache_stacked + cache_pspec = self.env.partition_by_axis(self.cache_sharding_axis) # Number of heads + in_specs = (cache_pspec, cache_pspec) + out_specs = (cache_pspec, cache_pspec) + self.update_single_cache_line = shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs) + + def update_single_cache_line(self, cache_k, cache_v): + b, head, len, dim = cache_k.shape + cache_k._elem = cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) + cache_v._elem = cache_v._elem.at[self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, head, dim)) + return cache_k, cache_v def finalize(self): if not self.env.lazy_cache_update: @@ -144,9 +155,11 @@ def finalize(self): self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) else: + # Try to use shard_map to get rid of the data copy b, head, len, dim = self.cache_k.shape - self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, head, dim)) + self.cache_k, self.cache_v = self.update_single_cache_line(self.cache_k, self.cache_v) + # self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) + # self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, head, dim)) def update(self, key, value, layer_id:int): """Update kv cache""" diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index ea103a9e..5ec2ce08 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -176,16 +176,16 @@ def __init__(self, data: JetEngineEnvironmentData): else: self.kv_cache_shard_axis = "num_attn_heads" - cache_sharding_axis = self.attention_kv_axis_names.index( + self.cache_sharding_axis = self.attention_kv_axis_names.index( self.kv_cache_shard_axis ) - if self.cache_shape[cache_sharding_axis] == 1: + if self.cache_shape[self.cache_sharding_axis] == 1: # cannot shard on an axis that is 1 # default to last - cache_sharding_axis = len(self.cache_shape) - 1 + self.cache_sharding_axis = len(self.cache_shape) - 1 - self.cache_sharding = self.sharding_by_axis(cache_sharding_axis) + self.cache_sharding = self.sharding_by_axis(self.cache_sharding_axis) self._load_sharding_config() def _load_sharding_config(self): From 0015e901616d5174ab2db04a1c144bad11e2e5d8 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 8 Jul 2024 21:15:27 +0000 Subject: [PATCH 13/57] Fix update single cache line in cache.finalize() --- jetstream_pt/cache_manager.py | 27 +++++++++++-------- jetstream_pt/layers.py | 8 +++--- .../third_party/llama/model_exportable.py | 1 - 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index f2275ffb..eb861950 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -17,7 +17,7 @@ import torch from jetstream_pt import torchjax from jax.experimental.shard_map import shard_map - +import torch_xla2 # pylint: disable-next=all class CacheInterface: @@ -124,15 +124,20 @@ def __init__( else: # when generate cache is not stacked, new cache cannot stack assert not self.env.new_cache_stacked - cache_pspec = self.env.partition_by_axis(self.cache_sharding_axis) # Number of heads - in_specs = (cache_pspec, cache_pspec) + cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads + none_pspec = self.env.partition_by_axis() + in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec) out_specs = (cache_pspec, cache_pspec) - self.update_single_cache_line = shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs) + self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) + + def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs): + b, head, _, dim = cache_k.shape + for bb, pp in enumerate(self.pos.reshape(b)): + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, 0) + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, 0) + cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, (bb, 0, pp, 0)) + cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, (bb, 0, pp, 0)) - def update_single_cache_line(self, cache_k, cache_v): - b, head, len, dim = cache_k.shape - cache_k._elem = cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) - cache_v._elem = cache_v._elem.at[self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, head, dim)) return cache_k, cache_v def finalize(self): @@ -156,8 +161,8 @@ def finalize(self): self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) else: # Try to use shard_map to get rid of the data copy - b, head, len, dim = self.cache_k.shape - self.cache_k, self.cache_v = self.update_single_cache_line(self.cache_k, self.cache_v) + b, head, _, dim = self.cache_k.shape + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) # self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) # self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, head, dim)) @@ -355,4 +360,4 @@ def finalize(self): self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos].set(self.new_ks._elem) self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos].set(self.new_vs._elem) self.k_scaler._elem = self.k_scaler._elem.at[:, batch, :, self.pos].set(self.new_k_scalers._elem) - self.v_scaler._elem = self.v_scaler._elem.at[:, batch, :, self.pos].set(self.new_v_scalers._elem) \ No newline at end of file + self.v_scaler._elem = self.v_scaler._elem.at[:, batch, :, self.pos].set(self.new_v_scalers._elem) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 50ee6e5d..2edfaf2b 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -458,11 +458,11 @@ def attend(xq, keys, values, local_mask=None): return local_output, (local_max, local_denom) - + #import pdb; pdb.set_trace() with jax.named_scope("attn_insert_cache"): - keys, values = cache.update(xk, xv, self.layer_id) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + orig_keys, orig_values = cache.update(xk, xv, self.layer_id) + keys = repeat_kv(orig_keys, n_rep) + values = repeat_kv(orig_values, n_rep) # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 4fa26e46..1b288dbc 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -250,7 +250,6 @@ def forward( ragged_batch_index: precomputed batch index for ragged attention ragged_block_index: precomputed block index for ragged attention """ - with jax.named_scope("transformer_tok"): seqlen = tokens.shape[-1] h = self.tok_embeddings(tokens) From 7661bb2c6d02c55574b3d02bc36a7cffba8e06a9 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 8 Jul 2024 21:32:40 +0000 Subject: [PATCH 14/57] Adds int8 support. --- jetstream_pt/cache_manager.py | 75 ++++++++++++++++++----------------- jetstream_pt/layers.py | 9 +++-- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index eb861950..6860becb 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -163,8 +163,6 @@ def finalize(self): # Try to use shard_map to get rid of the data copy b, head, _, dim = self.cache_k.shape self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) - # self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, head, dim)) - # self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, head, dim)) def update(self, key, value, layer_id:int): """Update kv cache""" @@ -277,13 +275,29 @@ def __init__( self.cache_v = cache_v self.k_scaler = cache_k_scaler self.v_scaler = cache_v_scaler + self.new_ks = None + self.new_vs = None + self.new_k_scaler = None + self.new_v_scaler = None + + self.batch = jnp.arange(self.env.batch_size) self.input_pos = input_pos self.sharding = sharding self.env = env - if self.env.generate_cache_stacked: - layer, batch, heads, len, dim = self.cache_k.shape - self.new_ks, self.new_vs, self.new_k_scalers, self.new_v_scalers = torchjax.to_torch((jnp.zeros((layer, batch, heads, 1, dim)), jnp.zeros((layer, batch, heads, 1, dim)), jnp.zeros((layer, batch, 1, 1, 1)), jnp.zeros((layer, batch, 1, 1, 1)))) + cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads + none_pspec = self.env.partition_by_axis() + in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec) + out_specs = (cache_pspec, cache_pspec) + self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) + + def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs): + b, head, _, dim = cache_k.shape + for bb, pp in enumerate(self.pos.reshape(b)): + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, 0) + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, 0) + cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, (bb, 0, pp, 0)) + cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, (bb, 0, pp, 0)) def state(self): """Get kv cache state""" @@ -300,12 +314,8 @@ def empty(cls, shape, device, env): cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) - if env.generate_cache_stacked: - kscaler = jnp.ones((shape[0], shape[1], 1, shape[2], 1), dtype=jnp.bfloat16) - vscaler = jnp.ones((shape[0], shape[1], 1, shape[2], 1), dtype=jnp.bfloat16) - else: - kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) - vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) + kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) + vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) @@ -324,40 +334,33 @@ def update(self, xk, xv, layer_id:int): k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) - if self.stacked: - self.new_ks[layer_id, ...] = k_quant - self.new_vs[layer_id, ...] = v_quant - self.new_k_scalers[layer_id, ...] = kscale - self.new_v_scalers[layer_id, ...] = vscale - - return self.cache_k[layer_id], self.cache_v[layer_id], k_quant, v_quant, self.k_scaler[layer_id], self.v_scaler[layer_id], kscale, vscale - + if self.env.lazy_cache_update: + self.new_ks = k_quant + self.new_vs = v_quant + self.new_k_scaler = kscale + self.new_v_scaler = vscale + if self.env.ring_buffer: self.cache_k[:, :, self.input_pos, :] = k_quant self.cache_v[:, :, self.input_pos, :] = v_quant self.k_scaler[:, :, self.input_pos, :] = kscale self.v_scaler[:, :, self.input_pos, :] = vscale else: - batch = jnp.arange(self.env.batch_size) - self.cache_k[batch, :, self.input_pos, :] = k_quant.squeeze(2) - self.cache_v[batch, :, self.input_pos, :] = v_quant.squeeze(2) - self.k_scaler[batch, :, self.input_pos, :] = kscale.squeeze(2) - self.v_scaler[batch, :, self.input_pos, :] = vscale.squeeze(2) + self.cache_k[self.batch, :, self.input_pos, :] = k_quant.squeeze(2) + self.cache_v[self.batch, :, self.input_pos, :] = v_quant.squeeze(2) + self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2) + self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2) return self.cache_k, self.cache_v, k_quant, v_quant, self.k_scaler, self.v_scaler, kscale, vscale def finalize(self): - if not self.stacked: + if not self.env.lazy_cache_update: return - # self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_ks._elem, -2)) - # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_vs._elem, -2)) if self.env.ring_buffer: - self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(self.new_ks._elem) - self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(self.new_vs._elem) - self.k_scaler._elem = self.k_scaler._elem.at[:, :, :, self.pos].set(self.new_k_scalers._elem) - self.v_scaler._elem = self.v_scaler._elem.at[:, :, :, self.pos].set(self.new_v_scalers._elem) + # Assume no cache stack for ring buffer + self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(self.new_ks._elem) + self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(self.new_vs._elem) else: - batch = jnp.arange(self.env.batch_size) - self.cache_k._elem = self.cache_k._elem.at[:, batch, :, self.pos].set(self.new_ks._elem) - self.cache_v._elem = self.cache_v._elem.at[:, batch, :, self.pos].set(self.new_vs._elem) - self.k_scaler._elem = self.k_scaler._elem.at[:, batch, :, self.pos].set(self.new_k_scalers._elem) - self.v_scaler._elem = self.v_scaler._elem.at[:, batch, :, self.pos].set(self.new_v_scalers._elem) + self.k_scaler[self.batch, :, self.input_pos, :] = self.new_k_scaler.squeeze(2) + self.v_scaler[self.batch, :, self.input_pos, :] = self.new_v_scaler.squeeze(2) + # Try to use shard_map to get rid of the data copy + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 2edfaf2b..67b33e7a 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -575,14 +575,17 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): return local_output, (local_max, local_denom) with jax.named_scope("attn_insert_cache"): - keys, values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + orig_keys, orig_values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) + keys = repeat_kv(orig_keys, n_rep) + values = repeat_kv(orig_values, n_rep) # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, k_scaler, v_scaler, mask) + with jax.named_scope("attn_cache_lazy_update"): + cache.finalize() + # For non flash attention or prefill, existing output contains everything if not self.env.flash_attention or seqlen > 1: return existing_output From c965a4263afea2311cb79affc9d37cc546ca3c4f Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 9 Jul 2024 21:47:06 +0000 Subject: [PATCH 15/57] Int8 left aligned lazy cache update working, performance still not good enough. --- jetstream_pt/attention_kernel.py | 16 +++++---- jetstream_pt/cache_manager.py | 15 ++++---- jetstream_pt/environment.py | 3 -- jetstream_pt/layers.py | 34 ++++++++----------- .../third_party/llama/model_exportable.py | 9 +++-- 5 files changed, 36 insertions(+), 41 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index afb25d39..d8d9401a 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -198,6 +198,8 @@ def scaler_index_map(b, i, *_): grid=(batch_size, seq_len // bk), ), compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + #interpret=True, + #debug=True, out_shape=[ q, jax.ShapeDtypeStruct( @@ -278,7 +280,7 @@ def ragged_mha( shard_axis, *([None] * replicated_in_axes), ), - out_axes=shard_axis, + out_axes=(shard_axis,(shard_axis, shard_axis)) )(q, k, v, start, end, *replicated_inputs) return out, (m, l) @@ -311,7 +313,6 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): def flash_attention(xq, keys, values, mask=None, normalize_var=True): """The vanilla attention kernel implementation.""" - # import pdb; pdb.set_trace() # mask_value: float = DEFAULT_MASK_VALUE logits = torch.einsum( "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) @@ -343,25 +344,26 @@ def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, n ) if normalize_var: - logits = logits / torch.sqrt(keys.shape[-1]) # Align with meta llama + logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama # Quantized - logits = logits * k_scaler + bs, hs, ls, ds = k_scaler.shape + logits = logits * k_scaler.reshape(k_scaler.shape[0], 1, 1, k_scaler.shape[2]) # mask = jnp.arange(keys.shape[1])[None] < lengths[:, None] if mask is not None: # logits = logits + jnp.where(mask, 0.0, DEFAULT_MASK_VALUE)[:, None] logits = logits + mask - logits_max = torch.max(axis=-1, keepdim=True) + logits_max, _ = torch.max(logits, axis=-1, keepdim=True) unnormalized = torch.exp(logits - logits_max) #Quantized, should not put here, otherwise sum will have this too, which cancels with denominator # unnormalized = unnormalized * v_scaler denominator = unnormalized.sum(axis=-1, keepdim=True) - unnormalized = unnormalized * v_scaler + unnormalized = unnormalized * v_scaler.reshape(v_scaler.shape[0], 1, 1, v_scaler.shape[2]) o = ( - jnp.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(values), values) + torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(values), values) / denominator ) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 6860becb..ea2fd713 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -52,7 +52,7 @@ def update(self, key, value, layer_id): if self.kv_quantize: # pretend to be quantized bsz, _, seq, _ = key.shape ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)) - return key, value, ones, ones + return key, value, None, None, ones, ones, None, None return key, value @@ -185,7 +185,7 @@ def update(self, key, value, layer_id:int): return self.cache_k, self.cache_v - if self.env.ring_buffer: + elif self.env.ring_buffer: # Assume no cache stack for ring buffer # pylint: disable-next=all self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(keyj) @@ -280,7 +280,7 @@ def __init__( self.new_k_scaler = None self.new_v_scaler = None - self.batch = jnp.arange(self.env.batch_size) + self.batch = jnp.arange(env.batch_size) self.input_pos = input_pos self.sharding = sharding self.env = env @@ -293,11 +293,12 @@ def __init__( def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs): b, head, _, dim = cache_k.shape - for bb, pp in enumerate(self.pos.reshape(b)): + for bb, pp in enumerate(self.input_pos.reshape(b)): new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, 0) new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, 0) cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, (bb, 0, pp, 0)) cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, (bb, 0, pp, 0)) + return cache_k, cache_v def state(self): """Get kv cache state""" @@ -340,7 +341,7 @@ def update(self, xk, xv, layer_id:int): self.new_k_scaler = kscale self.new_v_scaler = vscale - if self.env.ring_buffer: + elif self.env.ring_buffer: self.cache_k[:, :, self.input_pos, :] = k_quant self.cache_v[:, :, self.input_pos, :] = v_quant self.k_scaler[:, :, self.input_pos, :] = kscale @@ -357,8 +358,8 @@ def finalize(self): return if self.env.ring_buffer: # Assume no cache stack for ring buffer - self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(self.new_ks._elem) - self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(self.new_vs._elem) + self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set(self.new_ks._elem) + self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set(self.new_vs._elem) else: self.k_scaler[self.batch, :, self.input_pos, :] = self.new_k_scaler.squeeze(2) self.v_scaler[self.batch, :, self.input_pos, :] = self.new_v_scaler.squeeze(2) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 5ec2ce08..9af21b58 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -142,9 +142,6 @@ def __init__(self, data: JetEngineEnvironmentData): self.testing = self._data.testing self.testing_seed = self._data.testing_seed - if self.lazy_cache_update: - self.flash_attention = True - if self.generate_cache_stacked: self.cache_shape = (self.num_layers, *self._data.cache_shape) else: diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 67b33e7a..da5b5365 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -390,7 +390,7 @@ def __init__(self, env, layer_id): self.ragged_attention = ak.RaggedAttentionKernel( env, input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), + output_specs=(qkv_pspec, (qkv_pspec, qkv_pspec)), sharding_axis=self.shard_axis, ) self.layer_id = layer_id @@ -425,7 +425,7 @@ def __call__( xq_expanded = xq def attend(xq, keys, values, local_mask=None): - if self.env.ragged_mha and seqlen == 1: + if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] != 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, xq, @@ -436,7 +436,7 @@ def attend(xq, keys, values, local_mask=None): ragged_batch_index, ragged_block_index, ) - elif self.env.flash_attention: + elif self.env.flash_attention or keys.shape[-2] == 1: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, mask=local_mask) else: @@ -458,7 +458,6 @@ def attend(xq, keys, values, local_mask=None): return local_output, (local_max, local_denom) - #import pdb; pdb.set_trace() with jax.named_scope("attn_insert_cache"): orig_keys, orig_values = cache.update(xk, xv, self.layer_id) keys = repeat_kv(orig_keys, n_rep) @@ -467,10 +466,10 @@ def attend(xq, keys, values, local_mask=None): # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, mask) - with jax.named_scope("attn_cache_lazy_update"): - cache.finalize() + + # Updating cache during each step still has very large impact on latency. # For non flash attention or prefill, existing output contains everything - if not self.env.flash_attention or seqlen > 1: + if not self.env.lazy_cache_update or seqlen > 1: return existing_output # For flash attention, existing output contains the existing kv cache generated logits @@ -541,7 +540,7 @@ def __call__( xq_expanded = xq def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): - if self.env.ragged_mha and seqlen == 1: + if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] != 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, xq, @@ -554,9 +553,11 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): k_scaler, v_scaler, ) - elif self.env.flash_attention: + local_max = local_max.reshape(*local_max.shape, 1) + local_denom = local_denom.reshape(*local_denom.shape, 1) + elif self.env.flash_attention or keys.shape[-2] == 1: with torch_xla2.default_env(): - local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, mask=local_mask) + local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, k_scaler, v_scaler, mask=local_mask) else: local_output = self.dense_attention(xq, keys, values, k_scaler, v_scaler, local_mask) local_max = None @@ -578,16 +579,11 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): orig_keys, orig_values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) keys = repeat_kv(orig_keys, n_rep) values = repeat_kv(orig_values, n_rep) - - # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, k_scaler, v_scaler, mask) - with jax.named_scope("attn_cache_lazy_update"): - cache.finalize() - # For non flash attention or prefill, existing output contains everything - if not self.env.flash_attention or seqlen > 1: + if not self.env.lazy_cache_update or seqlen > 1: return existing_output # For flash attention, existing output contains the existing kv cache generated logits @@ -599,10 +595,10 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # return new_output with jax.named_scope("attn_global"): - # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") - # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") - + existing_denom = existing_denom[:, 0:1] + existing_max = existing_max[:, 0:1] global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) + existing_output = existing_output * existing_denom * torch.exp(existing_max) / global_sum new_output = new_output * new_denom * torch.exp(new_max) / global_sum attn_out = existing_output + new_output diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 1b288dbc..79c351be 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -258,17 +258,17 @@ def forward( bsz, seqlen = tokens.shape freqs_cis = self.freqs_cis[input_pos] freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - + # Should check more thoroughly, as of now, when prefill, it's always not stacked. When generate, it's controlled by the parameter. # target_cache_layers = 1 if self.env.generate_cache_stacked else len(self.layers) # assert len(caches) == target_cache_layers, f"Number of caches ({len(caches)}) and layers ({target_cache_layers}) dont match" - end = None if start is None else (start + input_pos) % self.env.cache_len # For stacked case, cannot get cache inside the loop which will cause cache copy - cache = caches[0] for layer_id, layer in enumerate(self.layers): - if not caches[0].stacked: + if not self.env.generate_cache_stacked: cache = caches[layer_id] + else: + cache = caches[0] # else: # For stacked case, there is only 1 layer of kv cache with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): @@ -282,7 +282,6 @@ def forward( ragged_batch_index, ragged_block_index, ) - # cache.finalize() with jax.named_scope("transformer_norm"): h = self.norm(h) From 8af04745adea32ee08e1d7bb7d29b0926bd63538 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 9 Jul 2024 23:20:10 +0000 Subject: [PATCH 16/57] Fix the stacked cache introduced in the previous couple of commits. --- jetstream_pt/attention_kernel.py | 3 +- jetstream_pt/layers.py | 29 ++++++++++--------- .../third_party/llama/model_exportable.py | 8 ++--- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index d8d9401a..291178b4 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -264,7 +264,8 @@ def ragged_mha( ragged_batch_index, ragged_block_index, ) - + # New cache has t=1 + bk = min(bk, k.shape[-2]) with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index da5b5365..c098b18d 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -389,7 +389,7 @@ def __init__(self, env, layer_id): self.flash_attention = ak.flash_attention self.ragged_attention = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), + input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), output_specs=(qkv_pspec, (qkv_pspec, qkv_pspec)), sharding_axis=self.shard_axis, ) @@ -419,12 +419,13 @@ def __call__( _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq_expanded = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - else: - xq_expanded = xq + # if not self.env.ragged_mha and seqlen == 1: + # xq_expanded = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + # else: + # xq_expanded = xq def attend(xq, keys, values, local_mask=None): + # As of right now, ragged attention doesn't support attention calculation with prefill and new cache line if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] != 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, @@ -436,20 +437,22 @@ def attend(xq, keys, values, local_mask=None): ragged_batch_index, ragged_block_index, ) - elif self.env.flash_attention or keys.shape[-2] == 1: + elif self.env.flash_attention: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, mask=local_mask) else: + if seqlen == 1: + xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) local_output = self.dense_attention(xq, keys, values, None, None, local_mask) local_max = None local_denom = None - if not self.env.ragged_mha and seqlen == 1: - local_output = local_output[:, :, 0:1, :] - if local_max is not None: - local_max = local_max[:, :, 0:1, :] - if local_denom is not None: - local_denom = local_denom[:, :, 0:1, :] + if seqlen == 1: + local_output = local_output[:, :, 0:1, :] + if local_max is not None: + local_max = local_max[:, :, 0:1, :] + if local_denom is not None: + local_denom = local_denom[:, :, 0:1, :] # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") # if local_max is not None and local_denom is not None: @@ -465,7 +468,7 @@ def attend(xq, keys, values, local_mask=None): # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): - existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, mask) + existing_output, (existing_max, existing_denom) = attend(xq, keys, values, mask) # Updating cache during each step still has very large impact on latency. # For non flash attention or prefill, existing output contains everything diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 79c351be..1bb4a1c0 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -265,11 +265,11 @@ def forward( end = None if start is None else (start + input_pos) % self.env.cache_len # For stacked case, cannot get cache inside the loop which will cause cache copy for layer_id, layer in enumerate(self.layers): - if not self.env.generate_cache_stacked: - cache = caches[layer_id] - else: + if caches[0].stacked: cache = caches[0] - # else: # For stacked case, there is only 1 layer of kv cache + else: + cache = caches[layer_id] + # else: # For stacked case, there is only 1 yer of kv cache with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): h = layer( From db0e3a3c2b29c99277e9c504b19f35ad5b2b8b82 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 10 Jul 2024 00:41:00 +0000 Subject: [PATCH 17/57] Put original ragged attention back. --- jetstream_pt/attention_kernel.py | 177 ++++++++++++++++++++++++++++++- 1 file changed, 176 insertions(+), 1 deletion(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 291178b4..2306d977 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -213,6 +213,180 @@ def scaler_index_map(b, i, *_): return out, (m[..., 0], l[..., 0]) +def ragged_mqa_kernel_reference( + start_ref, + end_ref, + line_end_ref, + pre_b_ref, + pre_i_ref, + q_ref, + k_ref, + v_ref, + k_scaler_ref, + v_scaler_ref, + o_ref, + m_ref, + l_ref, + *, + bk: int, + mask_value: float, + normalize_var: bool, + quantized: bool, +): + """Pallas kernel for flash attention.""" + b, i = pl.program_id(0), pl.program_id(1) + + @pl.when(i == 0) + def init(): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + # length = lengths_ref[b] + # Always start from 0, left aligned + length = end_ref[b] + + @pl.when(i * bk < length) + def run(): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + + if normalize_var: + qk = qk / math.sqrt(k.shape[-1]) # Align with meta llama + # Quantized + if quantized: + qk = qk * k_scaler_ref[...] + + mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length + qk = qk + jnp.where(mask, 0.0, mask_value) + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + # Quantized + if quantized: + s_curr = s_curr * v_scaler_ref[...] + + o_curr_times_l_curr = jnp.dot(s_curr, v) + + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + +@functools.partial(jax.jit, static_argnames=["bk", "mask_value"]) +def ragged_mqa_reference( + q: jax.Array, + k: jax.Array, + v: jax.Array, + start: jax.Array, + end: jax.Array, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, + ragged_batch_index=None, + ragged_block_index=None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi query attention.""" + batch_size, num_heads, head_dim = q.shape + assert end.shape == (batch_size,) + assert end.dtype == jnp.int32 + seq_len = k.shape[1] + + def _compute_ragged_block_indices(b, i, lengths_ref): + length = lengths_ref[b] + not_done = i * bk < length + am_last_batch = b == batch_size - 1 + # if length < bk, then it's -1, should be 0? + last_good_block = jax.lax.div(length, bk) - 1 + + # if not done, then still work on b, otherwise next batch + b_next = jnp.where(not_done, b, jnp.where(am_last_batch, b, b + 1)) + # if not done, i next = i + # if done + #if last batch, previous good block + #if not last batch, i next = 0 + i_next = jnp.where( + not_done, i, jnp.where(am_last_batch, last_good_block, 0) + ) + return b_next, i_next + + def kv_index_map(b, i, lengths_ref): + b_next, i_next = _compute_ragged_block_indices(b, i, lengths_ref) + return b_next, i_next, 0 + + in_specs = [ + pl.BlockSpec(kv_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + ] + inputs = ( + start, + end, + end, # line_end, not actually used + ragged_batch_index, + ragged_block_index, + q, + k, + v, + ) + quantized = False + if k_scaler is not None: + in_specs = in_specs + [ + pl.BlockSpec(kv_index_map, (None, 1, bk)), + pl.BlockSpec(kv_index_map, (None, 1, bk)), + ] + inputs = inputs + (k_scaler, v_scaler) + quantized = True + + out, m, l = pl.pallas_call( + functools.partial( + ragged_flash_attention_kernel, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=quantized, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=5, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec(lambda b, i, _: (b, 0, 0), (None, num_heads, head_dim)), + pl.BlockSpec(lambda b, i, _: (b, 0, 0), (None, num_heads, head_dim)), + pl.BlockSpec(lambda b, i, _: (b, 0, 0), (None, num_heads, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + out_shape=[ + q, + jax.ShapeDtypeStruct( + (batch_size, num_heads, head_dim), jnp.float32 + ), + jax.ShapeDtypeStruct( + (batch_size, num_heads, head_dim), jnp.float32 + ), + ], + )(*inputs) + return out, (m[..., 0], l[..., 0]) + @functools.partial( jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"] ) @@ -269,7 +443,8 @@ def ragged_mha( with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( - ragged_mqa, + # ragged_mqa, + ragged_mqa_reference, bk=bk, mask_value=mask_value, normalize_var=normalize_var, From 7572a11a0ffaf391cb6a9e8b55f1f218f8317ac0 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 10 Jul 2024 04:55:40 +0000 Subject: [PATCH 18/57] Add the original ragged attention kernel. --- jetstream_pt/attention_kernel.py | 35 +++++++++++++++++-------------- keys_original | Bin 0 -> 66746 bytes original_scores | Bin 0 -> 2500 bytes 3 files changed, 19 insertions(+), 16 deletions(-) create mode 100644 keys_original create mode 100644 original_scores diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 2306d977..b8b5e2ad 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -113,10 +113,10 @@ def ragged_mqa( v: jax.Array, start: jax.Array, end: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, ragged_batch_index=None, ragged_block_index=None, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, @@ -227,7 +227,6 @@ def ragged_mqa_kernel_reference( o_ref, m_ref, l_ref, - *, bk: int, mask_value: float, normalize_var: bool, @@ -289,17 +288,17 @@ def run(): (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe ).astype(o_ref.dtype) -@functools.partial(jax.jit, static_argnames=["bk", "mask_value"]) +@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]) def ragged_mqa_reference( q: jax.Array, k: jax.Array, v: jax.Array, start: jax.Array, end: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, ragged_batch_index=None, ragged_block_index=None, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, @@ -328,8 +327,12 @@ def _compute_ragged_block_indices(b, i, lengths_ref): ) return b_next, i_next - def kv_index_map(b, i, lengths_ref): - b_next, i_next = _compute_ragged_block_indices(b, i, lengths_ref) + def kv_index_map(b, i, start_ref, + end_ref, + line_end_ref, + ragged_batch_index_ref, + ragged_block_index_ref): + b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) return b_next, i_next, 0 in_specs = [ @@ -368,9 +371,9 @@ def kv_index_map(b, i, lengths_ref): num_scalar_prefetch=5, in_specs=in_specs, out_specs=[ - pl.BlockSpec(lambda b, i, _: (b, 0, 0), (None, num_heads, head_dim)), - pl.BlockSpec(lambda b, i, _: (b, 0, 0), (None, num_heads, head_dim)), - pl.BlockSpec(lambda b, i, _: (b, 0, 0), (None, num_heads, head_dim)), + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, num_heads, head_dim)), + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, num_heads, head_dim)), + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, num_heads, head_dim)), ], grid=(batch_size, seq_len // bk), ), @@ -433,18 +436,18 @@ def ragged_mha( else: replicated_in_axes = 6 replicated_inputs = ( - jnp.squeeze(k_scaler, -1), - jnp.squeeze(v_scaler, -1), ragged_batch_index, ragged_block_index, + jnp.squeeze(k_scaler, -1), + jnp.squeeze(v_scaler, -1), ) # New cache has t=1 bk = min(bk, k.shape[-2]) with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( - # ragged_mqa, - ragged_mqa_reference, + ragged_mqa, + #ragged_mqa_reference, bk=bk, mask_value=mask_value, normalize_var=normalize_var, @@ -456,7 +459,7 @@ def ragged_mha( shard_axis, *([None] * replicated_in_axes), ), - out_axes=(shard_axis,(shard_axis, shard_axis)) + out_axes=shard_axis )(q, k, v, start, end, *replicated_inputs) return out, (m, l) diff --git a/keys_original b/keys_original new file mode 100644 index 0000000000000000000000000000000000000000..4ce7d0b90ca8bc8e433e6b89d5bc2a88b3a21190 GIT binary patch literal 66746 zcmZ^~c{Ejj^fzuOBtw}>BxxW?LP+=QZAwa{Xi`a1QA(k?WFAVACXoh8QY4||p1qaQ zTtb>>4Vq}4eDu5D_56O%AJ1<+XRWjDS!bQQ&-;Cbz4vQ4KAwtQZyQs+d|G!I&!n~;Ph zA~Y;`PSpIRk>L?@R$8krmoqb)t}KfmDLZoR!uiYP|JPqs-NVA?M*04)ovOsMX4@V88Uf!x2WZ+lb3h5^zl?y*{DKo1+vKhv-i~T?A|2@8roWuVS z|9`IPa4UBg+y8&?|9zJK`vwKg7#|WaH)Cj@f#`Cr40E*t8ARRY79lG|BVD`E`&@f{lyWT9P+glCzSmhHqdVT=E{*Z*X zKPN-+uN-mu@cVQvXSKLY@jYyeh~$PgBW{QtiZnD4Vn6Al?&6t3Th2^6p|+9oUE(S9 zkVx*XT{zuMj`vup(MG*1JpXnc9vgZFzKktl>DHy7nQ%mURR1?9oK@sFg%L31lp>y+ z7cVU6y$iGZgwfMcail&(lO@Wxg{r+SoQB%C)ix9Yw?skdf*9~K-wmA$C%{^FXTJPv zGHC8jgtr&}gCn7%gd?9^_-CIFR32J(c3waju2~?#1LyLDqZ)Oz{n{pI+3{9LE~$W* z1I?kxV-aOk*g?=Z1N6;_1*?BA#M={o2_4plabxIivH!zu7#o!UVcnm@p_Ri~nth7C zx~0JQ;Qlmel?dl;3$fz7)8de@Yg|F%t^O=%dR!JQ|^G>~^ke1cMQoxKlY>+yIt8yf;!k9u;H#YBFzc`8c(4v}`<{1Eik--KsBCgbFH zh2U5H0qSRV6+BP&f?el(-~iziZu+FmHS5#_<&_fgcaZ^`+t}ilmXBbe@CsDFnqsN; zbUd9i24<`tFAXl#!h3$PVnBV4XzTl2_!h~O>%EQFE!ayR|F+Os9Lw=jjzC(sJh)Y% z$2QFckX{#$&CkooUvnEiGu%qi8;}RS_kN?=xEOv_+ts`Ltfw2hNN-DIBbA0M#XBu({_xc(x)7KB>&&Pov#9&StKl z5a_{sdQ?z(%NB}~)K%M8oA8z|@4+&AA7^~D=a9)G(M;_V>=>+$6D&82Dvw5@QpS5) zF}Z`#F$j%&Z4)ZBaEJJ{4ajZ-5)OX7l2ui%IU?b!eLKg!G&U z3iXO;>)}>f=AOl#SqorIv?^|2ZN_O4d9c{X1MPd>0_A-oeY7nH1IbN#?QFoV%R1=v zmO;=^V1QP!u{`6#O*->5gLe8K!i3EXzY4=qG4T*}-@6fg$LGTAe#eC&H3td0Xh}mO zm2uO^ObQyU$Mu~bs$<%VAhWgv!uRCBQQu@b@hXOw{J9DC3WtPM?amN2Y#OdXx%1ipG#?!+*!8!?lZTGMAd^$#+DUta+yxAd zlBt|}u^;(|rQUR<`V%R<_?HXuHwv-D zF&cX>k0Fn-L$UeYZg5}*`v@mi5GtwYo-#{>6Ljpx6{*3&Qu1kNH?h_qwY^t`sxP*@P?}J&_73p@yIP6tEi@T*RhVY3y;m<2S zcHNPR+w%^Qn$sriJ>Po1|G85N|DR`*5FW$N|lO0~Ig3o=Q!89Xd z{=Uu#nj4m)`jsfGwpt+0?ox`%7uE2)i59+iSwTP93)k#a5ng;X zOAj7=X{A5kXT$K=7Q&wgg*V#jymtRMcziEGe7+$=XAy0Q$>h?j7hv>_0BCBfl>`{CB>7?$n6;rV8^6xLe-%fCcVYcFFkcG>R@Mm( z-%pdvp=oUPbO~8JRK&2-d-U$?ENPnEU18VxV7$?=4HmsmK&N#@AXni51MD|LtaJ#% zAQjZQo<}2!bSXqzTgce9fIUoRi&J0D#${8x@ayY(&=J^+(&{Q8tSK92d!6Ix+-a;` zRSBKG4!Cyh0p3&UiH^#Z;t;)V=pTL=440o4R`=V)2V2MB_1_!lqF+4su`;5M%P&I9 ziXgt8-!5#`9L`;$15v*809o(+M~x0vqN{$MQA)@u_gy z*B?%6XT$ZFr}QS*oJaj0gnI+CMeT+CF?m#fKKkJvRV-2F)jB81a>Q2QW4s4nP1eCH zwnzEK1SND!apfI}KjFufOBB@U5B5iu=-1=lG^_VPVYOaw{v7F#yCMQ`fe} zs^&kU^QW@tjDa%_T$x6uWr-+Pp9wmHA-0 zIW&ijZ?6+eLw``CeqWq*Fpieg_NNINDU=oQS=iV+9%Eil6@4Ckg}y4~lr(WK9DnqR zf)(}AcbGbGrHWW>Jp#u4IZbobwvy?*tptJQkm8|$)12PG(TEgqS)waFJvRo*vdVeK zzzyO?`%|>Y=`w9W27W+gUl_2x&?e|R1)pE@lWr?rb(PoAa? z$-g1sh8j*I_b%KfV43^lA z!nI{fdDz$hIIq12Vm>%P&zKyL^Q?kG^{r&vcv^C5Rx|B;HWwBcUxD)DQ{Z~4I{rMI zKwi6y#Ig(*zVCUB#xy5zQfsUba3K-0JbST3caN->n<#v>G#5;}khpPe26vs9D#R}t zi3@Uiv7U|@Rt8N&>#a7t*RhuNG#v%&;Hfn4aw3?iP2yj_r-%)WGpPQ_7)Uhs#oSe< zj1_mVN0-mwdRHGRs&>;Fy8)DAkRev(9uRZqZHMB7Y&zZLH_U#mLS=ik*n8nrzJEE8 z{^m@g`Px5)Px*%Euug;@TYd0zQ4p9Wq;ZyU0RJZ%v7zf=jA-eHEqGA8o4E?Fe;7bR z+)~B&Ha&s%wNli{ci?bn4#d9+p{G+7;qFB>Ui7p;sJ-CLDwlJh`Jx9r?2`a*74G2l z^B={;h+}N2WJ3P4)VOR|9+=k*#I7q&!*%naIPB+H%2()%1`&VBzSoX!a{dA}HorNEJI08T<;R;~ z^e7Lu%+eIz=_G^P+rfDAS`OZ@Dx|TsEBL|A4r-X$o04=l zDdH!KV(dHapIDzU4`zJ$Pn0U{q{BvksmLr{7~5XX1$zt0A}o|+x*a6H@-8^*Z49q2 z8wgHo-SGa@4LEA}4lvm>nmaDWV#AYd6uEplS7@H(F<-4&J#i2ciSR7Xo`ke=I>}qT7fXnbk&sZ{PO+|^zbPl{epVhyZ58SSUH-1OpE8jnt|};{1S}G z(#H9xhvGQn-4H+Eo?zDehCCOZfK%hI!;S4tRNCnST{c_bT5~tv(W?k%_%(~sj|bzL z|31^U*%?%HA{+d*Mxe2}nZ#Q28x*vZKvc~k&}jN7s+)}zFPGHQb`7`#kagrX_>`7v;4`sMb@|pS=Zh=-)8(zIYV7KuDQ03-QR4{B1x~!hjZHRa{-jxRD`LgiFO?K(6%r$w!L#1>Q1WSvX==sO;?YE z;3{&QmqOFlb>m*!B^*~Jiu^g27ut{I9c}AT?l$5}X*7QMbqU5#9!(4WX>rcB9B%W^ z7kjTz#l_PSU`5?_@x;Ak=(YU~bbdWdGq;=bpk{&XgQmgLaUu@@~r9*a)-LRQU=i@`>}eX5ec;SY0 zoLP6456`{~>OW*y{A<2AwtXK=5odDSh+6o}->Cm}E9_M|jhmPDrv;tMc)ijEsJog# zePsDb<>Te*|E@Wb!p?N@y?GW*+@U1YAM?iXWGMN%{s6c)w~{bsHTx@Dl9$~EdVTyM zU7MHBW;v_qfGF?eaBij4E3$;j`?Z6`Mq^$$u|^EN87;n^UI{&>^a7Qvg1aM5~r)pPwUr=##+Z<8+ypOg-{R5q?J4GEES9q7)Nhf_KL1a!j{h9TPo>j-v{gQf0j+wT4&b_n|u9$oQ`DB-QGqi%D*Iw)yKiZr;A`l zN;kBw=@1>o!5FBu44tE{K#F@$Ov-x$yY;$Y{;(q2>7kBMfd^Q3XB!z;g+Y>a1cmr6 zf-&MhabHvw&5a(xCsb^h7M;Qw&cGfI9K?&`_4xSh7LJ*|n_F^s3p@AZ2z#JE>&Puc zW4}l)efCb+zPcN#yXuYZ-qRdA9A1Effq`H!rwfnNohif>sK8E*E$sBPQ+iKh3CoAQ zp-&fz=@Lgk*y z_+pVdm0h;Seaaq?dO#0e7HWwqm#W1N1}=Db$^?2~P_m}m=MQORRp8Nsq8kq?TYVB!q;9MRUI02e}09$T!5UUi2bFbsw$?w|{Tz2gw z{JQ^wOai9!qpvr_51FP=dvq?eUl~TxYwwZT&@!~HuLdREX7X5igqIvE!>n_)ka9Pl z)jx{jLFZ*+POvC_GUo|&nww(H_pQ=pZilGP)lb5X-c!*}&k~_=4g|g#g4;gd6g8r6 zz|tv);7;NLsviia%OScT_t>Q+Jl$v zx*+X6Lz7=v=EK;EDQK0ej$WHCk>L6XEcRaz_ST%He!mZcj=M8P7x=L9ODX#Ial?0o z{m?0_kFYBso!2;7;u+Wef?>>4=-jW%fA=1jY&KmlD)*m?7oA75gVG%c+Tl+c;!0S& z|CIP4rUTq=yRpieL13}ryC`^%z@At8vShTXaP8v&j=ALkhWER%`RQL^KQ2f#|1cdQ zYYX9rOE6d@3244kow^H&{AlfLm>vHOoXUJT{m21fYj_s*(F%f?L1w6J+X-bJ1~jbc z9Za0ulfyI3#DR(5XsX!{@P6+Mb=&vQ+nOmDQal8oPQODVle@wNB{dG#8w<0udg2+$ zM@aJ@i;2H_z@-_zc>K^N_*Qv@o!71szUSsb`5gu9mTCgW-j+hf=|tSo_?ntNWQbF< zcH!{hVbtYE0-WfY3k`h+vvS&J>FC-x{`VpY8+`-#Q@%Z?yDr3G7WXKz=OF0!?jLPx z{Us=_Z6iDB7|v_+hVBF8>37sXAv*6F=NUd1pY%4y{!cO>wN;*t-#vttL5$#zOgAjb6o?=$D2$wukF;poPm+g2+md-1w-4*b-D#Xo*QY@pIr2{dQ2IH{lQl+6+U_GrLjI&#l6#? zQ4#yI*6Rk+xojn@KeB}`R9W+cC{3Qk|TKV&s**g0E!~ zn*=HgJkM3!YOn-4cjgM7`wUP^##uWo%%xJP99r9{L&Mkp(mrd8z^qY+v)VV)5|94E zQ;SZzWjcko3vGBo?v5xI>Q6JICt!2R6LM%A%;BT7G2A;*Rs)OBsLm1hT{ukx%}#({ z+g$dq^_T zSPgw@T;Q(aMXqbCrF*(X^ml0hzb`mL>ZpQGE9MryL4jbgpQ$T$*W=uaTeSJd<%~HCP zN8Mg>x#0}aFHS*k>;$_?555?`g~TD{ysoDRXHxc!uCE#BSlsr9ii-02#pyvDyk{v_ z9*&3XhxzR5JQN2Pe+K)SL@YUFfXnp9@@?gPbatkVXi_F%uf?t`f87P2O>`vR$UoxB zxqYB$#xXolrYL#PzEFH0^_EtC=)sC}(%_3@0)4X?%*WrG!h1JY@R>c6q;J)@-TN5V zxf|o`Qa2jsor1ns&%*Pf^}si_pRX_fu58awZ zFEG6idczLmsaF|v@z!KMX8KC1*5bzY#)#1<)P8m3Qp9 z#J!EC^0ks7GM*+ILtK}l>FiXPbhu5bk>rW1AC*vuejZF$jz#^_x%gzLt!SR1gIo9R zr1!cd;5XC^d+*zVS;-d#-NGsO??xrex%C4kS8Smd0Sn2h@l>@kmWY`iHoVg5j}Y{2 z4(2w<^s=8z=$+j*F3On4kEb7m@s+FSu>D#5aVAN;_9q2>ocuT|WGnjW7YjOG{ zPqy?*q87_c=(gU5XDoK$-Glq{(mO@M7*3$P)BlMYtH)E&?IFk|R*+a3j>E<0g7FSV z^tV^S-si*d#o-{t`M2py@Uu#Su>-D%_lhPn@7e-ls~;G~ZNok$lhFHO z5ROzF#w{`}I8&`j=un9f9BTduy~jDT>n>At+84?T_ur!pqvUC-h9%a>YSfO#5J_x` z4IAlQ29p8WI7U5Qi0pNO{QV5LX;d+!J{`#3i#Bs-(S&M?w*m0z#A47Ym%aNFt_e?0 zcE!{^6`XfwH~RkD%)OLGu*bM4emf-+Hytma&b{7X)z_Pjna@RbsiTd48jv+o#;wHH zLRP<@pz70)X51*i>j9mVsqhO1EQ;fdDP1t5@q%EzxRH#N3dty`k#gSug~!{o=tyQg z$e+=dX_Ha>$max?D%4WOPeZmnGeY9_>5-tm=N}c<+hJG7pETH~Stz&bfR`Z~VALR8 zEH_t1tE)EPR(loJPv>Iur%6~Fk|XO=D9Jn;dQ1ZSDT<)LUYmJpXn)SIJ|_ukRimkfWzyL`0GB zXGyX+>c3y~q}31?$l@d^JOZ6*4xC=KRD9{7$m@1*L!Gx7Z0TXhE2b@=%Ud?#nS!5W zmpPEFR-K`)Uk;MN%QkVqkva-a>IvPnQgA7S!nX?parlW7)KRC*>ha@wq#du&49IJv0yD`)wT-t=sk^>GH9WJB991E_6&p$)BB4q6IzRPIGroOSU1mUZs?!d`!5#%8sR} z??LHkGiVK%j-6W<2yY8w`OEr6eC(kDMolZErtxzLj?d(N8T}dG2BV9dGYKVopd)z& zMzlVZnx)+!!yoPR&q|kn7a+d7@)NT4W(uZt@^rJNK%CtcC!UJA2KT<|@$dH=pj%r% zd}}sQ$eZp9Cms%>F9UW`Dy=3DS$z}+sd3DunN-Fzd3ylj39l5KKG^}{|MlhK{c>EW z@J;f;{}}a|sL4N{Jcf6^K7zQ|iBG)zLl)Dgv*Fa!qK}m_A3qvP5ied#UKV|&S-Uc@ z|MUU;z4j2q=AVPeYCmpoJt%!^yM;C-_J`q1(rCil$)qsenWg$<_$;A8VKXr^-w zHf*pVodr((x_BQODIEd}qe2LkVc$yquBdnE8=R4gran{rslH+qU0oT(PXkARQN>B2 z0f(cC))tDr_lq7>G=Wand7}4!SH;3%v3O!*A8!Ac3u|Unl2utG?-+Fi9o+=hbyem< zhYy0C(Lorlaf@a7|2~V;o(Fc? z?bNlqKFj?)O1(c0WRDBe@bLg$-Vq?jUYUro3m$`)juHJG8;-XXi(#g8Jjbe4?c<#UV!@!K#W3HVqeP zf^h^iwG5}@yQ(F&26kL*|B$L0Gsvm&cJ)6kb8b4Fik2I5(SG+F5RDH|>IH%|e|mG= z&0zlRpu$;t@5PdTpMdVn0zdCjoZT`En-2aWwU7s7B=LjDi)GL#Z$O_@J?Zed{v0nE zhSdi*L7dJc{PKP?KRK97gMWMAjf2I|qjA4rv#_1sCY6b6f91i;J=vtwaz^U1akh|C z+a!*$bEQICXOQmFB%OWtphEkl_@#U?BnN4-g4$!j*6lr|7%b+feiiWRNf!+Cl;^ko zG=v)$EAgdAcS;>;09U$Az>-~AP;hntk8Vw(A9u#Vp>4f!cu042mW*MWi+^dE`6dco zxLn-OEQj}M18L>caJ+JHDehE@Axa&~*|)t|vbRO}WFN+3Q)0NH^9sD09K|Z`Dd16k zKoV!+K+~g}gfGTL7-q5%T#}0L&0!fYY7$T5t4`AzTSwXwS_PL|aD_UdgbbU1b{+SkiutaT-o78?zj&3>l2udFx(5r|CVvl>n@!R6nTz6-$ z=u*-HG=8Ov{vD&CL$^tYx*mpV!_v9!%P1`Rpdcilx=eXi2Pn3D5<~^Y!RE{~n!e~W zUMsl#0=FPL&Kv5o9K%IR}-3g67!1MZUkcwC0V z>({&EZ}lp{FQ*Vcm*vA6-v-EfKAJTZ7vjx?J76~MH}zh#pG?Xe@rS1c8gG19ozSVm zx}zuZQq2^sXqwAQ;zq!W&Uc`4MaIRgR2Fx4>}I(`o>iysjc12N!^H_-o%m6;CHgIu z()=T%VAWv_oV+`e(*>eyhOVd;aE_Ft2ead;Loof(e)w~Lw$O1Ph7t#lfm+XD7&dJf z*}fPoj8Moy|KW}lrLD|HP46jcpBp$IN#c`fZeo5*E`4~Ujpmb2kd50uY6y0r9}lhx zs=B+tM?;qDN>5Q(Te^%>Y7^*hF-&nkE6m^fP;4qKp~$Zr;b_NqYPTK1#_3H$^z~GV zNcM(xwVmRD+6mIZOI7*3f+o2SIwIWCQ{v;xeP~wCJCdU=pJ;~YfftAOM4N1NC=QZ? ziJ__(GDV*Kp6?)A?E_%iugsC!)cygC! z9k+(*cjEa#+G}aYf7$T*a3?7|Z>QK{VYFgiGGs@+qo)c+SafFuk8Vzc&+hJIyZS2V zj|t}kIS0WyI+7al=23lN0+ofG6au{dK$N2q6jk}~v9&R9IPk9Y=rJw2m3tm^v`7I42;|LVA}r#WYIBlQuzQFp05x^V>T)@X&cA*DG0V$QXjpgzG;2#?=_E@&W3-5G)`jw0g(EP*)RY{q z7zvs;`a+Du9LI#veVDDC3HxHN)5~e4IA^|uk0@rt=c7HyZrM8cH2DNH)%zisJ|PA8 z10I31P}Bbrb^I|x&k;R2Kz0u<_#C0xCx%hwk^;&azKq;Q_5?l0-FSUf110#pfa5oh zl2L#go5;Lek;ip$;w>%G8hR9_&8(wu5gkG@$I-*ISn!+X#^K|hLU)%I@=5f==3nFZ z&%v9}@L@jK>l@;uwPvDG_*OVx3*kZKb6R?SCpdU_$FkZ{ zFyY(`$iJk@PFpf5E-Zxm8x(<~)(y~)jfQ*QazxJz8TaLVKoX+3ih9peaPmpq3`?G< zVy$^^d>=Rz!yMwcrjBF#q+?zCI1Jeoi!1HTVb>W2jCq-g z+b$fWZ#C0-b^aBy`Zt0r%?<&KH=)&bD|lk@6!t!2#shABB6r<$(no`DkiqcL3?=sH zn_|jt(F3{Zc0I|f-i99T?;(BKHh71!vEH}CXxqU!d5#Sl_Q(|OtWo19b`#hz)f0^R zq<~p_qEL6h1pD>-M#E<>%i6~CqD7UOVko`j30-XPZ!OHTaK zk6QE^#AvY#uFdh4ac>7`>FPZ4H<~I=oH+rT2U~H^PewSfy-n11n}&^}mFSqkKhfH8 z6|dT70m|ZN9zAUf{OByBpK=@MtmhXBwCm0Z6{dLNk|G{H+YhFnHR1$?QR3CC7C}99 zAH}^;5R596U{rQ2I+}RV(=$rwZ90LaV*3fNWkgQp&-iq%`YFuqV=A=V&I%%?4VUA{I%~#i+*>* zy_3UX-_3B|c1{KRTlU4VS(kaQ>^|Op>wukuCvmG+HT)hr5kn8xQkiu(EDDXr@+CfW zTqlMCa-#&F(kYm4UkEmrGx_QfQ#NT;tf?7R43(jSgw}}}Xmr?;^Oh~AE?tzt-Lw(j z_=nNe%`^C$-ezbB-3uqy#!$7}I1;Yfa@XRW^yR3P)O*efsCwTWJnqC|WnK-)AJK<# zBTX^xr74^n_#JdFuEmV1SYG`s6Bd2cfs6(5)!yHAQ17ZGJ{mnJ1^OQl0MVy@aZ0S?!vl&HKs~xlQX98NYZ?9epy467Lp^3O|p*-FG9zkmZS_ zad9VptWrj&04>xtyD1p0%;WE)H{&}G3$!(D5xNwn(}UBI-1KXKP~-a@q81Ec_UOTZ zulGY{SSdhG1EiA)FKDpf(t-N8tMDj=+RWgxd~Z^Bi{mzzOYpGs4xPI12SGFDi!IWd z;J(3(Ua7o+-j+f5?1C<>xS2UmEhITKDXv_W{oJF=gu-7!wXk15Y zb9RZls)nK6)4s6deHgxco=k^@AEaMd1lOj>!PMM#AzppA)Ok@1?4MapBaX&^Y#dbH zxK@~b;gG2OpCPpmPM{B`WW7YHOw#(^N!3%`=*|;OKA(6?EY92oosBLSaQh;?GJJps zEk97!tIbj+V{e&Xwq3{=umpr^doFG7Nt15}<5BO=WSFf83G0TTTz;*n^+^f5@0P=q z`>8xAc>sPWe@Vvr`?0}w6W+>Q%staK!?J_noIfy@;x4D)?M?0Aupos#{TB%tH{L+o zizpua-J0(vYH*=Zm09j+rpDoMzF*$CXaD1hP@7YaQ-%%NTHN`+@* z-i$u8UKwI)(y(sc3t^(yCKe+iw7kKscH z>%}6`9{sYhlj@6{AYlNxZ(z5T$T&nLL5dqGGK*-zft`4p16h!m#z z;!?S0Fv!|WI*W91&HQA3R4>8(lU~5Cb>mPndm48{Sa7iRc=&#Bhv=yH7yQljSp9k` zuADm#TOteuZ|hIO`mJqrY~Ob};=6>$&-EcAJ*nswsVU3d4Rk!GO>~^o2Lj3gO^+T0 z!&*I-3-zR_o?UTbloAKm4aMLY%CO^FGZYR9+-bp=d?t28V-RE&*!;q#q`-_T)6r zpXBA$5A56&kj7-w`=TVg_}G>fnCH{;fc3D>cpl{|>A;4}gY?S9mX3w(CA+L@E27pA_rw{-oIGxwNBpFRtz51R6c$Fl5VQ2+{cr77qQ; zdE*)``FK#w9bX`(`S<2OCp1}6vnS5$+Xe;c1^Bq65%$&n0o#fk_-a3pu04+jyQorl zyy!L@4fzQ_CtKi789wxlQ^n&i>bO!{JsiPoth=HLNsaL#%ySx zt{@mMO#z3o!SLZ%AMx4PbYXIF9SClIm^@?{)Cuo}|F-Ueb@$uEQQgdWfAnH1?@Xl? zIZ1dw*$y@-cjIN}Ueo!6GvewAYGReRNc7RNhHdk=pj_@mzHp^klBW-ne=pPzPx zTC>NA0@jiHM+r~uF7rgr)#8nx@6ps&1$~dwlbd== z|J_TW6rK3dKh?^4N%jf6;sVg@p+$;YW`f<+1R=fQ2w8IhyhxBh+q%UR-NS%xcQu4s z{ezq}<_bm(GUJBgJrb>!TTsz!6S^1f!GoV}QQUM#9@aPrbuX-^-#?GTpr;Bn@NbZq zlIcZ{WE#r3GxeaOc$nHsETQ+CIVkOG${Q4eIjuDwGkWE+)3R#u&#!#JF8;4LHb$1` zj}67w#Z$?D&|0kBZ-SXVuV{T@U-UCJfpZV*L9dq~?pi;QzEl-L-~13RNos|eJry}- zq6eA8?H6|*QRnWhbK%{?e!L_=cBMvr1BFBZ_sw}F4(_}tp4gp68~-dJkI~6ge^-kq zjz7VD!`1Og+hdw<)tlcfos5P@jG;@yU2$nkH(`BpAe`;$hM9!}K)+v+Z2V(o{9U>D z&1w-oT9FG&tyS>Ko_xA`EPzx>wt!p3G1!o`R@gtGO1Kc+9e*A(;?Yi?!uSQsJaz6? zFpN|IgHdS^`Kwxtt&U{9Vf!g#{4!{-v`6=Z>9}BJ8x?)2p+b|-RFbv=wtbM$m_gwS%Z^N7ne!4N)#TK3LUMd z>1Le^@19bE()-ck;IupBZS5wGE9}ZsY)^{2>i0qU-4g6}><2WPCu7o=Rd~o|F&&rj zRP?KZq&aDVT%@&lyix_7gYu!6fo(>|qVI}8_%CTLuNw9Vjt5QUm_crm?qjx6)VCV= zxT}QgO`_nOxfN>uJwgW?-omXIRmzRJ3^ z2Ii=E;R4f4L9UPW=z|_+cwQrm2KDHK!e$melnP!s|1UyzL(Ia{<^5HUM05P z0~i?64N6yxqngl%!pgvPu(N+JHcCdOf*+I~s|gJm#^m=m2?ox-EL5C5PDx8M;O0{j z%Le75fBiiyy!`-v-6`OglNH%i)Wyr+V$gOUiMFP_vEqn2J-A*cJR7;IdSA;;9Ds;{ zWlc1E%woFC{Uq08YiZo2k+d^)EBzS}jKMMv(=0$>oz^qrEN_vV-L+V^Jqzj-JaBl{ zOx!Q4J5>R5=%800&bYOe%}>elm1`gJ4m?6W#iNAlzx%<#T6uhw=0(dMPvKGDqhQP% zMOr4~z&$SyrPT^cgr0r7!;7IiIO~pt`)eAZSZ@g0q46}KMWf!lMxu{P-+vZBC|?v(3V-n+2dzRR=F- zNrit~E|BFpiP#?BTg>wkVD5hlAWwHP|7cnOFiKsR?L7sWgL`wkRVtnxN_b|I0_Zf> zz?t@ZnZ`}wiv?%dHy;rS?$rCpTevgG3_jmZ zhtt=yAkup7&!lYJ^;*o-xMkX>3jUj^%LoupD<(Oyj7;3oxNT2UTu^Q1>d5 z*WHc;bDKG!bZRwE4H^lBkNdIv{JChNV}+NR2BV?DUD1lq1Edr{Q?NB@y-~!BAPbz8 z{Sof>+XkzrItb2co_u+G6s?J@7PNzc&}Z0gQoEFZ(0UYN)gAcGiz#&A{Qw^Lp&kCD z1f#RuGH`eH6<7O*Q{b*slBm2)ZaH|2%o^{2#_(~d(RTnT%(y{Ef@EHKUk5OXb;PV0 zZ>ajkE77e%f!A#pFzc5St5>Swby@FmNqZizZ*=9=S$Ci^&=K3d*NJMlkeiIF#JZsl zm||~&PTu21%k{_5X+$FSJg9|xH7~(ZW#E1x?wI-GtWYMI1ul|AF+Q(9*ZElEkiQ#9 z&(;t>h3w`o)8V>>?LpfRZNaq-vkJWLZKgFPMa~Q7@ih-5=I6r<11nw<^qD5Jt z*FFHuT6uN!?z)@vx*ds#Gw46GkB(Jj?~=p04sXEXTzd?P(DeQ6;9oT z#I_i$|8@xq%YE?XrY%DEq$H|oGlZSmQ=r`NhG;GGitRn%!Lem~z;xgnVd{HB+#2|i zo^+jqCdp}V(&8GQ$&3xr@KUZ&wJ(ohAcoa?y^WnqUf!IIhFeMj-uxadM zo_^!JX#Au|3~Fi?jK&s_b(lT}Zvw*74TvF@C0bqUlIXOWu58DaFi zUo>}+jQ8%Q0nPfa=x(?L6_+w?ExLzMXAa>aF66+CS0up}gVAjDFk1g+6+Bs=$a@B> z@$%qte5q_A1{N)2_xP{SHAk28e)hz2&GpjRd%Hk^%{F?c;lri*Z|I_QBSkHjPoJVgZ&Yy2y(-TR-%OpMBcT0EG!$9tu$tXC90}%F zFwYk+{GJ4>)nz+YxB3AG z@~2BQ>k}}=a0fkFzmNk*O^`mkST5eYa7Ee_I~bf_9TbZm#No$ZR^&)7*uFDckhA^i zpw(|SJaMw6)Jw$#gZttRF;8$)I0qYBB51ai7tUR;%C~RC@r%1M9Dm&z%VoOC`w8DD zQ0DWw?v+ga1Ild*n|-HLEP`RpFF0>i)ItvL+(st?yEEnr`?`RJIks@yG1GT zVAUP$pl{8=>Ydc>R2y~dcNI!y-)wkBZUeViMRxSu4nB9QXrb38zTGqst9xF8F&%*C zz9f;}fj<;j`xx9@)v^6hfaFGgGT1KoD%2SnV543=D4I%yM`!Y-FUq^30;gbR@n)?@ zzIAvr@o#v2^PBkoOA0G*m_vC+=Ac{i9?&$9JnuURQ>8iAhCZQi@Olr-ULSz5i)N5* zR0XnLg<_IeAQg-Oaf35!h^dBsb9|Xv%O#ZWNW?ClwW7&3 zA2jgo55;HJpWf_UK=$gx*p;+WnEu2G`af}}lmLcH$0?HU-~&*&R1Y%qcA{hH7eTS& zweZX&7gXlE;j0cZ+%)o~kU359^oNvrFvhMG6y8~r`tE(~tCl*8k@gRlrtUX7UK561 zorZd4!{MTyA8t01YBIF5WN)PTo;@x*;do08s5xyGJ`{|_BNNqVtdl)e%}St6t430p z!XvU>e*`*M--Fp76LEjtXIZ4?OlJOCnO^2QQS~!B%<}9m%>Frn{V~ym9qT@v;-hs~ zKgS(-KFo;u{}#kfb%{oc@Il~^s)T)i>P`c9XNX~s2=8ehhplUWhZpj-FyN#+DJyq@ ziBjFw*ImJac*39EHLu2^E{V)rd@LK`P2j)c2)#Rho>t{*;3xa8XNm(tLwgUIV#Jywi5D*xJX7;`@tO0zu|3L0UF%#w28 zc<-CAXZmwEmrzJsXB#v5AU`SyuLld~D`H`|4n^sNK$-V&YHchNR*%|5W~bV~A+ibe zEw92wHwAj!ES-zj<$%B8IJ&g%4p>+{gsXLbiMozN0m1n&h0bACi9J1=V?u>>eW*6M z6MZ-tM!{`eG21GLEms>w0lx&ZjSGTl_7*+nJGDw&5>g{LJ=daqj{&g0ULCJnXJGA$ z-Q>Qe7p1YgkW#KCje`lK_dbQX%<72==bdo-P(#vM?TS}7bfp=pgW;CVMOf4@ko@A~ z@u!)cm{zhY1!`MUjQ?`sN~9(Wadc!0b{~bf2r~-b(u*00oy670Qv}Z;C)tQXJ(!T^ z%U;x*VU*(l2)eyTust}C8TPe82eoy~b*U=tt1%_#UrkAnS72n_9n>q5i?JU4nPJ%_ z?3!K=8B&dLkxh4etkr{AUc4>3zV@K=H#BMQ&{+1g?Tqm8<6~H=aoxdJr(Sdq*iP!j z^Fb}}3iJ8&i}6ea8rv@I!ZqEKrTkK94qvKEtn=K6Rw;Af!_%Xn z++#3%`0Y0gkDksdmJJZ!KV420S53+O&OR`{EUg)GXEUVSlJd~fvst2X3UlkIMRD7T z$`y_u6RSUVb~>~riGr?7VGdqpP-qwkD=V|{>aH0#e)GlO3q8Lj9w5RuD9sN5_M92m#^x|p(4xiWg4F;;Gna?o*!{{XtaEmN$s*vsBoZHJ4_+hj0x5|mHr zMSEi&L3;EMYWuYZTipaq99$%3=AOXQH?-;C&T-T*?GMp$z9Ige+=&gb>%%@M^978W*Y^Vu5H2NDSc07W$XQXuK6d4wsx-6<3 zX$9}~O0t%P-VkW@1fR^@hE-Q$NJY9<862cbMXv^l4lj+Vs`LXU^@xVX9&?%fzQg#X z<{&HnGYTFFQvINL2FC5Ygn?fS&^UH2d@R@@o_nV$eO?Ktj&ujYCPlJl={-lBQZBZB zs>DoPcY5Gs#eRWkreT)?_R~bLzTYH9C61-=8fUU>JRs9LGKIZbkR@*zBgfuSZc?LC ze|D`~7)|<|Al&I~&I-l{<6}}LI~c;^D_YRSA{q^srNHos``GwsC+0S0Kltpt1GSnN zl)Gaawp1mvLPI-NxH%RL;*8kcw%+WjLmSA`okX48aQU@g@w=xd7_XpTHdouwLP^>~YV&U{P&L4=dBy>`}=u_mUa3 z-rO$l_4)|-jm&}cB`Pe*sT1Bg>_W4YbH$g{K<~^{nSpJP^x0m-TO~${VZcg$ ztsB(N>IvPZl!|HwUaZyTzNn>}NG&eSQr({ljE}Y_mFMg6P0e%BCB%Xv{QPNi$r<5% zzkJx;=tOhPt=RxWMYg@`Om^!^6vX^ljl;7q!@RaMvY)sfFLq02fjxDYU;bMNPoBkI z*A~LeXl3@O(NWxCJ6ITccaLyl*J-FvsK;ni7j^&+OXrAWao?>E&_gE=tE`t|)2lI{ zo0DZcn@8!P5}J{C&ja6s%+{X;beBfi8*Aiq63;L7aF{&M5-%q6`H?E>u+@% z#vG3HrL!koNc0V2DHXbuZKz27W|ZT*{5PVfSr>LPdYxEiD6P5lIfc$H3t>YK^rHtk zpTRR)hhBbiz*VC?n9hoEDEp&MXw3`5%9^fhUXe11J?h1f+@IK2%eSI}rzMU!6AGuB zidlAOp6KaZiCL#`P zgx1(wU_X8+JxET3lH3k#h3W`WP5+FWeEsPpN$U}uU2|F`q$gF5nMR-LZy>Cgi*D*3 z(sg6AuvXUz2lbhY#`WvrZ-vhE;=oYm6k6>tyU7g(DJKgCu0ia)?_g@M31SY%Rzk-M z7R*RLNn92c38so};FH=D-sGvlcKHxYvcDxJxu2rixA#PxavPrAj};$&v1aF8=7NFi zILPVI3*Ihff=5jXgm^B*k4YCW_28eVxnu~Hh^|;Rts5#Ntr0sr`msq@wAss3=drrn zgBo@hi7^RgaA%VhEj?*Z)=z#D&z>5_mb#?l1gRF+`Evq=ugrzovo~;F>oC^&eX3M5 zU`nwmdd$kzf}CY%arD}r829ETKASY2_T|mNlBJulCcu>9f&@60FcuBIT!N5~x~zwf z9#pUTf|k!!nEKEL@x5acoVncw=RHec(!nO^Qm_g%d^DsQzJZiiu|k}Ez7b}>wvlH~ ze+Q}h$KhOdE&TM^OH4rkrOgx}TV4(iwrz8) zI#h%S1|wPep~3P){Z}*hm3BnI&xDxL>$p{UIqsUei!Hvj7v2ednbP?kIA8cnOnaIS z2eO}Gb@6sg_Vj~h|IXBws!hSoBU!GRHyieP5nJINFBGgFfX$(Lm>BFW>tLG-nqzfw za+EyRcGemjbS-h5W9`e+W(#vv9^8RHc4n8qY zfPOQCD9r>kQBuR>xxhkZHK7o@0AfyRk-Exa$B4V`D0938&HGlvD8FwQ?YKmAbJ)#{ zt#*sC+JB27QG*x`k;aOrEMiOQys>kaHXJAAtHsT}Dj2O?hvlt%p;SkecC~o0;^6C! z3HJyU2HQZ|zDB%oZz0pp>dXeko09H|rL5k66!q?Z}LWy_nz`|niIdF8L-@2$t=efsZ1_qn%4tJMKG{GJs# zuQ?^ozBZQ*JEk%F0n;(PrZ+VE=!+K%`%qlga#`C!EezPG#yWPC-Y@qh7MU@Zeox;IUFST4ScA2=K;?um0o3u`co7})V}<&kwdqrVK&}e;_^|aTen?D* zH*b?^+#PkC65xcBj&G#%&Up}Z`x6Fd1u~6*-$nZfFHF{*Onq)i*O~_unXWWOVY^3{ zeRNn(Nj+zfN3a^}rf&g;H%Gux6K5&U>I`ZbOr#P^9op|Sge~%!4jNK!a)!MYCA8S0 zeS|VO-7}CWS>)h)cmYP6r7{1&8mRLt!DaXNGImUhT`pA+7WygEgl7Xq&FP(}<62G8 zzG5O%VY;H%+g6ypD-AV$r1g!y8WBA-pr~cupn6#bVX2nrTiJz0zg8vl1U)==p#-GG zCVArmT!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#{J#+J_E1(ynSbA1L7{`x-wq1i9!eb*6jW6HYv;8T z1ckYw%NEUujG7ZRCn9j(u*IQKi{?Z|j2k^tA>jYoth=`N;A&3TPm0pNY~TLp9vCZF z{Lq7;A%W3>!xn``%?X@0CpIuz>dlNfA>&35`d@d*(#Fo#-p1bE#?IQ_X5>g)D?58D zOQ{XFvl}sDxZOz05z=q1M%uRbvNT;oZIIN%_RZU)LwjfcQ)-n8zQ~jAw{N}M|L|S= z-mSgULl?}Q_hXDT!Xur|(thn*pYQg6E8Y7(M$oe8(8#Ee(5N4S{OiHg{4R*EbcDYT z^6z?8A46}C&Ydm(xu?F;kpGO) v>7PIS$L-Aar>)f9OL5zqw4ZYOtzXA~+;9Kx?a`^7RH2jf)&BT@JoeuJlpN$? literal 0 HcmV?d00001 diff --git a/original_scores b/original_scores new file mode 100644 index 0000000000000000000000000000000000000000..e8c777a2efc4be8601a37ea3c4d3866069fcc635 GIT binary patch literal 2500 zcmaKu3sh5A7KSeggb<&IidvS{s1PX-Lnd7eDq!uQl9>r8H4LnJW zFwc$e6KWF96Y^JWYMRMCZkn6bW+t{)E!MAySuGJb12A$Tb zH>Ox5v&>04vo0xGueC}{;<0Lzm(?Uun|g)B$9jr|S93xS^JJ&T_qghzjL$o(Ub5P% zGdj^RQT<*oyKC|M*`T*t4LT=3i#!-A;@M8`nLn6Oyu){XRe5MkRG2ILO_cw0v$7SI zHl4z8IcfNQ>z{y|=M8a6jx;>jhc}OF!N|#>#9jgG3m=l(RWtBhM~bbnehjCX*dKEv z| z2+VU6v!ib%W9uydmFgonUXqQ1)VVkkbLj!eAh34mv2_@5V-rF-?yoS}y-gkBZJi)T4_KR<30Gcm6<_g#-VR^B@|^iWd>_ zbfoKN?}6LZ3hvO=^`QFKV|-B;$NFtt0E(}+;@%r|c<|33va%69>Y$|_Un zw7Q|>u5b&f+_3>4e-g+x3(N3XVFhh!o{R;7&+Yb*TzkVBf7B+pbGsiMzMA!~pgh!gNy+;MqR# zS;u(vuRTLW%te|}6U1G~SVp#Jex+5N4(ikQM|hNR7i{N-!^XBtFe*BbJy_w#{q$fl zaoiSDapOePc+RJzuaCsdy#i=3*m~aMQt!)*qsur-j zRtUMApH@TI)ahkUkR1@ zbKz*oMk0}J!lKA3SbF+BwwP9sQoRMaK827s=8q_=A4O*L?ae+<3F4l#ZpTd(S0U*A znM8fK69j5+_DT!T)8A$z#Ek*_<#I@zP-5?s7KiNz_M^{E1aYMi{o3=8+5G{r9@t5i zfAl%dR13LvpH78+fgW7to*P)CFDK)c>yVk>sPq!};?tH)F5t;#-1dDo7Kdfw*OyE2 z+uY$?oJSJ0ec>Rb^ZQ}DCW7m~f3(UyzBjg1&%wbvbKoD+do(!TAMO8^7;J68HFUrT4VBIC&WNtaT~u9mKLX4N>^?@KAC;;5w)3T3!b86GiS86Fv_ z2#ZjLg*q)VY}|MS-=t7PC=`)O=PYNt70G4%44oby>*k#78|Ck=jzgRcZ&^=ufGefD z1`d6T(7`{sQr_sQE2jH~4Soyrd%i|+#dOK1E2I0C_`bzx;}zPK@p_|OA>F^2|633> zU(mQh{`X$S$BINjug4A-2EU>3KjSO=?{&Ge{Ne4nvsB`ZANT5RiC5RNvpGK2!&zQ1 L9{l0#-_`d&RQ8R9 literal 0 HcmV?d00001 From efeb5de7a541482c7f01dc78824a86dda0b52fc4 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 10 Jul 2024 18:10:54 +0000 Subject: [PATCH 19/57] Fixes the bf16/int8 cache stack. --- jetstream_pt/cache_manager.py | 99 ++++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 24 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index ea2fd713..0bb14a14 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -131,13 +131,19 @@ def __init__( self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs): - b, head, _, dim = cache_k.shape + b = cache_k.shape[-4] for bb, pp in enumerate(self.pos.reshape(b)): - new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, 0) - new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, 0) - cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, (bb, 0, pp, 0)) - cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, (bb, 0, pp, 0)) - + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + # We are not handling generate_cache_stacked=True new_cache_stacked=False here + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) + cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) return cache_k, cache_v def finalize(self): @@ -153,15 +159,13 @@ def finalize(self): if self.env.generate_cache_stacked: layer, b, head, len, dim = self.cache_k.shape if self.env.new_cache_stacked: - self.cache_k._elem = self.cache_k._elem.at[:, self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, layer, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[:, self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, layer, head, dim)) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) else: for i in range(self.env.num_layers): self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) else: # Try to use shard_map to get rid of the data copy - b, head, _, dim = self.cache_k.shape self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) def update(self, key, value, layer_id:int): @@ -285,6 +289,19 @@ def __init__( self.sharding = sharding self.env = env + if self.env.lazy_cache_update: + if self.env.generate_cache_stacked: + layer, batch, heads, time, dim = self.cache_k.shape + new_kv_dim = (layer, batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_kv_dim), jnp.zeros(new_kv_dim))) + if self.env.new_cache_stacked: + new_scale_dim = (layer, batch, 1, 1, 1) + self.new_k_scaler, self.new_v_scaler = torchjax.to_torch((jnp.zeros(new_scale_dim), jnp.zeros(new_scale_dim))) + else: + self.new_ks, self.new_vs = [], [] + else: # when generate cache is not stacked, new cache cannot stack + assert not self.env.new_cache_stacked + cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads none_pspec = self.env.partition_by_axis() in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec) @@ -292,12 +309,19 @@ def __init__( self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs): - b, head, _, dim = cache_k.shape + b = cache_k.shape[-4] for bb, pp in enumerate(self.input_pos.reshape(b)): - new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, 0) - new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, 0) - cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, (bb, 0, pp, 0)) - cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, (bb, 0, pp, 0)) + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + update_start_indices = (0, bb, 0, pp, 0) + if self.env.new_cache_stacked: + slice_dim = 1 + # We are not handling generate_cache_stacked=True new_cache_stacked=False + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) + cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) return cache_k, cache_v def state(self): @@ -315,8 +339,8 @@ def empty(cls, shape, device, env): cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) - kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) - vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) + kscaler = jnp.ones((shape[-4], 1, shape[-2], 1), dtype=jnp.bfloat16) + vscaler = jnp.ones((shape[-4], 1, shape[-2], 1), dtype=jnp.bfloat16) cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) @@ -335,23 +359,38 @@ def update(self, xk, xv, layer_id:int): k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) + self.new_k_scaler = kscale + self.new_v_scaler = vscale + if self.env.lazy_cache_update: - self.new_ks = k_quant - self.new_vs = v_quant - self.new_k_scaler = kscale - self.new_v_scaler = vscale - + if self.env.new_cache_stacked: + self.new_ks[layer_id, ...] = k_quant + self.new_vs[layer_id, ...] = v_quant + else: + if self.env.generate_cache_stacked: + self.new_ks.append(k_quant) + self.new_vs.append(v_quant) + else: + self.new_ks = k_quant + self.new_vs = v_quant elif self.env.ring_buffer: self.cache_k[:, :, self.input_pos, :] = k_quant self.cache_v[:, :, self.input_pos, :] = v_quant self.k_scaler[:, :, self.input_pos, :] = kscale self.v_scaler[:, :, self.input_pos, :] = vscale else: + # We don't handle left aligned but lazy_cache_update=False self.cache_k[self.batch, :, self.input_pos, :] = k_quant.squeeze(2) self.cache_v[self.batch, :, self.input_pos, :] = v_quant.squeeze(2) self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2) self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2) - return self.cache_k, self.cache_v, k_quant, v_quant, self.k_scaler, self.v_scaler, kscale, vscale + + ret_cache_k = self.cache_k[layer_id] if self.env.generate_cache_stacked else self.cache_k + ret_cache_v = self.cache_v[layer_id] if self.env.generate_cache_stacked else self.cache_v + ret_k_scaler = self.k_scaler[layer_id] if self.env.generate_cache_stacked else self.k_scaler + ret_v_scaler = self.v_scaler[layer_id] if self.env.generate_cache_stacked else self.v_scaler + + return ret_cache_k, ret_cache_v, k_quant, v_quant, ret_k_scaler, ret_v_scaler, kscale, vscale def finalize(self): if not self.env.lazy_cache_update: @@ -363,5 +402,17 @@ def finalize(self): else: self.k_scaler[self.batch, :, self.input_pos, :] = self.new_k_scaler.squeeze(2) self.v_scaler[self.batch, :, self.input_pos, :] = self.new_v_scaler.squeeze(2) - # Try to use shard_map to get rid of the data copy - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) + if self.env.generate_cache_stacked: + layer, b, head, len, dim = self.cache_k.shape + if self.env.new_cache_stacked: + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) + # self.cache_k._elem = self.cache_k._elem.at[:, self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, layer, head, dim)) + # self.cache_v._elem = self.cache_v._elem.at[:, self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, layer, head, dim)) + else: + # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. + for i in range(self.env.num_layers): + self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) + else: + # Try to use shard_map to get rid of the data copy + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) From ed1acffdb435dc0d409cc83afc44208fde303332 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 10 Jul 2024 21:15:04 +0000 Subject: [PATCH 20/57] Fix int8 stacked cache insertion in engine and finalization. --- jetstream_pt/cache_manager.py | 44 +++++++++++++++++++++-------------- jetstream_pt/engine.py | 35 +++++++++++++++++++--------- jetstream_pt/environment.py | 3 ++- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 0bb14a14..b955feb6 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -111,14 +111,13 @@ def __init__( # Keep this one it's used in the specific model code. self.stacked = env.generate_cache_stacked self.batch = jnp.arange(self.env.batch_size) - # The other way is to store the list and loop over to insert in finalize() if self.env.lazy_cache_update: if self.env.generate_cache_stacked: if self.env.new_cache_stacked: layer, batch, heads, time, dim = self.cache_k.shape new_dim = (layer, batch, heads, 1, dim) - self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_dim), jnp.zeros(new_dim))) + self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_dim, dtype=self.env.default_type), jnp.zeros(new_dim, dtype=self.env.default_type))) else: self.new_ks, self.new_vs = [], [] else: # when generate cache is not stacked, new cache cannot stack @@ -288,15 +287,16 @@ def __init__( self.input_pos = input_pos self.sharding = sharding self.env = env + self.stacked = env.generate_cache_stacked if self.env.lazy_cache_update: if self.env.generate_cache_stacked: layer, batch, heads, time, dim = self.cache_k.shape new_kv_dim = (layer, batch, heads, 1, dim) - self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_kv_dim), jnp.zeros(new_kv_dim))) + self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_kv_dim, dtype=jnp.int8), jnp.zeros(new_kv_dim, dtype=jnp.int8))) if self.env.new_cache_stacked: new_scale_dim = (layer, batch, 1, 1, 1) - self.new_k_scaler, self.new_v_scaler = torchjax.to_torch((jnp.zeros(new_scale_dim), jnp.zeros(new_scale_dim))) + self.new_k_scaler, self.new_v_scaler = torchjax.to_torch((jnp.zeros(new_scale_dim, dtype=self.env.default_type), jnp.zeros(new_scale_dim, dtype=self.env.default_type))) else: self.new_ks, self.new_vs = [], [] else: # when generate cache is not stacked, new cache cannot stack @@ -338,9 +338,13 @@ def empty(cls, shape, device, env): """Create empty kv caches""" cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) - - kscaler = jnp.ones((shape[-4], 1, shape[-2], 1), dtype=jnp.bfloat16) - vscaler = jnp.ones((shape[-4], 1, shape[-2], 1), dtype=jnp.bfloat16) + + if env.generate_cache_stacked: + s_shape = (shape[0], shape[1], 1, shape[3], 1) + else: + s_shape = (shape[0], 1, shape[2], 1) + kscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) + vscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) @@ -359,20 +363,23 @@ def update(self, xk, xv, layer_id:int): k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) - self.new_k_scaler = kscale - self.new_v_scaler = vscale - if self.env.lazy_cache_update: if self.env.new_cache_stacked: self.new_ks[layer_id, ...] = k_quant self.new_vs[layer_id, ...] = v_quant + self.new_k_scaler[layer_id, ...] = kscale + self.new_v_scaler[layer_id, ...] = vscale else: if self.env.generate_cache_stacked: self.new_ks.append(k_quant) self.new_vs.append(v_quant) + self.new_k_scaler.append(kscale) + self.new_v_scaler.append(vscale) else: self.new_ks = k_quant self.new_vs = v_quant + self.new_k_scaler = kscale + self.new_v_scaler = vscale elif self.env.ring_buffer: self.cache_k[:, :, self.input_pos, :] = k_quant self.cache_v[:, :, self.input_pos, :] = v_quant @@ -400,19 +407,22 @@ def finalize(self): self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set(self.new_ks._elem) self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set(self.new_vs._elem) else: - self.k_scaler[self.batch, :, self.input_pos, :] = self.new_k_scaler.squeeze(2) - self.v_scaler[self.batch, :, self.input_pos, :] = self.new_v_scaler.squeeze(2) if self.env.generate_cache_stacked: - layer, b, head, len, dim = self.cache_k.shape + layer, b, head, _, dim = self.cache_k.shape if self.env.new_cache_stacked: self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) - # self.cache_k._elem = self.cache_k._elem.at[:, self.batch, :, self.pos, :].set(self.new_ks._elem.reshape(b, layer, head, dim)) - # self.cache_v._elem = self.cache_v._elem.at[:, self.batch, :, self.pos, :].set(self.new_vs._elem.reshape(b, layer, head, dim)) + self.k_scaler._elem = self.k_scaler._elem.at[:, self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, layer, 1, 1)) + self.v_scaler._elem = self.k_scaler._elem.at[:, self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, layer, 1, 1)) else: # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. for i in range(self.env.num_layers): - self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) + self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) + self.k_scaler._elem = self.k_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_k_scaler[i]._elem.reshape(b, 1, 1)) + self.v_scaler._elem = self.v_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_v_scaler[i]._elem.reshape(b, 1, 1)) else: # Try to use shard_map to get rid of the data copy + b = self.cache_k.shape[-4] self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) + self.k_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, 1, 1)) + self.v_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, 1, 1)) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index b8832334..4a8a2344 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -366,15 +366,18 @@ def insert(cache, new_entry, update_index): else: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, scaler, new_entry): + def insert(cache, scaler, new_entry, update_index): reduce_axis = (1, 3) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) + if self.env.generate_cache_stacked: + vals = jnp.expand_dims(vals, 0) + scales = jnp.expand_dims(scales, 0) new_scaler = jax.lax.dynamic_update_slice( scaler, scales, - [slot, 0, pos, 0], + update_index, ) new_scaler = jax.lax.with_sharding_constraint( new_scaler, self.replicated @@ -382,19 +385,29 @@ def insert(cache, scaler, new_entry): res = jax.lax.dynamic_update_slice( cache, vals, - [slot, 0, pos, 0], + update_index, ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res, new_scaler - for (k, v), (kscaler, vscaler), (newk, newv) in zip( - decode_state.caches, decode_state.cache_scales, prefix.caches - ): - kcache, kscale = insert(k, kscaler, newk) - vcache, vscale = insert(v, vscaler, newv) - caches.append((kcache, vcache)) - scales.append((kscale, vscale)) - + if self.env.generate_cache_stacked: + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + #newk = jnp.expand_dims(newk, 0) + #newv = jnp.expand_dims(newv, 0) + cache_k, k_scale = insert(decode_state.caches[0][0], decode_state.cache_scales[0][0], newk, update_index) + cache_v, v_scale = insert(decode_state.caches[0][1], decode_state.cache_scales[0][1], newv, update_index) + caches = [(cache_k, cache_v)] + scales = [(k_scale, v_scale)] + else: + update_index = [slot, 0, pos, 0] + for (k, v), (kscaler, vscaler), (newk, newv) in zip( + decode_state.caches, decode_state.cache_scales, prefix.caches + ): + kcache, kscale = insert(k, kscaler, newk, update_index) + vcache, vscale = insert(v, vscaler, newv, update_index) + caches.append((kcache, vcache)) + scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) return DecodeState( tokens, diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 9af21b58..48b0a5eb 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -16,6 +16,7 @@ from typing import Tuple import jax +import jax.numpy as jnp import jax.sharding as jsharding from jax.experimental import mesh_utils import torch_xla2 @@ -141,7 +142,7 @@ def __init__(self, data: JetEngineEnvironmentData): self.lazy_cache_update = self._data.lazy_cache_update self.testing = self._data.testing self.testing_seed = self._data.testing_seed - + self.default_type = jnp.bfloat16 if self._data.bf16_enable else jnp.float32 if self.generate_cache_stacked: self.cache_shape = (self.num_layers, *self._data.cache_shape) else: From 9460f660740f316630952372c8010e2106cfc68c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 11 Jul 2024 03:15:49 +0000 Subject: [PATCH 21/57] Fixes int8 with lazy cache update. --- jetstream_pt/attention_kernel.py | 8 ++++---- jetstream_pt/cache_manager.py | 2 +- jetstream_pt/engine.py | 4 ++-- jetstream_pt/environment.py | 2 +- jetstream_pt/layers.py | 23 ++++++++--------------- 5 files changed, 16 insertions(+), 23 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index b8b5e2ad..66397767 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -446,8 +446,8 @@ def ragged_mha( with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( - ragged_mqa, - #ragged_mqa_reference, + # ragged_mqa, + ragged_mqa_reference, bk=bk, mask_value=mask_value, normalize_var=normalize_var, @@ -509,7 +509,7 @@ def flash_attention(xq, keys, values, mask=None, normalize_var=True): denominator = unnormalized.sum(axis=-1, keepdim=True) # print(f"logits {logits.shape} logits_max {logits_max.shape} denominator {denominator}") o = ( - torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(values), values) + torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) # / denominator[..., None] / denominator ) @@ -542,7 +542,7 @@ def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, n unnormalized = unnormalized * v_scaler.reshape(v_scaler.shape[0], 1, 1, v_scaler.shape[2]) o = ( - torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(values), values) + torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) / denominator ) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index b955feb6..03d92cbe 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -94,7 +94,7 @@ def __init__( self, cache_k: torch.Tensor, # previous cache cache_v: torch.Tensor, # previous cache - position: int, # position to store the cache + position: int | torch.Tensor, # position to store the cache sharding, env=None, ): diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 4a8a2344..8b329867 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -628,7 +628,7 @@ def update_mask(): return decode_state.mask.at[batch, decode_state.input_pos].set(0) mask = decode_state.mask - if not self.env.flash_attention: + if not self.env.lazy_cache_update: mask = update_mask() logits, new_caches, new_scales = self._call_model_generate( params, @@ -643,7 +643,7 @@ def update_mask(): ragged_block_index, ) - if self.env.flash_attention: + if self.env.lazy_cache_update: # fill mask later, now use flash attention mask = update_mask() diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 48b0a5eb..1836cf58 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -86,7 +86,7 @@ class JetEngineEnvironmentData: shard_on_batch: bool = False # Whether to enable ragged multi head attention. - ragged_mha: bool = False + ragged_mha: bool = True # The block size for the ragged attention. block_size: int = 512 diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index c098b18d..5c83febe 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -537,11 +537,6 @@ def __call__( _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq_expanded = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - else: - xq_expanded = xq - def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] != 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( @@ -562,15 +557,17 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, k_scaler, v_scaler, mask=local_mask) else: + if seqlen == 1: + xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) local_output = self.dense_attention(xq, keys, values, k_scaler, v_scaler, local_mask) local_max = None local_denom = None - if not self.env.ragged_mha and seqlen == 1: - local_output = local_output[:, :, 0:1, :] - if local_max is not None: - local_max = local_max[:, :, 0:1, :] - local_denom = local_denom[:, :, 0:1, :] + if seqlen == 1: + local_output = local_output[:, :, 0:1, :] + if local_max is not None: + local_max = local_max[:, :, 0:1, :] + local_denom = local_denom[:, :, 0:1, :] # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") # if local_max is not None and local_denom is not None: @@ -583,7 +580,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): keys = repeat_kv(orig_keys, n_rep) values = repeat_kv(orig_values, n_rep) with jax.named_scope("attn_qkv"): - existing_output, (existing_max, existing_denom) = attend(xq_expanded, keys, values, k_scaler, v_scaler, mask) + existing_output, (existing_max, existing_denom) = attend(xq, keys, values, k_scaler, v_scaler, mask) # For non flash attention or prefill, existing output contains everything if not self.env.lazy_cache_update or seqlen > 1: @@ -598,14 +595,10 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # return new_output with jax.named_scope("attn_global"): - existing_denom = existing_denom[:, 0:1] - existing_max = existing_max[:, 0:1] global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) - existing_output = existing_output * existing_denom * torch.exp(existing_max) / global_sum new_output = new_output * new_denom * torch.exp(new_max) / global_sum attn_out = existing_output + new_output - return attn_out From f1880860f826382f6b18eb7d7c807f5a1161e013 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 11 Jul 2024 03:18:54 +0000 Subject: [PATCH 22/57] Updates the int8 test. --- tests/test_quantization.py | 109 ++++++++++++++++++++++++++----------- 1 file changed, 77 insertions(+), 32 deletions(-) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 190e6fda..210abe93 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -72,69 +72,114 @@ def _print_diff(self, w, w_dq): def test_kv_cache(self): """test kv cache quantization""" - cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.ragged_mha=False + env_data.flash_attention=True + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=True + env_data.lazy_cache_update=True + env, _ = helpers.make_env_tiny(False, update_env_data) + + batch = env.batch_size + if env.generate_cache_stacked: + cache_shape = (env.num_layers, batch, 2, 100, 2) # layer, bs, num heads, seqlen, dim + else: + cache_shape = (batch, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - env, _ = helpers.make_env_tiny() + cache = cache_manager.Int8KVCacheGenerate.empty( - cache_shape, None, False, env + cache_shape, None, env ) # seqlen is 1 - k = self._xla_tensor((3, 2, 1, 2)) - v = self._xla_tensor((3, 2, 1, 2)) - - cache.input_pos = [57] - new_k, new_v, scaler_k, scaler_v = cache.update(k, v) - new_k = new_k * scaler_k - new_v = new_v * scaler_v - - self.assertTrue( - jnp.allclose(k._elem, new_k._elem[:, :, 57:58, :], atol=0.1) - ) - self.assertTrue( - jnp.allclose(v._elem, new_v._elem[:, :, 57:58, :], atol=0.1) - ) + k = self._xla_tensor((batch, 2, 1, 2)) + v = self._xla_tensor((batch, 2, 1, 2)) + + # cache.input_pos = [57] if env.ring_buffer else torch_xla2.default_env().to_xla(torch.tensor([57] * batch, dtype=torch.int32)) + cache.input_pos = [57] if env.ring_buffer else jnp.array([57] * batch) + layer = 1 + # layer id may or may not take effect, depends on the env config. + cache.update(k, v, layer_id=layer) + cache.finalize() + new_k = cache.cache_k * cache.k_scaler + new_v = cache.cache_v * cache.v_scaler + + if env.generate_cache_stacked: + self.assertTrue( + jnp.allclose(k._elem, new_k._elem[layer, :, :, 57:58, :], atol=0.1) + ) + self.assertTrue( + jnp.allclose(v._elem, new_v._elem[layer, :, :, 57:58, :], atol=0.1) + ) + else: + self.assertTrue( + jnp.allclose(k._elem, new_k._elem[:, :, 57:58, :], atol=0.1) + ) + self.assertTrue( + jnp.allclose(v._elem, new_v._elem[:, :, 57:58, :], atol=0.1) + ) def test_kv_kernel(self): """test kv cache quantization""" - cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.ragged_mha=False + env_data.flash_attention=True + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=True + env_data.lazy_cache_update=True + env, _ = helpers.make_env_tiny(False, update_env_data) + + + batch = env.batch_size + if env.generate_cache_stacked: + cache_shape = (env.num_layers, batch, 2, 100, 2) # bs, num heads, seqlen, dim + else: + cache_shape = (batch, 2, 100, 2) # layers, bs, num heads, seqlen, dim + with jax.default_device(jax.devices("cpu")[0]): - env, _ = helpers.make_env_tiny(False) + key = jax.random.PRNGKey(123) key2 = jax.random.PRNGKey(456) cache_k_jax = jax.random.normal(key, cache_shape) cache_v_jax = jax.random.normal(key2, cache_shape) - cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) + start = jnp.zeros((batch,), dtype=jnp.int32) + pos = [57] if env.ring_buffer else jnp.array([57] * batch, dtype=jnp.int32) + mask = jax.lax.broadcast_in_dim(jnp.array([0] * 57 + [float("-inf")] * 43), (env.batch_size, 100), (1,)) - cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None, env) + cache_k, cache_v, start, mask = torchjax.to_torch((cache_k_jax, cache_v_jax, start, mask)) + + cache = cache_manager.KVCacheGenerate(cache_k, cache_v, pos, None, env) # 1 is seqlen - xq = jax.random.normal(key, (3, 2, 1, 2)) - xk = jax.random.normal(key, (3, 2, 1, 2)) - xv = jax.random.normal(key, (3, 2, 1, 2)) + xq = jax.random.normal(key, (batch, 2, 1, 2)) + xk = jax.random.normal(key, (batch, 2, 1, 2)) + xv = jax.random.normal(key, (batch, 2, 1, 2)) xq, xk, xv = torchjax.to_torch((xq, xk, xv)) - attention_float = layers.AttentionKernel(env) - float_res = attention_float(xq, xk, xv, None, cache) + layer = 1 + attention_float = layers.AttentionKernel(env, layer_id=layer) + float_res = attention_float(xq, xk, xv, mask, cache, start=start, end=pos) # == cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) - cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (1, 3)) - cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (1, 3)) + cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (-3, -1)) + cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (-3, -1)) cache_int = cache_manager.Int8KVCacheGenerate( cache_k_int, cache_v_int, cache_k_scaler, cache_v_scaler, - [0], + pos, None, env, ) - attention_quant = layers.Int8KVAttentionKernel(env) - int_res = attention_quant(xq, xk, xv, None, cache_int) - + attention_quant = layers.Int8KVAttentionKernel(env, layer_id=layer) + + int_res = attention_quant(xq, xk, xv, mask, cache_int, start=jnp.zeros((batch,), dtype=jnp.int32), end=pos) self.assertTrue(jnp.allclose(float_res.jax(), int_res.jax(), atol=0.01)) def test_quantize_dequantize_tensor(self): From 124bb7154c548dfa3d4cdc8d00a71275064567f6 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 11 Jul 2024 17:07:10 +0000 Subject: [PATCH 23/57] Fix the int8 ragged attention output sharding. --- jetstream_pt/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 5c83febe..3ab7d3de 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -508,7 +508,7 @@ def __init__(self, env, layer_id): self.ragged_attention = ak.RaggedAttentionKernel( env, input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), + output_specs=(qkv_pspec, (qkv_pspec, qkv_pspec)), sharding_axis=self.shard_axis, ) self.layer_id = layer_id From ffaba5a96766f867f8cdb21062c3cdbea15937d9 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 11 Jul 2024 17:24:28 +0000 Subject: [PATCH 24/57] Fix group query attention broadcasting issue. --- jetstream_pt/attention_kernel.py | 6 ++++++ jetstream_pt/layers.py | 28 ++++++++++++++++------------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 66397767..daf1f451 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -443,6 +443,12 @@ def ragged_mha( ) # New cache has t=1 bk = min(bk, k.shape[-2]) + bq, hq, tq, dq = q.shape + hkv = k.shape[1] + rep = hq // hkv + if rep > 1: + q = q.reshape(bq, hkv, rep * tq, dq) + with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 3ab7d3de..cd015284 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -463,12 +463,13 @@ def attend(xq, keys, values, local_mask=None): with jax.named_scope("attn_insert_cache"): orig_keys, orig_values = cache.update(xk, xv, self.layer_id) - keys = repeat_kv(orig_keys, n_rep) - values = repeat_kv(orig_values, n_rep) + if not self.env.ragged_mha: + orig_keys = repeat_kv(orig_keys, n_rep) + orig_values = repeat_kv(orig_values, n_rep) # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): - existing_output, (existing_max, existing_denom) = attend(xq, keys, values, mask) + existing_output, (existing_max, existing_denom) = attend(xq, orig_keys, orig_values, mask) # Updating cache during each step still has very large impact on latency. # For non flash attention or prefill, existing output contains everything @@ -477,9 +478,10 @@ def attend(xq, keys, values, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): - new_keys = repeat_kv(xk, n_rep) - new_values = repeat_kv(xv, n_rep) - new_output, (new_max, new_denom) = attend(xq, new_keys, new_values, None) + if not self.env.ragged_mha: + xk = repeat_kv(xk, n_rep) + xv = repeat_kv(xv, n_rep) + new_output, (new_max, new_denom) = attend(xq, xk, xv, None) # if cache.cache_k is None: # Prefill # return new_output @@ -577,10 +579,11 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): with jax.named_scope("attn_insert_cache"): orig_keys, orig_values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) - keys = repeat_kv(orig_keys, n_rep) - values = repeat_kv(orig_values, n_rep) + if not self.env.ragged_mha: + orig_keys = repeat_kv(orig_keys, n_rep) + orig_values = repeat_kv(orig_values, n_rep) with jax.named_scope("attn_qkv"): - existing_output, (existing_max, existing_denom) = attend(xq, keys, values, k_scaler, v_scaler, mask) + existing_output, (existing_max, existing_denom) = attend(xq, orig_keys, orig_values, k_scaler, v_scaler, mask) # For non flash attention or prefill, existing output contains everything if not self.env.lazy_cache_update or seqlen > 1: @@ -588,9 +591,10 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): - new_keys = repeat_kv(new_key, n_rep) - new_values = repeat_kv(new_value, n_rep) - new_output, (new_max, new_denom) = attend(xq, new_keys, new_values, new_k_scaler, new_v_scaler, None) + if not self.env.ragged_mha: + new_key = repeat_kv(new_key, n_rep) + new_value = repeat_kv(new_value, n_rep) + new_output, (new_max, new_denom) = attend(xq, new_key, new_value, new_k_scaler, new_v_scaler, None) # if cache.cache_k is None: # Prefill # return new_output From 78789f1fffc6323143c1ba39512a13d904064404 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 11 Jul 2024 18:05:41 +0000 Subject: [PATCH 25/57] Fix shard map input issue. Variables not listed as inputs are freezed into jit function. --- jetstream_pt/cache_manager.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 03d92cbe..f18a8a9e 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -125,13 +125,13 @@ def __init__( cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads none_pspec = self.env.partition_by_axis() - in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec) + in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec, none_pspec) out_specs = (cache_pspec, cache_pspec) self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) - def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs): + def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): b = cache_k.shape[-4] - for bb, pp in enumerate(self.pos.reshape(b)): + for bb, pp in enumerate(pos.reshape(b)): slice_dim = 0 update_start_indices = (bb, 0, pp, 0) if self.env.generate_cache_stacked: @@ -158,14 +158,14 @@ def finalize(self): if self.env.generate_cache_stacked: layer, b, head, len, dim = self.cache_k.shape if self.env.new_cache_stacked: - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.pos) else: for i in range(self.env.num_layers): self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) else: # Try to use shard_map to get rid of the data copy - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.pos) def update(self, key, value, layer_id:int): """Update kv cache""" @@ -304,13 +304,13 @@ def __init__( cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads none_pspec = self.env.partition_by_axis() - in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec) + in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec, none_pspec) out_specs = (cache_pspec, cache_pspec) self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) - def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs): + def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, input_pos): b = cache_k.shape[-4] - for bb, pp in enumerate(self.input_pos.reshape(b)): + for bb, pp in enumerate(input_pos.reshape(b)): slice_dim = 0 update_start_indices = (bb, 0, pp, 0) if self.env.generate_cache_stacked: @@ -410,9 +410,9 @@ def finalize(self): if self.env.generate_cache_stacked: layer, b, head, _, dim = self.cache_k.shape if self.env.new_cache_stacked: - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.input_pos) self.k_scaler._elem = self.k_scaler._elem.at[:, self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, layer, 1, 1)) - self.v_scaler._elem = self.k_scaler._elem.at[:, self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, layer, 1, 1)) + self.v_scaler._elem = self.v_scaler._elem.at[:, self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, layer, 1, 1)) else: # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. for i in range(self.env.num_layers): @@ -423,6 +423,6 @@ def finalize(self): else: # Try to use shard_map to get rid of the data copy b = self.cache_k.shape[-4] - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.input_pos) self.k_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, 1, 1)) - self.v_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, 1, 1)) + self.v_scaler._elem = self.v_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, 1, 1)) From e4643eeb9a89b75b7059f2030db9ec7c3ab4428c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 12 Jul 2024 23:16:29 +0000 Subject: [PATCH 26/57] Fix the flash attention mask shape; Fix the update single cache line quant version --- jetstream_pt/attention_kernel.py | 8 +++---- jetstream_pt/cache_manager.py | 37 ++++++++++++++++++++------------ jetstream_pt/engine.py | 3 ++- jetstream_pt/layers.py | 26 +++++++++++----------- 4 files changed, 43 insertions(+), 31 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index daf1f451..60a9bef9 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -305,7 +305,7 @@ def ragged_mqa_reference( ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" batch_size, num_heads, head_dim = q.shape - assert end.shape == (batch_size,) + #assert end.shape == (batch_size,) assert end.dtype == jnp.int32 seq_len = k.shape[1] @@ -361,7 +361,7 @@ def kv_index_map(b, i, start_ref, out, m, l = pl.pallas_call( functools.partial( - ragged_flash_attention_kernel, + ragged_mqa_kernel_reference, bk=bk, mask_value=mask_value, normalize_var=normalize_var, @@ -452,7 +452,7 @@ def ragged_mha( with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( - # ragged_mqa, + #ragged_mqa, ragged_mqa_reference, bk=bk, mask_value=mask_value, @@ -546,7 +546,7 @@ def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, n denominator = unnormalized.sum(axis=-1, keepdim=True) unnormalized = unnormalized * v_scaler.reshape(v_scaler.shape[0], 1, 1, v_scaler.shape[2]) - + #import pdb; pdb.set_trace() o = ( torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) / denominator diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index f18a8a9e..1813dcc1 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -304,25 +304,35 @@ def __init__( cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads none_pspec = self.env.partition_by_axis() - in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec, none_pspec) - out_specs = (cache_pspec, cache_pspec) + in_specs = (*([cache_pspec] * 4), *([none_pspec] * 5)) + out_specs = (cache_pspec, cache_pspec, none_pspec, none_pspec) self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) - def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, input_pos): + def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, k_scaler, v_scaler, new_k_scaler, new_v_scaler, pos): b = cache_k.shape[-4] - for bb, pp in enumerate(input_pos.reshape(b)): + + for bb, pp in enumerate(pos.reshape(b)): slice_dim = 0 update_start_indices = (bb, 0, pp, 0) if self.env.generate_cache_stacked: - update_start_indices = (0, bb, 0, pp, 0) if self.env.new_cache_stacked: slice_dim = 1 - # We are not handling generate_cache_stacked=True new_cache_stacked=False + update_start_indices = (0, bb, 0, pp, 0) + # We are not handling generate_cache_stacked=True new_cache_stacked=False here + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) - new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) + + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) - return cache_k, cache_v + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim(new_k_scaler, bb, 1, slice_dim) + k_scaler = jax.lax.dynamic_update_slice(k_scaler, new_k_scaler_slice, update_start_indices) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim(new_v_scaler, bb, 1, slice_dim) + v_scaler = jax.lax.dynamic_update_slice(v_scaler, new_v_scaler_slice, update_start_indices) + + return cache_k, cache_v, k_scaler, v_scaler def state(self): """Get kv cache state""" @@ -410,9 +420,8 @@ def finalize(self): if self.env.generate_cache_stacked: layer, b, head, _, dim = self.cache_k.shape if self.env.new_cache_stacked: - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.input_pos) - self.k_scaler._elem = self.k_scaler._elem.at[:, self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, layer, 1, 1)) - self.v_scaler._elem = self.v_scaler._elem.at[:, self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, layer, 1, 1)) + caches = [self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.k_scaler._elem, self.v_scaler._elem, self.new_k_scaler, self.new_v_scaler] + self.cache_k._elem, self.cache_v._elem, self.k_scaler._elem, self.v_scaler._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, *caches, self.input_pos) else: # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. for i in range(self.env.num_layers): @@ -423,6 +432,6 @@ def finalize(self): else: # Try to use shard_map to get rid of the data copy b = self.cache_k.shape[-4] - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.input_pos) - self.k_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, 1, 1)) - self.v_scaler._elem = self.v_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, 1, 1)) + self.cache_k._elem, self.cache_v._elem, self.k_scaler._elem, self.v_scaler._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.k_scaler._elem, self.v_scaler._elem, self.new_k_scaler, self.new_v_scaler, self.input_pos) + #self.k_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, 1, 1)) + #self.v_scaler._elem = self.v_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, 1, 1)) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 8b329867..b383a2d1 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -221,7 +221,8 @@ def _call_model_prefill(self, weights, tokens, input_indexes): dtype=self.default_dtype, ) mask = jnp.triu(mask, k=1) - args = (tokens, input_indexes, caches, mask) + start = jnp.zeros((tokens.shape[0],), dtype=jnp.int32) + args = (tokens, input_indexes, caches, mask, start) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index cd015284..b716fc17 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -538,9 +538,12 @@ def __call__( bsz, num_heads, seqlen, head_dim = xq.shape _, num_kv_heads, _, kv_head_dim = xk.shape n_rep = num_heads // num_kv_heads - + def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): - if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] != 1: + if not self.env.ragged_mha and seqlen == 1: + xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + + if self.env.ragged_mha: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, xq, @@ -555,28 +558,26 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): ) local_max = local_max.reshape(*local_max.shape, 1) local_denom = local_denom.reshape(*local_denom.shape, 1) - elif self.env.flash_attention or keys.shape[-2] == 1: + elif self.env.flash_attention: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, k_scaler, v_scaler, mask=local_mask) else: - if seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) local_output = self.dense_attention(xq, keys, values, k_scaler, v_scaler, local_mask) local_max = None local_denom = None - if seqlen == 1: - local_output = local_output[:, :, 0:1, :] - if local_max is not None: - local_max = local_max[:, :, 0:1, :] - local_denom = local_denom[:, :, 0:1, :] + if local_output.shape[-2] == 2: + local_output = local_output[:, :, 0:1, :] + if local_max is not None: + local_max = local_max[:, :, 0:1, :] + local_denom = local_denom[:, :, 0:1, :] # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") # if local_max is not None and local_denom is not None: # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") self.env.apply_sharding(local_output, axis=self.shard_axis) return local_output, (local_max, local_denom) - + #import pdb; pdb.set_trace() with jax.named_scope("attn_insert_cache"): orig_keys, orig_values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) if not self.env.ragged_mha: @@ -716,7 +717,8 @@ def forward( xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) - + if mask.ndim == 2: + mask = mask[:, None, None, :] # if cache is not None and cache.cache_k is not None: # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( From ef0b148174f964a3d9c7f45dab5307309946f706 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 12 Jul 2024 18:57:41 +0000 Subject: [PATCH 27/57] Adds the kv cache test. --- tests/test_quantization.py | 145 +++++++++++++++++++++++-------------- 1 file changed, 92 insertions(+), 53 deletions(-) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 210abe93..20056c5b 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -22,7 +22,7 @@ import torch import torch_xla2 from jax.experimental import mesh_utils -from jetstream_pt import cache_manager, layers, quantize, torchjax +from jetstream_pt import cache_manager, layers, quantize, torchjax, environment from jetstream_pt.environment import QuantizationConfig from jetstream_pt.layers import ( WeightOnlyBlockwiseQuantizedLinear, @@ -33,6 +33,7 @@ from tests import helpers from torch.utils import _pytree as pytree from torch_xla2 import tensor +import copy torch.manual_seed(12345) @@ -79,8 +80,10 @@ def update_env_data(env_data): env_data.generate_cache_stacked=True env_data.new_cache_stacked=True env_data.lazy_cache_update=True - env, _ = helpers.make_env_tiny(False, update_env_data) - + env_data.quant_config.enable_kv_quantization = True + env_data.batch_size = 4 + env, _ = helpers.make_env_tiny(True, update_env_data) + batch = env.batch_size if env.generate_cache_stacked: cache_shape = (env.num_layers, batch, 2, 100, 2) # layer, bs, num heads, seqlen, dim @@ -88,36 +91,43 @@ def update_env_data(env_data): cache_shape = (batch, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - cache = cache_manager.Int8KVCacheGenerate.empty( + cache = cache_manager.KVCacheGenerate.empty( cache_shape, None, env ) # seqlen is 1 k = self._xla_tensor((batch, 2, 1, 2)) v = self._xla_tensor((batch, 2, 1, 2)) - # cache.input_pos = [57] if env.ring_buffer else torch_xla2.default_env().to_xla(torch.tensor([57] * batch, dtype=torch.int32)) - cache.input_pos = [57] if env.ring_buffer else jnp.array([57] * batch) - layer = 1 - # layer id may or may not take effect, depends on the env config. - cache.update(k, v, layer_id=layer) - cache.finalize() - new_k = cache.cache_k * cache.k_scaler - new_v = cache.cache_v * cache.v_scaler - - if env.generate_cache_stacked: - self.assertTrue( - jnp.allclose(k._elem, new_k._elem[layer, :, :, 57:58, :], atol=0.1) - ) - self.assertTrue( - jnp.allclose(v._elem, new_v._elem[layer, :, :, 57:58, :], atol=0.1) - ) - else: - self.assertTrue( - jnp.allclose(k._elem, new_k._elem[:, :, 57:58, :], atol=0.1) - ) - self.assertTrue( - jnp.allclose(v._elem, new_v._elem[:, :, 57:58, :], atol=0.1) - ) + def update_finalize_compare(in_k, in_v, in_layer, in_pos): + cache.input_pos = [in_pos] if env.ring_buffer else jnp.array([in_pos] * batch) + + # layer id may or may not take effect, depends on the env config. + cache.update(in_k, in_v, layer_id=in_layer) + cache.finalize() + if env.quant_config.enable_kv_quantization: + new_k = cache.cache_k * cache.k_scaler + new_v = cache.cache_v * cache.v_scaler + else: + new_k = cache.cache_k + new_v = cache.cache_v + + if env.generate_cache_stacked: + self.assertTrue( + jnp.allclose(k._elem, new_k._elem[in_layer, :, :, in_pos:(in_pos + 1), :], atol=0.1) + ) + self.assertTrue( + jnp.allclose(v._elem, new_v._elem[in_layer, :, :, in_pos:(in_pos + 1), :], atol=0.1) + ) + else: + self.assertTrue( + jnp.allclose(k._elem, new_k._elem[:, :, in_pos:(in_pos + 1), :], atol=0.1) + ) + self.assertTrue( + jnp.allclose(v._elem, new_v._elem[:, :, in_pos:(in_pos + 1), :], atol=0.1) + ) + update_finalize_compare(k, v, in_layer=1, in_pos=57) + update_finalize_compare(k, v, in_layer=1, in_pos=58) + update_finalize_compare(k, v, in_layer=2, in_pos=3) def test_kv_kernel(self): """test kv cache quantization""" @@ -128,9 +138,10 @@ def update_env_data(env_data): env_data.generate_cache_stacked=True env_data.new_cache_stacked=True env_data.lazy_cache_update=True + env_data.quant_config.enable_kv_quantization=False + env_data.batch_size = 4 env, _ = helpers.make_env_tiny(False, update_env_data) - batch = env.batch_size if env.generate_cache_stacked: cache_shape = (env.num_layers, batch, 2, 100, 2) # bs, num heads, seqlen, dim @@ -141,47 +152,75 @@ def update_env_data(env_data): key = jax.random.PRNGKey(123) key2 = jax.random.PRNGKey(456) - cache_k_jax = jax.random.normal(key, cache_shape) - cache_v_jax = jax.random.normal(key2, cache_shape) + cache_k_jax = jax.random.normal(key, cache_shape, dtype=env.default_dtype) + cache_v_jax = jax.random.normal(key2, cache_shape, dtype=env.default_dtype) + # cache_k_jax = jnp.zeros(cache_shape, dtype=env.default_dtype) + # cache_v_jax = jnp.zeros(cache_shape, dtype=env.default_dtype) start = jnp.zeros((batch,), dtype=jnp.int32) - pos = [57] if env.ring_buffer else jnp.array([57] * batch, dtype=jnp.int32) - mask = jax.lax.broadcast_in_dim(jnp.array([0] * 57 + [float("-inf")] * 43), (env.batch_size, 100), (1,)) - - cache_k, cache_v, start, mask = torchjax.to_torch((cache_k_jax, cache_v_jax, start, mask)) - cache = cache_manager.KVCacheGenerate(cache_k, cache_v, pos, None, env) + cache_k, cache_v, start = torchjax.to_torch((cache_k_jax, cache_v_jax, start)) + + # Prepare quantized cache before written in + cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (-3, -1)) + cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (-3, -1)) # 1 is seqlen - xq = jax.random.normal(key, (batch, 2, 1, 2)) - xk = jax.random.normal(key, (batch, 2, 1, 2)) - xv = jax.random.normal(key, (batch, 2, 1, 2)) + xq = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_dtype) + xk = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_dtype) + xv = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_dtype) xq, xk, xv = torchjax.to_torch((xq, xk, xv)) - layer = 1 - attention_float = layers.AttentionKernel(env, layer_id=layer) - float_res = attention_float(xq, xk, xv, mask, cache, start=start, end=pos) + def get_var(position: int): + pos = [position] if env.ring_buffer else jnp.array([position] * batch, dtype=jnp.int64) + mask = jax.lax.broadcast_in_dim(jnp.array([0] * position + [float("-inf")] * (100 - position)), (env.batch_size, 100), (1,)) + mask = torchjax.to_torch((mask)) + return pos, mask - # == - cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) - cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (-3, -1)) - cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (-3, -1)) - cache_int = cache_manager.Int8KVCacheGenerate( + cache = cache_manager.KVCacheGenerate(cache_k, cache_v, None, None, env) + # layer_id doesn't matter, will assign later + attention_float = layers.AttentionKernel(env, layer_id=0) + + float_res = [] + def update_finalize_record(in_attention, in_cache, in_q, in_k, in_v, in_layer, in_pos): + pos, mask = get_var(in_pos) + in_attention.layer_id=in_layer + in_cache.input_pos = pos + ret = in_attention(in_q, in_k, in_v, mask, in_cache, start=start, end=pos) + in_cache.finalize() + return ret + + float_res.append(update_finalize_record(attention_float, cache, xq, xk, xv, 1, 57)) + float_res.append(update_finalize_record(attention_float, cache, xq, xk, xv, 1, 58)) + float_res.append(update_finalize_record(attention_float, cache, xq, xk, xv, 2, 3)) + + # Running into the issue of multiple env object always share the same quant_config. + # Record the results and compare as a workaround. + env._data.quant_config.enable_kv_quantization = True + env = environment.JetEngineEnvironment(env._data) + + cache_int = cache_manager.KVCacheGenerate( cache_k_int, cache_v_int, - cache_k_scaler, - cache_v_scaler, - pos, + None, None, env, + cache_k_scaler=cache_k_scaler, + cache_v_scaler=cache_v_scaler ) - attention_quant = layers.Int8KVAttentionKernel(env, layer_id=layer) - - int_res = attention_quant(xq, xk, xv, mask, cache_int, start=jnp.zeros((batch,), dtype=jnp.int32), end=pos) - self.assertTrue(jnp.allclose(float_res.jax(), int_res.jax(), atol=0.01)) + # layer_id doesn't matter, will assign later + attention_quant = layers.Int8KVAttentionKernel(env, layer_id=0) + + int_res = [] + int_res.append(update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 57)) + int_res.append(update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 58)) + int_res.append(update_finalize_record(attention_quant, cache_int, xq, xk, xv, 2, 3)) + for f, i in zip(float_res, int_res): + self.assertTrue(jnp.allclose(f.jax(), i.jax(), atol=0.01)) + def test_quantize_dequantize_tensor(self): def quantize_dequantize_weight(w, n_bit): From 02e2d0bded7ded3b9aec0b85db9d23afee42b90d Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 12 Jul 2024 23:58:40 +0000 Subject: [PATCH 28/57] Replace quantized cache "pos" with "input_pos" to align with bf16 cache. Fix the kv cache quantization test. --- jetstream_pt/cache_manager.py | 30 +++++++++++++++--------------- tests/test_model_impl.py | 4 ++-- tests/test_quantization.py | 20 +++++++++----------- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 1813dcc1..3a211850 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -101,7 +101,7 @@ def __init__( super().__init__() self.cache_k = cache_k self.cache_v = cache_v - self.pos = position + self.input_pos = position self.sharding = sharding self.env = env @@ -148,24 +148,24 @@ def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): def finalize(self): if not self.env.lazy_cache_update: return - # self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_ks._elem, -2)) - # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.pos].set(jnp.squeeze(self.new_vs._elem, -2)) + # self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_ks._elem, -2)) + # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_vs._elem, -2)) if self.env.ring_buffer: # Assume no cache stack for ring buffer - self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(self.new_ks._elem) - self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(self.new_vs._elem) + self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set(self.new_ks._elem) + self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set(self.new_vs._elem) else: if self.env.generate_cache_stacked: layer, b, head, len, dim = self.cache_k.shape if self.env.new_cache_stacked: - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.pos) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.input_pos) else: for i in range(self.env.num_layers): - self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) + self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) else: # Try to use shard_map to get rid of the data copy - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.pos) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.input_pos) def update(self, key, value, layer_id:int): """Update kv cache""" @@ -191,27 +191,27 @@ def update(self, key, value, layer_id:int): elif self.env.ring_buffer: # Assume no cache stack for ring buffer # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[..., self.pos, :].set(keyj) - self.cache_v._elem = self.cache_v._elem.at[..., self.pos, :].set(valuej) + self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set(keyj) + self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set(valuej) return self.cache_k, self.cache_v else: if self.env.generate_cache_stacked: # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[layer_id, self.batch, :, self.pos, :].set( + self.cache_k._elem = self.cache_k._elem.at[layer_id, self.batch, :, self.input_pos, :].set( keyj.squeeze(2) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[layer_id, self.batch, :, self.pos, :].set( + self.cache_v._elem = self.cache_v._elem.at[layer_id, self.batch, :, self.input_pos, :].set( valuej.squeeze(2) ) return self.cache_k[layer_id], self.cache_v[layer_id] else: # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.pos, :].set( + self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.input_pos, :].set( keyj.squeeze(2) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.pos, :].set( + self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.input_pos, :].set( valuej.squeeze(2) ) return self.cache_k, self.cache_v diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 1732b6bd..a3472760 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -156,7 +156,7 @@ def test_attention(self): None, # mask is none for decode ) expected_out = attention_orig(*inputs_orig2) - cache_decode.pos = [pos] # next position to update + cache_decode.input_pos = [pos] # next position to update mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) @@ -318,7 +318,7 @@ def test_transformer_block(self): None, # mask is none for decode ) expected_out = block_orig(*inputs_orig2) - cache_decode.pos = [pos] # next position to update + cache_decode.input_pos = [pos] # next position to update mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 20056c5b..f53543d1 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -152,10 +152,8 @@ def update_env_data(env_data): key = jax.random.PRNGKey(123) key2 = jax.random.PRNGKey(456) - cache_k_jax = jax.random.normal(key, cache_shape, dtype=env.default_dtype) - cache_v_jax = jax.random.normal(key2, cache_shape, dtype=env.default_dtype) - # cache_k_jax = jnp.zeros(cache_shape, dtype=env.default_dtype) - # cache_v_jax = jnp.zeros(cache_shape, dtype=env.default_dtype) + cache_k_jax = jax.random.normal(key, cache_shape, dtype=env.default_type) + cache_v_jax = jax.random.normal(key2, cache_shape, dtype=env.default_type) start = jnp.zeros((batch,), dtype=jnp.int32) @@ -166,15 +164,15 @@ def update_env_data(env_data): cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (-3, -1)) # 1 is seqlen - xq = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_dtype) - xk = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_dtype) - xv = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_dtype) + xq = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) + xk = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) + xv = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) xq, xk, xv = torchjax.to_torch((xq, xk, xv)) def get_var(position: int): pos = [position] if env.ring_buffer else jnp.array([position] * batch, dtype=jnp.int64) - mask = jax.lax.broadcast_in_dim(jnp.array([0] * position + [float("-inf")] * (100 - position)), (env.batch_size, 100), (1,)) + mask = jax.lax.broadcast_in_dim(jnp.array([0] * position + [float("-inf")] * (100 - position)), (env.batch_size, 1, 1, 100), (3,)) mask = torchjax.to_torch((mask)) return pos, mask @@ -201,14 +199,14 @@ def update_finalize_record(in_attention, in_cache, in_q, in_k, in_v, in_layer, i env._data.quant_config.enable_kv_quantization = True env = environment.JetEngineEnvironment(env._data) - cache_int = cache_manager.KVCacheGenerate( + cache_int = cache_manager.Int8KVCacheGenerate( cache_k_int, cache_v_int, + cache_k_scaler, + cache_v_scaler, None, None, env, - cache_k_scaler=cache_k_scaler, - cache_v_scaler=cache_v_scaler ) # layer_id doesn't matter, will assign later attention_quant = layers.Int8KVAttentionKernel(env, layer_id=0) From 65b19a8b5695d10d81d669f704c741fe1fe11249 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Sat, 13 Jul 2024 04:39:46 +0000 Subject: [PATCH 29/57] Fix prefill cache insertion issue for stacked cache; Changes reduce dim for quantization from 1,3 to -3,-1 to make it more robust; --- jetstream_pt/cache_manager.py | 5 +++-- jetstream_pt/engine.py | 16 +++++++++++----- jetstream_pt/ray_worker.py | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 3a211850..8295c0e5 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -306,7 +306,8 @@ def __init__( none_pspec = self.env.partition_by_axis() in_specs = (*([cache_pspec] * 4), *([none_pspec] * 5)) out_specs = (cache_pspec, cache_pspec, none_pspec, none_pspec) - self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) + self.update_single_cache_line = shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False) + self.update_single_cache_line = jax.jit(self.update_single_cache_line) def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, k_scaler, v_scaler, new_k_scaler, new_v_scaler, pos): b = cache_k.shape[-4] @@ -364,7 +365,7 @@ def empty(cls, shape, device, env): def quantize(self, val): """Quantize value""" # val is (batch, heads, seqlen, dim) - scale = torch.amax(val.abs(), axis=(1, 3), keepdim=True) + scale = torch.amax(val.abs(), axis=(-3, -1), keepdim=True) scale = scale / 127 return (val / scale).to(torch.int8), scale diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index b383a2d1..bf59d3bf 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -368,7 +368,7 @@ def insert(cache, new_entry, update_index): @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, scaler, new_entry, update_index): - reduce_axis = (1, 3) + reduce_axis = (-3, -1) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) @@ -392,14 +392,16 @@ def insert(cache, scaler, new_entry, update_index): return res, new_scaler if self.env.generate_cache_stacked: + cache_k, k_scale = decode_state.caches[0][0], decode_state.cache_scales[0][0] + cache_v, v_scale = decode_state.caches[0][1], decode_state.cache_scales[0][1] for idx, (newk, newv) in enumerate(prefix.caches): update_index = [idx, slot, 0, pos, 0] #newk = jnp.expand_dims(newk, 0) #newv = jnp.expand_dims(newv, 0) - cache_k, k_scale = insert(decode_state.caches[0][0], decode_state.cache_scales[0][0], newk, update_index) - cache_v, v_scale = insert(decode_state.caches[0][1], decode_state.cache_scales[0][1], newv, update_index) - caches = [(cache_k, cache_v)] - scales = [(k_scale, v_scale)] + cache_k, k_scale = insert(cache_k, k_scale, newk, update_index) + cache_v, v_scale = insert(cache_v, v_scale, newv, update_index) + caches = [(cache_k, cache_v)] + scales = [(k_scale, v_scale)] else: update_index = [slot, 0, pos, 0] for (k, v), (kscaler, vscaler), (newk, newv) in zip( @@ -649,6 +651,10 @@ def update_mask(): mask = update_mask() next_token = self._sampling(logits, self.env.batch_size) + # print(f"current input pos: {decode_state.input_pos} and generated token is {next_token}") + # # for layer, (k,v) in enumerate(new_caches[0]): + # data = new_caches[0][0] * new_scales[0][0] if self.env.quant_config.enable_kv_quantization else new_caches[0][0] + # print(f"layer 0, scaled back k is {data}") if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 lens = decode_state.lens + 1 diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 01b647db..b473e05c 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -596,7 +596,7 @@ def insert(cache, new_entry): @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, scaler, new_entry): - reduce_axis = (1, 3) + reduce_axis = (-3, -1) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) From a87c6088e79a3ea0fb47f6d01513543337be0ea1 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 15 Jul 2024 19:53:39 +0000 Subject: [PATCH 30/57] Adds lazy cache update with generate cache stacked new cache unstacked for performance validation. --- jetstream_pt/cache_manager.py | 58 +++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 8295c0e5..d9c84856 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -298,7 +298,7 @@ def __init__( new_scale_dim = (layer, batch, 1, 1, 1) self.new_k_scaler, self.new_v_scaler = torchjax.to_torch((jnp.zeros(new_scale_dim, dtype=self.env.default_type), jnp.zeros(new_scale_dim, dtype=self.env.default_type))) else: - self.new_ks, self.new_vs = [], [] + self.new_ks, self.new_vs, self.new_k_scaler, self.new_v_scaler = [], [], [], [] else: # when generate cache is not stacked, new cache cannot stack assert not self.env.new_cache_stacked @@ -319,19 +319,36 @@ def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, k_scaler, v if self.env.new_cache_stacked: slice_dim = 1 update_start_indices = (0, bb, 0, pp, 0) - # We are not handling generate_cache_stacked=True new_cache_stacked=False here - - new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) - cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) - - new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) - cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) - - new_k_scaler_slice = jax.lax.dynamic_slice_in_dim(new_k_scaler, bb, 1, slice_dim) - k_scaler = jax.lax.dynamic_update_slice(k_scaler, new_k_scaler_slice, update_start_indices) - - new_v_scaler_slice = jax.lax.dynamic_slice_in_dim(new_v_scaler, bb, 1, slice_dim) - v_scaler = jax.lax.dynamic_update_slice(v_scaler, new_v_scaler_slice, update_start_indices) + if self.env.generate_cache_stacked and not self.env.new_cache_stacked: + for slice in range(self.env.num_layers): + update_start_indices = (slice, bb, 0, pp, 0) + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks[slice], bb, 1, slice_dim) + new_ks_slice = jnp.expand_dims(new_ks_slice, 0) + cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) + + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs[slice], bb, 1, slice_dim) + new_vs_slice = jnp.expand_dims(new_vs_slice, 0) + cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim(new_k_scaler[slice], bb, 1, slice_dim) + new_k_scaler_slice = jnp.expand_dims(new_k_scaler_slice, 0) + k_scaler = jax.lax.dynamic_update_slice(k_scaler, new_k_scaler_slice, update_start_indices) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim(new_v_scaler[slice], bb, 1, slice_dim) + new_v_scaler_slice = jnp.expand_dims(new_v_scaler_slice, 0) + v_scaler = jax.lax.dynamic_update_slice(v_scaler, new_v_scaler_slice, update_start_indices) + else: + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) + + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim(new_k_scaler, bb, 1, slice_dim) + k_scaler = jax.lax.dynamic_update_slice(k_scaler, new_k_scaler_slice, update_start_indices) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim(new_v_scaler, bb, 1, slice_dim) + v_scaler = jax.lax.dynamic_update_slice(v_scaler, new_v_scaler_slice, update_start_indices) return cache_k, cache_v, k_scaler, v_scaler @@ -421,15 +438,18 @@ def finalize(self): if self.env.generate_cache_stacked: layer, b, head, _, dim = self.cache_k.shape if self.env.new_cache_stacked: + # new kv scaler also has to go through shard_map instead of indexing because it needs to reshape to (batch, layer) which mess up with the data caches = [self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.k_scaler._elem, self.v_scaler._elem, self.new_k_scaler, self.new_v_scaler] self.cache_k._elem, self.cache_v._elem, self.k_scaler._elem, self.v_scaler._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, *caches, self.input_pos) else: # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. - for i in range(self.env.num_layers): - self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) - self.k_scaler._elem = self.k_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_k_scaler[i]._elem.reshape(b, 1, 1)) - self.v_scaler._elem = self.v_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_v_scaler[i]._elem.reshape(b, 1, 1)) + caches = [self.cache_k._elem, self.cache_v._elem, self.new_ks, self.new_vs, self.k_scaler._elem, self.v_scaler._elem, self.new_k_scaler, self.new_v_scaler] + self.cache_k._elem, self.cache_v._elem, self.k_scaler._elem, self.v_scaler._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, *caches, self.input_pos) + # for i in range(self.env.num_layers): + # self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) + # self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) + # self.k_scaler._elem = self.k_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_k_scaler[i]._elem.reshape(b, 1, 1)) + # self.v_scaler._elem = self.v_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_v_scaler[i]._elem.reshape(b, 1, 1)) else: # Try to use shard_map to get rid of the data copy b = self.cache_k.shape[-4] From 3170ef2c873adc29938929c1f1fe6f52cdbc24dc Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 15 Jul 2024 20:19:47 +0000 Subject: [PATCH 31/57] Fix the shard map sharding for stacked generate cache and unstacked new cache. --- jetstream_pt/cache_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index d9c84856..d9456eeb 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -303,8 +303,9 @@ def __init__( assert not self.env.new_cache_stacked cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads + new_cache_pspec = self.env.partition_by_axis(2) if self.env.new_cache_stacked else self.env.partition_by_axis(1) none_pspec = self.env.partition_by_axis() - in_specs = (*([cache_pspec] * 4), *([none_pspec] * 5)) + in_specs = (*([cache_pspec] * 2), *([new_cache_pspec] * 2), *([none_pspec] * 5)) out_specs = (cache_pspec, cache_pspec, none_pspec, none_pspec) self.update_single_cache_line = shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False) self.update_single_cache_line = jax.jit(self.update_single_cache_line) From ee1c0113a1ce9327cf1b3cd53d865aec1e566621 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 15 Jul 2024 20:59:32 +0000 Subject: [PATCH 32/57] Using Jax API to slicing instead of Pytorch index slicing. --- jetstream_pt/cache_manager.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index d9456eeb..dc619a12 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -421,10 +421,15 @@ def update(self, xk, xv, layer_id:int): self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2) self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2) - ret_cache_k = self.cache_k[layer_id] if self.env.generate_cache_stacked else self.cache_k - ret_cache_v = self.cache_v[layer_id] if self.env.generate_cache_stacked else self.cache_v - ret_k_scaler = self.k_scaler[layer_id] if self.env.generate_cache_stacked else self.k_scaler - ret_v_scaler = self.v_scaler[layer_id] if self.env.generate_cache_stacked else self.v_scaler + # ret_cache_k = self.cache_k[layer_id] if self.env.generate_cache_stacked else self.cache_k + # ret_cache_v = self.cache_v[layer_id] if self.env.generate_cache_stacked else self.cache_v + # ret_k_scaler = self.k_scaler[layer_id] if self.env.generate_cache_stacked else self.k_scaler + # ret_v_scaler = self.v_scaler[layer_id] if self.env.generate_cache_stacked else self.v_scaler + + ret_cache_k = torch_xla2.interop.call_jax(jax.lax.dynamic_index_in_dim, self.cache_k, layer_id, 0, False) if self.env.generate_cache_stacked else self.cache_k + ret_cache_v = torch_xla2.interop.call_jax(jax.lax.dynamic_index_in_dim, self.cache_v, layer_id, 0, False) if self.env.generate_cache_stacked else self.cache_v + ret_k_scaler = torch_xla2.interop.call_jax(jax.lax.dynamic_index_in_dim, self.k_scaler, layer_id, 0, False) if self.env.generate_cache_stacked else self.k_scaler + ret_v_scaler = torch_xla2.interop.call_jax(jax.lax.dynamic_index_in_dim, self.v_scaler, layer_id, 0, False) if self.env.generate_cache_stacked else self.v_scaler return ret_cache_k, ret_cache_v, k_quant, v_quant, ret_k_scaler, ret_v_scaler, kscale, vscale From e08f31ff52bd88c9cbaddf058f07a85454606775 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 16 Jul 2024 00:55:16 +0000 Subject: [PATCH 33/57] Adds stacked cache support in ragged attention reference kernel. --- jetstream_pt/attention_kernel.py | 52 +++++++++++++++++++++++--------- jetstream_pt/cache_manager.py | 12 +------- jetstream_pt/layers.py | 2 ++ 3 files changed, 40 insertions(+), 26 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 60a9bef9..b3967e33 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -214,6 +214,7 @@ def scaler_index_map(b, i, *_): def ragged_mqa_kernel_reference( + layer_ref, start_ref, end_ref, line_end_ref, @@ -293,6 +294,7 @@ def ragged_mqa_reference( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, ragged_batch_index=None, @@ -304,10 +306,14 @@ def ragged_mqa_reference( normalize_var: bool = True, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" - batch_size, num_heads, head_dim = q.shape + batch_size, time, head_dim = q.shape #assert end.shape == (batch_size,) assert end.dtype == jnp.int32 - seq_len = k.shape[1] + seq_len = k.shape[-2] + + stacked = False + if k.ndim == 5: + stacked = True def _compute_ragged_block_indices(b, i, lengths_ref): length = lengths_ref[b] @@ -327,20 +333,35 @@ def _compute_ragged_block_indices(b, i, lengths_ref): ) return b_next, i_next - def kv_index_map(b, i, start_ref, + def kv_index_map(b, i, layer_ref, start_ref, end_ref, line_end_ref, ragged_batch_index_ref, ragged_block_index_ref): b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) + if stacked: + return layer_ref[0], b_next, i_next, 0 return b_next, i_next, 0 + if stacked: + q_bp = (None, None, time, head_dim) + kv_bp = (None, None, bk, head_dim) + ks_bp = (None, None, 1, bk) + num_prefetch = 6 + else: + q_bp = (None, time, head_dim) + kv_bp = (None, bk, head_dim) + ks_bp = (None, 1, bk) + num_prefetch = 5 + in_specs = [ - pl.BlockSpec(kv_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(kv_index_map, q_bp), + pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(kv_index_map, kv_bp), ] + inputs = ( + layer, start, end, end, # line_end, not actually used @@ -353,8 +374,8 @@ def kv_index_map(b, i, start_ref, quantized = False if k_scaler is not None: in_specs = in_specs + [ - pl.BlockSpec(kv_index_map, (None, 1, bk)), - pl.BlockSpec(kv_index_map, (None, 1, bk)), + pl.BlockSpec(kv_index_map, ks_bp), + pl.BlockSpec(kv_index_map, ks_bp), ] inputs = inputs + (k_scaler, v_scaler) quantized = True @@ -368,12 +389,12 @@ def kv_index_map(b, i, start_ref, quantized=quantized, ), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=5, + num_scalar_prefetch=num_prefetch, in_specs=in_specs, out_specs=[ - pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, num_heads, head_dim)), - pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, num_heads, head_dim)), - pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, num_heads, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), ], grid=(batch_size, seq_len // bk), ), @@ -381,10 +402,10 @@ def kv_index_map(b, i, start_ref, out_shape=[ q, jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 + (batch_size, time, head_dim), jnp.float32 ), jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 + (batch_size, time, head_dim), jnp.float32 ), ], )(*inputs) @@ -397,6 +418,7 @@ def ragged_mha( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, ragged_batch_index: jax.Array, @@ -466,7 +488,7 @@ def ragged_mha( *([None] * replicated_in_axes), ), out_axes=shard_axis - )(q, k, v, start, end, *replicated_inputs) + )(q, k, v, layer, start, end, *replicated_inputs) return out, (m, l) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index dc619a12..b3c4fbc9 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -421,17 +421,7 @@ def update(self, xk, xv, layer_id:int): self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2) self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2) - # ret_cache_k = self.cache_k[layer_id] if self.env.generate_cache_stacked else self.cache_k - # ret_cache_v = self.cache_v[layer_id] if self.env.generate_cache_stacked else self.cache_v - # ret_k_scaler = self.k_scaler[layer_id] if self.env.generate_cache_stacked else self.k_scaler - # ret_v_scaler = self.v_scaler[layer_id] if self.env.generate_cache_stacked else self.v_scaler - - ret_cache_k = torch_xla2.interop.call_jax(jax.lax.dynamic_index_in_dim, self.cache_k, layer_id, 0, False) if self.env.generate_cache_stacked else self.cache_k - ret_cache_v = torch_xla2.interop.call_jax(jax.lax.dynamic_index_in_dim, self.cache_v, layer_id, 0, False) if self.env.generate_cache_stacked else self.cache_v - ret_k_scaler = torch_xla2.interop.call_jax(jax.lax.dynamic_index_in_dim, self.k_scaler, layer_id, 0, False) if self.env.generate_cache_stacked else self.k_scaler - ret_v_scaler = torch_xla2.interop.call_jax(jax.lax.dynamic_index_in_dim, self.v_scaler, layer_id, 0, False) if self.env.generate_cache_stacked else self.v_scaler - - return ret_cache_k, ret_cache_v, k_quant, v_quant, ret_k_scaler, ret_v_scaler, kscale, vscale + return self.cache_k, self.cache_v, k_quant, v_quant, self.k_scaler, self.v_scaler, kscale, vscale def finalize(self): if not self.env.lazy_cache_update: diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index b716fc17..6f7b81eb 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -432,6 +432,7 @@ def attend(xq, keys, values, local_mask=None): xq, keys, values, + self.layer_id, start, end, ragged_batch_index, @@ -549,6 +550,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): xq, keys, values, + self.layer_id, start, end, ragged_batch_index, From b8e6b857a8c23ec353db964a6807a98baafb5e62 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 16 Jul 2024 01:18:09 +0000 Subject: [PATCH 34/57] Adds stacked cache support for the modified ragged kernel. --- jetstream_pt/attention_kernel.py | 56 ++++++++++++++++++++++++-------- jetstream_pt/layers.py | 16 +++++---- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index b3967e33..6665411f 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -16,6 +16,7 @@ def ragged_flash_attention_kernel( + layer_ref, start_ref, end_ref, line_end_ref, @@ -111,6 +112,7 @@ def ragged_mqa( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, ragged_batch_index=None, @@ -123,12 +125,17 @@ def ragged_mqa( ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" with jax.named_scope("ragged_mqa"): - batch_size, num_heads, head_dim = q.shape - seq_len = k.shape[1] + batch_size, time, head_dim = q.shape + seq_len = k.shape[-2] + + stacked = False + if k.ndim == 5: + stacked = True def kv_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -136,11 +143,15 @@ def kv_index_map( ragged_block_index_ref, ): index = b * (seq_len // bk) + i + + if stacked: + return layer_ref[0], ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 def q_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -148,19 +159,36 @@ def q_index_map( ragged_block_index_ref, ): index = b * (seq_len // bk) + i + if stacked: + return layer_ref[0], ragged_batch_index_ref[index], 0, 0 return ragged_batch_index_ref[index], 0, 0 - def scaler_index_map(b, i, *_): + def scaler_index_map(b, i, layer_ref, *_): + if stacked: + return layer_ref[0], b, 0, i return b, 0, i line_end = jnp.where(start < end, end, seq_len - 1) + + if stacked: + q_bp = (None, None, time, head_dim) + kv_bp = (None, None, bk, head_dim) + ks_bp = (None, None, 1, bk) + num_prefetch = 6 + else: + q_bp = (None, time, head_dim) + kv_bp = (None, bk, head_dim) + ks_bp = (None, 1, bk) + num_prefetch = 5 + in_specs = [ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(q_index_map, q_bp), + pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(kv_index_map, kv_bp), ] inputs = ( + layer, start, end, line_end, @@ -173,8 +201,8 @@ def scaler_index_map(b, i, *_): quantized = False if k_scaler is not None: in_specs = in_specs + [ - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), + pl.BlockSpec(scaler_index_map, ks_bp), + pl.BlockSpec(scaler_index_map, ks_bp), ] inputs = inputs + (k_scaler, v_scaler) quantized = True @@ -188,12 +216,12 @@ def scaler_index_map(b, i, *_): quantized=quantized, ), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=5, + num_scalar_prefetch=num_prefetch, in_specs=in_specs, out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), ], grid=(batch_size, seq_len // bk), ), @@ -203,10 +231,10 @@ def scaler_index_map(b, i, *_): out_shape=[ q, jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 + (batch_size, time, head_dim), jnp.float32 ), jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 + (batch_size, time, head_dim), jnp.float32 ), ], )(*inputs) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 6f7b81eb..72b87580 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -426,7 +426,8 @@ def __call__( def attend(xq, keys, values, local_mask=None): # As of right now, ragged attention doesn't support attention calculation with prefill and new cache line - if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] != 1: + # We are not using ragged attention for prefill yet. + if self.env.ragged_mha and seqlen == 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, xq, @@ -464,7 +465,8 @@ def attend(xq, keys, values, local_mask=None): with jax.named_scope("attn_insert_cache"): orig_keys, orig_values = cache.update(xk, xv, self.layer_id) - if not self.env.ragged_mha: + # We are not using ragged attention for prefill yet. + if not self.env.ragged_mha or seqlen > 1: orig_keys = repeat_kv(orig_keys, n_rep) orig_values = repeat_kv(orig_values, n_rep) @@ -479,7 +481,7 @@ def attend(xq, keys, values, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): - if not self.env.ragged_mha: + if not self.env.ragged_mha or seqlen > 1: xk = repeat_kv(xk, n_rep) xv = repeat_kv(xv, n_rep) new_output, (new_max, new_denom) = attend(xq, xk, xv, None) @@ -544,7 +546,8 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): if not self.env.ragged_mha and seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - if self.env.ragged_mha: + # We are not using ragged attention for prefill yet. + if self.env.ragged_mha and seqlen == 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, xq, @@ -582,7 +585,8 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): #import pdb; pdb.set_trace() with jax.named_scope("attn_insert_cache"): orig_keys, orig_values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) - if not self.env.ragged_mha: + # We are not using ragged attention for prefill yet. + if not self.env.ragged_mha or seqlen > 1: orig_keys = repeat_kv(orig_keys, n_rep) orig_values = repeat_kv(orig_values, n_rep) with jax.named_scope("attn_qkv"): @@ -594,7 +598,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): - if not self.env.ragged_mha: + if not self.env.ragged_mha or seqlen > 1: new_key = repeat_kv(new_key, n_rep) new_value = repeat_kv(new_value, n_rep) new_output, (new_max, new_denom) = attend(xq, new_key, new_value, new_k_scaler, new_v_scaler, None) From 394e666003b830fa3e1efdbbe9b5175be10be559 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 16 Jul 2024 22:36:41 +0000 Subject: [PATCH 35/57] Llama2 70b int8 optimization done. Output not correct yet. --- jetstream_pt/.attention_kernel.py.swp | Bin 0 -> 36864 bytes jetstream_pt/attention_kernel.py | 64 +++++++++++------------ jetstream_pt/layers.py | 71 ++++++++++++++++++++------ 3 files changed, 83 insertions(+), 52 deletions(-) create mode 100644 jetstream_pt/.attention_kernel.py.swp diff --git a/jetstream_pt/.attention_kernel.py.swp b/jetstream_pt/.attention_kernel.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..7ac8ff387f9864d2ad66bb2cf20c1d4a8c985251 GIT binary patch literal 36864 zcmeI539w{Ud4L;blf?y03KgltdoKNkem(PM*z%Yj9fgT3%D~L1XF=bMUXe<)ZiUNvDVJT&hvXT;`k%|?>U=+$NMoYnkLW4-Y|19_1+ppg; zqam1d!}q%Hz2~0gKj)t9|L;GgO|QK0wXELY#{GHQRocf~7{ol%EYC)qM zE(DDkxwguuJ25}r7@t1p?1@sidCU01_@=>uUl&KXaVF|kd!>#0!fvPA4ukr5YpL68 zm#h9vvDMuebh}}r8#Non*|6ORYx&mF;FyPJhY}b{U{wk9y3^z5o|xIPc|t#(qa&xQ zlTX^S>Mg^ILkSEeFqFVh0z(N5B`}o0Py#~<3?=Y?O9I{2F`26<(vxhl_S^3x`o530 z$8WXY2l}obV~^+U_t7>27tfLQ{G9#1vhVs)_Poo-kNU13Yp<8=_mBIoKiQto+3%e; zBA32BJ^UIBd=O^g z3OE~%hx?DqWNw8w!5g6jC&N?Vfv03L{|L9hn<0Sn;SBg8n-I6d)vyn?!O`#oHb1U` z^I#H=furH;Y?|B#pN4DTPvHv4!f9|U+{?zthoB5+!-?=+HcM`WC0K-)LJq#mCdU`y z!*Da)1lPiDxEx*ro8UA!0giwlQm^;GZSY>0hf85ANS!|ok{`vV9&~1l^Fgf_PO9lz zGw7;mA~_f0#_tF#2Eq`6sLK_VDwTRGi>OpWOE&ohEX{Pn^7vKgbDBIiBkf+J*lA98xnX2%(nm&W1;w8;5-m}l)SPLB z{-I4~fa|sjOw-k!jM`dI^gvXm^rre$vDPj2pPJF9ynIJ>_=K#=%|<1gjWWW9mrxQ3L`}MaZL}KcGjo0Cw@TsS8cvkaBxmJOYtBTB z1;u9DcN71^=^GjMr5_-}IJ3H30~#Z(LoQ`wx#VaG?jfA;LZaYA&tLuYl&0c{O)Xi0 zr7y1B3Pp6%1XcWTZE%BfL*#lNQzcDTx8IGr{dVTs-K;Ib-th^Z8kY!H9De zb}pxA``K)EBF8f>f8{GHf5zk|&-`@QZX>#|7S=J=EXr6^5E1H`!WBH{9auM79u-#<5k7n)E!pko;p!WsU*L+Z4X{4bfUG#3;wh7=d z(jOTa*%{PoK}Q)oWHsGvGcyi4RTbZmCs4-LvL8WJi8-8^HKw1p#05X7*y*+jZ)eu# zK#Mr1)F@9}v(=noPT6dzax-|wovu?WUP85-3k(m7`mC;HeQ&eZZS}exeKt7- zFli)v8*@fKTkaxHqWP(x4fB)7zZglg@`ojVm~3|@fNg~4KsQ_yIx&aVVlBpEwU*y9 z`(3^Sp{~a_<@=ms?zC)$-_J!mPLLVRL{8(OAC z+0aysq}8hNWAF_wUVEPo>~bdEYFr{7uEhDcboHQHb?4*NMMl>DBdm#k46^>u*~Q?e zSku2B{syjxC0K+7cn$mke2=yJCqdTsm%|3Qn|1nK@R#s*cs=xB7K$M2`Sal%$iYeQ z3)b%6hA+cEgRJjA0N24=;Q-WN3U9l)z8|LkSEeFqFVh z0z(N5B`}mgLIR49TXd^<7uYPvADA7)jHDo+gt|oz zwWg`#L(hgwiKE3ibEO@2d+mnmgth5>N%TN$XBgE`y^7X_a>d)YZfToP6#1g1POe-! zcV&6m5+zHOD^;p=nUW=P|M_1QS7kG6gAioD0W(5xsy|7d%zu)t05Ye3+`HC02W}DS# z!PilaUFCwLY*cw2Gz@I87R{YnkB~fI-;rw6YifNhb+o#S;>E33ojD=P77{|%k{2ng zDu7k8qxX5(@VQ*KK)(rO_lNXeP?`Oqjxr&(yBFCn*6*>?QYnIM(M|SJWhb!PVfSD$ zM{{U2>ro?Mw?S`fIazK@UeOXz%IsK#QA2bwBcwlKu8EtRR`rp~-KgRn0r|>oSBHRJj!a*p5>;YT`FN5=86Z{_h zl0ASg!wv9WXu}j-0o&mD@M{wM5qt?g09V67xD@j640wdK{YM~xXTmY?b=LS_footM zBDfHq0Y}1j(BIz-*TXyE%}{|dTnayYN+$E4@Ne*WxD{@IcflK=0T)9Sj(~g7?SCHL z0|y|09dHsnh<^Y3@Byg6`S4`;Hah*!z_n0@li?9``ZvNG;T$*xM9+T*I{hAO2GQ>y z2R}n+|1}W3{$D^BUIW`;1Rg_Q|L^csxC7n~Z-yR}VGo=Mx<2`K?@7oYpyZ=>xASeD zew2&Z5*z&0S)K!)^PK8(RGAwac$G~yds)XbYWJ;LmC|{NKlx6`-hVqNhgoT`=og4l zZBrAXS&*LvRh?5CH%e-p;V)-m)x;K2qUgMk1CtDw=Zx}2F3I*i4@I_wHX^=jvsP$~ zs#k;ysvOW?sVci--NFQSB*!_~ZSKtV!f=0>bz=jY&}c1nKna`-5)XpuAZ|oBl#!5v z*HD0UMYSY7ZHHo*9x`>M7%YZz=^XMC($^YoOe%=G+2tZJE;;Ligo%XdGg1+_G|nOo zYR45)MIEhCl_dQ}GNMFSsqEg8&Xnoc*^Mk#!=S>jQf}7CeJC1k z^!QU_c^ySiYc^(--i7Z{b98-5_k!z>N1->-QSIh_Gj^%wwC{}9mm$fdWNk{fhRgLV zqvOp^6VI!dqhJZ6t&(U3m5SUL zHM|t^>Lp7u#Rz(}u8R1-qXIP_wwGM8O)K|u&}wnBNk~@(bCOoNY{nX>Tg4=ZYaMfE zN6klp>IRK!DlH1w4oOa@BgzbfM0vuQP$WYPAdu=5f{BAJ|s-8EoxSVWmdPwCo< z&xC5WtsqxVK+69cC4kXUUj(V=Dv{j*xfFqJZ_ns0y&Isq~e2dLnjDlUgZ8-?O{jvqoac25%bQ zpq|O3^1on;mNEoWe3p05%Yx`;9ZSz--L~SsvxyZ2?X&OVp_rw~0?7;s63JnkrD3O* ztVyw3OI7+xc58mB+ry3|TQ_z*o2Jm|yPuM~vFrvjNSdrFHRZ==t|kve)Y;^4BZg}J z`DA`81G?4y5w-hB@ps9@=X{wA@9MUD<*qjjd7Kil3tA;(gDi2dm)L9bD35f6>0}Y3 zzo6bJ92j9_qMM@IDt5xU*kzSFBawfLb`)1rT}bdkb2ho8M_e?9w;69 zGZIl>ql--k3h^%&tAtJUmxa_(GMMi+3;t&^084SE(iw@;d~C;s1*_G#i(;MuoPpT| z8@R4UbeOpDB-3q+m4#Mn>&=X;S#ZR~u>E>Q(u*g<_Qx42>n5ZuO|4L}OcLP*<^JE&h3il83QxhAdMzY2Y%_B@5(sMqH0I?_?FrHh$RM z$gq6*;I1U5^duu6*=oX=!`V?VDIp(0x!Xg^OM2Ru(cyGJ`!lsCCB0xE2H4rXdN47Q z+*tZlk?8+F%DVkuLG=Gml7%84S^w*E=5Y9TD1o5_h7uS`U?_p11cnkAN?<5~p#+8! z7)oF$fuRJRU=mP9ttOkD>6ad3dzs=q+a~(OM_K>Nn*SpC!}|ZNwk-Fv{=XTnhS$P5 z@LV_*j)kA2$G;E033tNZ!^h!zcpEH&yccjTOu+GQ9Q*>^{`+7L?1mS?nJ@;w4@bcR z==$%1Ti}E62570zL=V z!#wPPi{RPtF!ljogFE0LOv1z127DiGgIhp+16%}W!Rc@kJczBpC*kklO>i-6g$Wpe zN3j+76+8m>g7^*i4BP_mg}1>~Aoc|j{1IFNldut<2hRfWBk&MyLu?K{5B~tyK^>+* zw;SDNlzK7SZu;dm_fnRa-PDY^mux4pL&b*iy5G=3QT2p)>q>8|nOCdyPxC&N{^_0kK4xdXCAkgs3dv4%V*L~yi+x3S2s(gty3{GKsD2+oax)dmYf`(f(A>BHB>Z@ zYf|ylsRVSHFZJ5s`x7OFG z*m)vEF<6bWh9YOMD(5V>k+0ODDRY)pO{^{$tk7AvN+)Tavw~7b?45Y0l=?uQ^*SXsXYc^KmBRz1LbG8f_gqH`}#Lh-I)7B)Kbzv7o&s_I}Qu#RZYr zeTY7dvlfK6p`$WHHRyvIHT7yqnhS0=Hr93MhR7{y&9g+`G$)((c7%iVx z^*OElxZ)Lc9aTI+8%=v;PsXQ>iHJy3pO9jbnsdtt}Q04l{cH7b!Gb_ z67sgogB?xS3Swp+yO_A#yi8!M^00|)z{&;!@J;GsKTpYJDd&~xSfjn2)rID zP=-B_ho{2P@LksAUxb@M{Qd2MEIdm6>N+MWKJu(TuJZ$LIHpIP_1cc5rVp~VGTu8@ zc<(iG1ZyM zk}o47^EvG+BiBkyl_tha(R7+V6)f#}d9o{SGAuOPvkLD)+|U#w{e`F#nmp)o+8JwF zVv-*=X2fe@C7RDVCsoYb@y=)7x6uh&pHo|z(d$dCAhzcsEztb5=(h&|dA8O9z(K8JW1j9VfpsMI4nMcP5yXd-|%2$ldzdtIqgFBv48GIvHW ze!9_b-v4+_BorTSYdUls%UHC`W`8xuvC?dfy=``o_g8uP4%=|wxyy`QFX}Dqe!sDY z?-_Q0yV;a$Ib&k){ld{9>!(!voL$8~?BV)`JruvN>+uP@E`PA=_5~X|EOkBVA^Cx& zCtzD{y-r|tX!fuvLtNA0 zYb?7#&R(7K%XK(!6(LG*<;N6mvIAUP9cwrQRnA$XTkyN|S909BwshHMj2(E2zcyZ5 z?epM?Ir#QkG}ACD-nw{e#xY?%(3L|1>9wB8nVAqsk>`Fp?mX9U<#{M>JlE;MbDi!x z*X6oXrUm%_HLfHRI!o>L*3~J^6K$HFQP4*0G}cTjqy`p(rH;%17}d2qKz7Ay;WWni zQa7QzmWtWFyE`$zkQlfqPq0i=YzHV!f3WM#R6p4lEa7(3cdcYrf=#dHa#IB0oN3DX zpJi+2CcFMG$?A~LcUa$l0p#6)x4{8;HEe_v;0SntHU8aj8;IV&168;Po(Ye!zL$6X zAB7LYA}qiwVFI2Fx3RXb!%jFJe#(0OOK>aH;bm|lJjh!9Cb$vAZ+{gEa01-L`u($T zJ1oLW;S_j`wfn8`G59F_HSC50{E|X{6TSiOgAmSwM_H%;EnEke!6rBrPJy4YR{se+ z2!94nFQ4J}uUMZy0{;Qm!`q<=r@?dKKGx=U!s|fR?In02Jj~ksub>Juuo1pOU3>zh zPHvQXF&`QYwJT%k8q46M*9Y!M9eUs2@mu~*iEZAt*DFt^>-^t$!Mtvta8J<}ce2Z; zp9#Yj+iO|}Iwoy?G@X$Z^Cdy!F(+O&$V!H{$WaEgS?l?A3yv?$``OiiSEsvGd8xRG zOO1tURIc)pGVd?dI;OKsDIDpGRedV|;2?F~6) zeW4F1y1Sd(Z0u*~(T8YtN1?rcBrnP%Y?K5tNU>ZM@Au2!$RFD~Y5k!S6P>>v(U5=$ zOZFvX2=@-fiIbNG8#7^6{L$Fj=6(FV5_Bpa%{o3(Znyc3Gw#Q>xcxYjhh^=j?}BMR zB|HLDp#~LkdXNsKa~xMcN6WLYk#fdEqc4g0w_as*95^<*kx^WyRJ`jH`3D*n=d{C= ztf^=FD(RS1F0zxiNBgdj)mT<3w=7&(-?}Ix{c^vaHlB{*w#6jwe7A69UzGDts_`~& zTpXpR0_HhfzvCIN+8pF}W|SF!;&Dni=BKpsNlD^aQqE3~SOFg}-=@7C zOgx#&UJ*w+?rI{N>OkV);D}Uy;si}G*(a>ln;m-Zv8?m60vg=kB;@(HcReO!%Q_$5 zgxAyQB%TFZXTN(7t?4+|{&#R_MeMok*4{W8?5_o0YvbI7vuka<<7>CwT9-56Ss(C( zi_}+rn&VAUY+o22$qUbo{z-FAKZMcjzIra2t=Go4@v}x3$G?SvA0XGCn7(jSjmemi zm2SFGL@CL?-XIEi4%yveCfl!%AC)9W)`~nMp+;B88W#sE$xweac$T}HwW+E^!BSRd zn$p=H112CY9mfalA~oJ{vCPi#2 z4z)6?7C|C1YQ8}wFmr*mq)wNH3jnDBQwRyqreHe^X0@Lf?RH6`*ZZ*2a7+&<|AYYl zZUg_F&Hr{oe}%cipxmL3cjHVv;;$`VcgrgBUDi5lzIE310BmY+jIqWK_-SL1QP&Fk z8oSJ&j8`<=BSiaV*A=yPZr2w-YdxoeHOsxssufK>Crm|a5fzy#`sal<%cMf@@5+4m zFgp!KtPXD6T&@Om8cz1M<8V#g9>!gYY;}_jt7Osd@n_{le$%|hvNGMEZ{Upn*Rf3T z#djd@q;E*MmJ;7ou2JglNBf#8Bz=XDkohED(w7q(vUc{lL7aqnUs*bazOX2d5FSqU zWuomBPZG^4my)^gH=$P)pIZg*{G%^1-B%g+P6e?m)LxzZ{gP;Et7T5Ae;Gb tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi head attention. Args: @@ -481,10 +472,10 @@ def ragged_mha( """ mask_value = DEFAULT_MASK_VALUE if k_scaler is None: - replicated_in_axes = 4 + replicated_in_axes = 5 replicated_inputs = (ragged_batch_index, ragged_block_index) else: - replicated_in_axes = 6 + replicated_in_axes = 7 replicated_inputs = ( ragged_batch_index, ragged_block_index, @@ -494,10 +485,11 @@ def ragged_mha( # New cache has t=1 bk = min(bk, k.shape[-2]) bq, hq, tq, dq = q.shape - hkv = k.shape[1] + hkv = k.shape[-3] rep = hq // hkv if rep > 1: - q = q.reshape(bq, hkv, rep * tq, dq) + #import pdb; pdb.set_trace() + q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( @@ -510,12 +502,12 @@ def ragged_mha( # out_dtype=out_dtype, ), in_axes=( - shard_axis, - shard_axis, - shard_axis, + q_shard_axis, + kv_shard_axis, + kv_shard_axis, *([None] * replicated_in_axes), ), - out_axes=shard_axis + out_axes=q_shard_axis )(q, k, v, layer, start, end, *replicated_inputs) return out, (m, l) @@ -607,13 +599,12 @@ def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, n class RaggedAttentionKernel: """Ragged attention kernel.""" - - def __init__(self, env, input_specs, output_specs, sharding_axis): + def __init__(self, env, input_specs, output_specs, q_shard_axis, kv_shard_axis): self.binded_ragged_mha = functools.partial( - ragged_mha, bk=env.block_size, shard_axis=sharding_axis + ragged_mha, bk=env.block_size, q_shard_axis=q_shard_axis, kv_shard_axis=kv_shard_axis ) self.binded_ragged_mha = shard_map( - ragged_mha, env.mesh, input_specs, output_specs, check_rep=False + self.binded_ragged_mha, env.mesh, input_specs, output_specs, check_rep=False ) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) @@ -622,17 +613,20 @@ def __call__( xq, keys, values, + layer, start, end, ragged_batch_index, ragged_block_index, k_scaler=None, v_scaler=None, + ): return self.binded_ragged_mha( xq, keys, values, + layer, start, end, ragged_batch_index, diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 72b87580..1ec23d50 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -368,9 +368,20 @@ def apply_rotary_emb( def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep).""" - bs, n_kv_heads, slen, head_dim = x.shape + bs, n_kv_heads, slen, head_dim = x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1] + if x.ndim == 5: + stacked = True + else: + stacked = False if n_rep == 1: return x + if stacked: + layer = x.shape[0] + return ( + x[:, :, :, None, :, :] + .expand(layer, bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(layer, bs, n_kv_heads * n_rep, slen, head_dim) + ) return ( x[:, :, None, :, :] .expand(bs, n_kv_heads, n_rep, slen, head_dim) @@ -382,16 +393,17 @@ class AttentionKernel: def __init__(self, env, layer_id): self.env = env - self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = 0 if self.env.shard_on_batch else 2 if self.env.generate_cache_stacked else 1 + q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads + kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention self.flash_attention = ak.flash_attention self.ragged_attention = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), - output_specs=(qkv_pspec, (qkv_pspec, qkv_pspec)), - sharding_axis=self.shard_axis, + input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), ) self.layer_id = layer_id @@ -416,7 +428,8 @@ def __call__( cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, num_kv_heads, _, kv_head_dim = xk.shape + num_kv_heads = xk.shape[-3] + kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads # if not self.env.ragged_mha and seqlen == 1: @@ -427,6 +440,12 @@ def __call__( def attend(xq, keys, values, local_mask=None): # As of right now, ragged attention doesn't support attention calculation with prefill and new cache line # We are not using ragged attention for prefill yet. + kv_shard_axis = self.kv_shard_axis + if self.kv_shard_axis > 0: + if keys.ndim == 4: + kv_shard_axis = 1 + else: + kv_shard_axis = 2 if self.env.ragged_mha and seqlen == 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( self.ragged_attention, @@ -438,6 +457,8 @@ def attend(xq, keys, values, local_mask=None): end, ragged_batch_index, ragged_block_index, + self.q_shard_axis, + kv_shard_axis, ) elif self.env.flash_attention: with torch_xla2.default_env(): @@ -459,7 +480,7 @@ def attend(xq, keys, values, local_mask=None): # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") # if local_max is not None and local_denom is not None: # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") - self.env.apply_sharding(local_output, axis=self.shard_axis) + self.env.apply_sharding(local_output, axis=self.q_shard_axis) return local_output, (local_max, local_denom) @@ -505,16 +526,26 @@ class Int8KVAttentionKernel: def __init__(self, env, layer_id): self.env = env - self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = 0 if self.env.shard_on_batch else 2 if self.env.generate_cache_stacked else 1 + q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads + kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention self.flash_attention = ak.flash_attention_quantized - self.ragged_attention = ak.RaggedAttentionKernel( + self.ragged_attention_orig = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), - output_specs=(qkv_pspec, (qkv_pspec, qkv_pspec)), - sharding_axis=self.shard_axis, + input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.kv_shard_axis, + ) + self.ragged_attention_new = ak.RaggedAttentionKernel( + env, + input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.q_shard_axis, ) self.layer_id = layer_id @@ -539,17 +570,22 @@ def __call__( cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, num_kv_heads, _, kv_head_dim = xk.shape + num_kv_heads = xk.shape[-3] + kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): + if keys.ndim == 4: + impl = self.ragged_attention_new + else: + impl = self.ragged_attention_orig if not self.env.ragged_mha and seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) # We are not using ragged attention for prefill yet. if self.env.ragged_mha and seqlen == 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( - self.ragged_attention, + impl, xq, keys, values, @@ -580,13 +616,14 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") # if local_max is not None and local_denom is not None: # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") - self.env.apply_sharding(local_output, axis=self.shard_axis) + self.env.apply_sharding(local_output, axis=self.q_shard_axis) return local_output, (local_max, local_denom) #import pdb; pdb.set_trace() with jax.named_scope("attn_insert_cache"): orig_keys, orig_values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) # We are not using ragged attention for prefill yet. if not self.env.ragged_mha or seqlen > 1: + #import pdb; pdb.set_trace() orig_keys = repeat_kv(orig_keys, n_rep) orig_values = repeat_kv(orig_values, n_rep) with jax.named_scope("attn_qkv"): From 0f24b8e015fcd1416e9d2fc3ccdf8481179c21ad Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 16 Jul 2024 22:40:54 +0000 Subject: [PATCH 36/57] Remove testing temp output files. --- keys_original | Bin 66746 -> 0 bytes original_scores | Bin 2500 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 keys_original delete mode 100644 original_scores diff --git a/keys_original b/keys_original deleted file mode 100644 index 4ce7d0b90ca8bc8e433e6b89d5bc2a88b3a21190..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 66746 zcmZ^~c{Ejj^fzuOBtw}>BxxW?LP+=QZAwa{Xi`a1QA(k?WFAVACXoh8QY4||p1qaQ zTtb>>4Vq}4eDu5D_56O%AJ1<+XRWjDS!bQQ&-;Cbz4vQ4KAwtQZyQs+d|G!I&!n~;Ph zA~Y;`PSpIRk>L?@R$8krmoqb)t}KfmDLZoR!uiYP|JPqs-NVA?M*04)ovOsMX4@V88Uf!x2WZ+lb3h5^zl?y*{DKo1+vKhv-i~T?A|2@8roWuVS z|9`IPa4UBg+y8&?|9zJK`vwKg7#|WaH)Cj@f#`Cr40E*t8ARRY79lG|BVD`E`&@f{lyWT9P+glCzSmhHqdVT=E{*Z*X zKPN-+uN-mu@cVQvXSKLY@jYyeh~$PgBW{QtiZnD4Vn6Al?&6t3Th2^6p|+9oUE(S9 zkVx*XT{zuMj`vup(MG*1JpXnc9vgZFzKktl>DHy7nQ%mURR1?9oK@sFg%L31lp>y+ z7cVU6y$iGZgwfMcail&(lO@Wxg{r+SoQB%C)ix9Yw?skdf*9~K-wmA$C%{^FXTJPv zGHC8jgtr&}gCn7%gd?9^_-CIFR32J(c3waju2~?#1LyLDqZ)Oz{n{pI+3{9LE~$W* z1I?kxV-aOk*g?=Z1N6;_1*?BA#M={o2_4plabxIivH!zu7#o!UVcnm@p_Ri~nth7C zx~0JQ;Qlmel?dl;3$fz7)8de@Yg|F%t^O=%dR!JQ|^G>~^ke1cMQoxKlY>+yIt8yf;!k9u;H#YBFzc`8c(4v}`<{1Eik--KsBCgbFH zh2U5H0qSRV6+BP&f?el(-~iziZu+FmHS5#_<&_fgcaZ^`+t}ilmXBbe@CsDFnqsN; zbUd9i24<`tFAXl#!h3$PVnBV4XzTl2_!h~O>%EQFE!ayR|F+Os9Lw=jjzC(sJh)Y% z$2QFckX{#$&CkooUvnEiGu%qi8;}RS_kN?=xEOv_+ts`Ltfw2hNN-DIBbA0M#XBu({_xc(x)7KB>&&Pov#9&StKl z5a_{sdQ?z(%NB}~)K%M8oA8z|@4+&AA7^~D=a9)G(M;_V>=>+$6D&82Dvw5@QpS5) zF}Z`#F$j%&Z4)ZBaEJJ{4ajZ-5)OX7l2ui%IU?b!eLKg!G&U z3iXO;>)}>f=AOl#SqorIv?^|2ZN_O4d9c{X1MPd>0_A-oeY7nH1IbN#?QFoV%R1=v zmO;=^V1QP!u{`6#O*->5gLe8K!i3EXzY4=qG4T*}-@6fg$LGTAe#eC&H3td0Xh}mO zm2uO^ObQyU$Mu~bs$<%VAhWgv!uRCBQQu@b@hXOw{J9DC3WtPM?amN2Y#OdXx%1ipG#?!+*!8!?lZTGMAd^$#+DUta+yxAd zlBt|}u^;(|rQUR<`V%R<_?HXuHwv-D zF&cX>k0Fn-L$UeYZg5}*`v@mi5GtwYo-#{>6Ljpx6{*3&Qu1kNH?h_qwY^t`sxP*@P?}J&_73p@yIP6tEi@T*RhVY3y;m<2S zcHNPR+w%^Qn$sriJ>Po1|G85N|DR`*5FW$N|lO0~Ig3o=Q!89Xd z{=Uu#nj4m)`jsfGwpt+0?ox`%7uE2)i59+iSwTP93)k#a5ng;X zOAj7=X{A5kXT$K=7Q&wgg*V#jymtRMcziEGe7+$=XAy0Q$>h?j7hv>_0BCBfl>`{CB>7?$n6;rV8^6xLe-%fCcVYcFFkcG>R@Mm( z-%pdvp=oUPbO~8JRK&2-d-U$?ENPnEU18VxV7$?=4HmsmK&N#@AXni51MD|LtaJ#% zAQjZQo<}2!bSXqzTgce9fIUoRi&J0D#${8x@ayY(&=J^+(&{Q8tSK92d!6Ix+-a;` zRSBKG4!Cyh0p3&UiH^#Z;t;)V=pTL=440o4R`=V)2V2MB_1_!lqF+4su`;5M%P&I9 ziXgt8-!5#`9L`;$15v*809o(+M~x0vqN{$MQA)@u_gy z*B?%6XT$ZFr}QS*oJaj0gnI+CMeT+CF?m#fKKkJvRV-2F)jB81a>Q2QW4s4nP1eCH zwnzEK1SND!apfI}KjFufOBB@U5B5iu=-1=lG^_VPVYOaw{v7F#yCMQ`fe} zs^&kU^QW@tjDa%_T$x6uWr-+Pp9wmHA-0 zIW&ijZ?6+eLw``CeqWq*Fpieg_NNINDU=oQS=iV+9%Eil6@4Ckg}y4~lr(WK9DnqR zf)(}AcbGbGrHWW>Jp#u4IZbobwvy?*tptJQkm8|$)12PG(TEgqS)waFJvRo*vdVeK zzzyO?`%|>Y=`w9W27W+gUl_2x&?e|R1)pE@lWr?rb(PoAa? z$-g1sh8j*I_b%KfV43^lA z!nI{fdDz$hIIq12Vm>%P&zKyL^Q?kG^{r&vcv^C5Rx|B;HWwBcUxD)DQ{Z~4I{rMI zKwi6y#Ig(*zVCUB#xy5zQfsUba3K-0JbST3caN->n<#v>G#5;}khpPe26vs9D#R}t zi3@Uiv7U|@Rt8N&>#a7t*RhuNG#v%&;Hfn4aw3?iP2yj_r-%)WGpPQ_7)Uhs#oSe< zj1_mVN0-mwdRHGRs&>;Fy8)DAkRev(9uRZqZHMB7Y&zZLH_U#mLS=ik*n8nrzJEE8 z{^m@g`Px5)Px*%Euug;@TYd0zQ4p9Wq;ZyU0RJZ%v7zf=jA-eHEqGA8o4E?Fe;7bR z+)~B&Ha&s%wNli{ci?bn4#d9+p{G+7;qFB>Ui7p;sJ-CLDwlJh`Jx9r?2`a*74G2l z^B={;h+}N2WJ3P4)VOR|9+=k*#I7q&!*%naIPB+H%2()%1`&VBzSoX!a{dA}HorNEJI08T<;R;~ z^e7Lu%+eIz=_G^P+rfDAS`OZ@Dx|TsEBL|A4r-X$o04=l zDdH!KV(dHapIDzU4`zJ$Pn0U{q{BvksmLr{7~5XX1$zt0A}o|+x*a6H@-8^*Z49q2 z8wgHo-SGa@4LEA}4lvm>nmaDWV#AYd6uEplS7@H(F<-4&J#i2ciSR7Xo`ke=I>}qT7fXnbk&sZ{PO+|^zbPl{epVhyZ58SSUH-1OpE8jnt|};{1S}G z(#H9xhvGQn-4H+Eo?zDehCCOZfK%hI!;S4tRNCnST{c_bT5~tv(W?k%_%(~sj|bzL z|31^U*%?%HA{+d*Mxe2}nZ#Q28x*vZKvc~k&}jN7s+)}zFPGHQb`7`#kagrX_>`7v;4`sMb@|pS=Zh=-)8(zIYV7KuDQ03-QR4{B1x~!hjZHRa{-jxRD`LgiFO?K(6%r$w!L#1>Q1WSvX==sO;?YE z;3{&QmqOFlb>m*!B^*~Jiu^g27ut{I9c}AT?l$5}X*7QMbqU5#9!(4WX>rcB9B%W^ z7kjTz#l_PSU`5?_@x;Ak=(YU~bbdWdGq;=bpk{&XgQmgLaUu@@~r9*a)-LRQU=i@`>}eX5ec;SY0 zoLP6456`{~>OW*y{A<2AwtXK=5odDSh+6o}->Cm}E9_M|jhmPDrv;tMc)ijEsJog# zePsDb<>Te*|E@Wb!p?N@y?GW*+@U1YAM?iXWGMN%{s6c)w~{bsHTx@Dl9$~EdVTyM zU7MHBW;v_qfGF?eaBij4E3$;j`?Z6`Mq^$$u|^EN87;n^UI{&>^a7Qvg1aM5~r)pPwUr=##+Z<8+ypOg-{R5q?J4GEES9q7)Nhf_KL1a!j{h9TPo>j-v{gQf0j+wT4&b_n|u9$oQ`DB-QGqi%D*Iw)yKiZr;A`l zN;kBw=@1>o!5FBu44tE{K#F@$Ov-x$yY;$Y{;(q2>7kBMfd^Q3XB!z;g+Y>a1cmr6 zf-&MhabHvw&5a(xCsb^h7M;Qw&cGfI9K?&`_4xSh7LJ*|n_F^s3p@AZ2z#JE>&Puc zW4}l)efCb+zPcN#yXuYZ-qRdA9A1Effq`H!rwfnNohif>sK8E*E$sBPQ+iKh3CoAQ zp-&fz=@Lgk*y z_+pVdm0h;Seaaq?dO#0e7HWwqm#W1N1}=Db$^?2~P_m}m=MQORRp8Nsq8kq?TYVB!q;9MRUI02e}09$T!5UUi2bFbsw$?w|{Tz2gw z{JQ^wOai9!qpvr_51FP=dvq?eUl~TxYwwZT&@!~HuLdREX7X5igqIvE!>n_)ka9Pl z)jx{jLFZ*+POvC_GUo|&nww(H_pQ=pZilGP)lb5X-c!*}&k~_=4g|g#g4;gd6g8r6 zz|tv);7;NLsviia%OScT_t>Q+Jl$v zx*+X6Lz7=v=EK;EDQK0ej$WHCk>L6XEcRaz_ST%He!mZcj=M8P7x=L9ODX#Ial?0o z{m?0_kFYBso!2;7;u+Wef?>>4=-jW%fA=1jY&KmlD)*m?7oA75gVG%c+Tl+c;!0S& z|CIP4rUTq=yRpieL13}ryC`^%z@At8vShTXaP8v&j=ALkhWER%`RQL^KQ2f#|1cdQ zYYX9rOE6d@3244kow^H&{AlfLm>vHOoXUJT{m21fYj_s*(F%f?L1w6J+X-bJ1~jbc z9Za0ulfyI3#DR(5XsX!{@P6+Mb=&vQ+nOmDQal8oPQODVle@wNB{dG#8w<0udg2+$ zM@aJ@i;2H_z@-_zc>K^N_*Qv@o!71szUSsb`5gu9mTCgW-j+hf=|tSo_?ntNWQbF< zcH!{hVbtYE0-WfY3k`h+vvS&J>FC-x{`VpY8+`-#Q@%Z?yDr3G7WXKz=OF0!?jLPx z{Us=_Z6iDB7|v_+hVBF8>37sXAv*6F=NUd1pY%4y{!cO>wN;*t-#vttL5$#zOgAjb6o?=$D2$wukF;poPm+g2+md-1w-4*b-D#Xo*QY@pIr2{dQ2IH{lQl+6+U_GrLjI&#l6#? zQ4#yI*6Rk+xojn@KeB}`R9W+cC{3Qk|TKV&s**g0E!~ zn*=HgJkM3!YOn-4cjgM7`wUP^##uWo%%xJP99r9{L&Mkp(mrd8z^qY+v)VV)5|94E zQ;SZzWjcko3vGBo?v5xI>Q6JICt!2R6LM%A%;BT7G2A;*Rs)OBsLm1hT{ukx%}#({ z+g$dq^_T zSPgw@T;Q(aMXqbCrF*(X^ml0hzb`mL>ZpQGE9MryL4jbgpQ$T$*W=uaTeSJd<%~HCP zN8Mg>x#0}aFHS*k>;$_?555?`g~TD{ysoDRXHxc!uCE#BSlsr9ii-02#pyvDyk{v_ z9*&3XhxzR5JQN2Pe+K)SL@YUFfXnp9@@?gPbatkVXi_F%uf?t`f87P2O>`vR$UoxB zxqYB$#xXolrYL#PzEFH0^_EtC=)sC}(%_3@0)4X?%*WrG!h1JY@R>c6q;J)@-TN5V zxf|o`Qa2jsor1ns&%*Pf^}si_pRX_fu58awZ zFEG6idczLmsaF|v@z!KMX8KC1*5bzY#)#1<)P8m3Qp9 z#J!EC^0ks7GM*+ILtK}l>FiXPbhu5bk>rW1AC*vuejZF$jz#^_x%gzLt!SR1gIo9R zr1!cd;5XC^d+*zVS;-d#-NGsO??xrex%C4kS8Smd0Sn2h@l>@kmWY`iHoVg5j}Y{2 z4(2w<^s=8z=$+j*F3On4kEb7m@s+FSu>D#5aVAN;_9q2>ocuT|WGnjW7YjOG{ zPqy?*q87_c=(gU5XDoK$-Glq{(mO@M7*3$P)BlMYtH)E&?IFk|R*+a3j>E<0g7FSV z^tV^S-si*d#o-{t`M2py@Uu#Su>-D%_lhPn@7e-ls~;G~ZNok$lhFHO z5ROzF#w{`}I8&`j=un9f9BTduy~jDT>n>At+84?T_ur!pqvUC-h9%a>YSfO#5J_x` z4IAlQ29p8WI7U5Qi0pNO{QV5LX;d+!J{`#3i#Bs-(S&M?w*m0z#A47Ym%aNFt_e?0 zcE!{^6`XfwH~RkD%)OLGu*bM4emf-+Hytma&b{7X)z_Pjna@RbsiTd48jv+o#;wHH zLRP<@pz70)X51*i>j9mVsqhO1EQ;fdDP1t5@q%EzxRH#N3dty`k#gSug~!{o=tyQg z$e+=dX_Ha>$max?D%4WOPeZmnGeY9_>5-tm=N}c<+hJG7pETH~Stz&bfR`Z~VALR8 zEH_t1tE)EPR(loJPv>Iur%6~Fk|XO=D9Jn;dQ1ZSDT<)LUYmJpXn)SIJ|_ukRimkfWzyL`0GB zXGyX+>c3y~q}31?$l@d^JOZ6*4xC=KRD9{7$m@1*L!Gx7Z0TXhE2b@=%Ud?#nS!5W zmpPEFR-K`)Uk;MN%QkVqkva-a>IvPnQgA7S!nX?parlW7)KRC*>ha@wq#du&49IJv0yD`)wT-t=sk^>GH9WJB991E_6&p$)BB4q6IzRPIGroOSU1mUZs?!d`!5#%8sR} z??LHkGiVK%j-6W<2yY8w`OEr6eC(kDMolZErtxzLj?d(N8T}dG2BV9dGYKVopd)z& zMzlVZnx)+!!yoPR&q|kn7a+d7@)NT4W(uZt@^rJNK%CtcC!UJA2KT<|@$dH=pj%r% zd}}sQ$eZp9Cms%>F9UW`Dy=3DS$z}+sd3DunN-Fzd3ylj39l5KKG^}{|MlhK{c>EW z@J;f;{}}a|sL4N{Jcf6^K7zQ|iBG)zLl)Dgv*Fa!qK}m_A3qvP5ied#UKV|&S-Uc@ z|MUU;z4j2q=AVPeYCmpoJt%!^yM;C-_J`q1(rCil$)qsenWg$<_$;A8VKXr^-w zHf*pVodr((x_BQODIEd}qe2LkVc$yquBdnE8=R4gran{rslH+qU0oT(PXkARQN>B2 z0f(cC))tDr_lq7>G=Wand7}4!SH;3%v3O!*A8!Ac3u|Unl2utG?-+Fi9o+=hbyem< zhYy0C(Lorlaf@a7|2~V;o(Fc? z?bNlqKFj?)O1(c0WRDBe@bLg$-Vq?jUYUro3m$`)juHJG8;-XXi(#g8Jjbe4?c<#UV!@!K#W3HVqeP zf^h^iwG5}@yQ(F&26kL*|B$L0Gsvm&cJ)6kb8b4Fik2I5(SG+F5RDH|>IH%|e|mG= z&0zlRpu$;t@5PdTpMdVn0zdCjoZT`En-2aWwU7s7B=LjDi)GL#Z$O_@J?Zed{v0nE zhSdi*L7dJc{PKP?KRK97gMWMAjf2I|qjA4rv#_1sCY6b6f91i;J=vtwaz^U1akh|C z+a!*$bEQICXOQmFB%OWtphEkl_@#U?BnN4-g4$!j*6lr|7%b+feiiWRNf!+Cl;^ko zG=v)$EAgdAcS;>;09U$Az>-~AP;hntk8Vw(A9u#Vp>4f!cu042mW*MWi+^dE`6dco zxLn-OEQj}M18L>caJ+JHDehE@Axa&~*|)t|vbRO}WFN+3Q)0NH^9sD09K|Z`Dd16k zKoV!+K+~g}gfGTL7-q5%T#}0L&0!fYY7$T5t4`AzTSwXwS_PL|aD_UdgbbU1b{+SkiutaT-o78?zj&3>l2udFx(5r|CVvl>n@!R6nTz6-$ z=u*-HG=8Ov{vD&CL$^tYx*mpV!_v9!%P1`Rpdcilx=eXi2Pn3D5<~^Y!RE{~n!e~W zUMsl#0=FPL&Kv5o9K%IR}-3g67!1MZUkcwC0V z>({&EZ}lp{FQ*Vcm*vA6-v-EfKAJTZ7vjx?J76~MH}zh#pG?Xe@rS1c8gG19ozSVm zx}zuZQq2^sXqwAQ;zq!W&Uc`4MaIRgR2Fx4>}I(`o>iysjc12N!^H_-o%m6;CHgIu z()=T%VAWv_oV+`e(*>eyhOVd;aE_Ft2ead;Loof(e)w~Lw$O1Ph7t#lfm+XD7&dJf z*}fPoj8Moy|KW}lrLD|HP46jcpBp$IN#c`fZeo5*E`4~Ujpmb2kd50uY6y0r9}lhx zs=B+tM?;qDN>5Q(Te^%>Y7^*hF-&nkE6m^fP;4qKp~$Zr;b_NqYPTK1#_3H$^z~GV zNcM(xwVmRD+6mIZOI7*3f+o2SIwIWCQ{v;xeP~wCJCdU=pJ;~YfftAOM4N1NC=QZ? ziJ__(GDV*Kp6?)A?E_%iugsC!)cygC! z9k+(*cjEa#+G}aYf7$T*a3?7|Z>QK{VYFgiGGs@+qo)c+SafFuk8Vzc&+hJIyZS2V zj|t}kIS0WyI+7al=23lN0+ofG6au{dK$N2q6jk}~v9&R9IPk9Y=rJw2m3tm^v`7I42;|LVA}r#WYIBlQuzQFp05x^V>T)@X&cA*DG0V$QXjpgzG;2#?=_E@&W3-5G)`jw0g(EP*)RY{q z7zvs;`a+Du9LI#veVDDC3HxHN)5~e4IA^|uk0@rt=c7HyZrM8cH2DNH)%zisJ|PA8 z10I31P}Bbrb^I|x&k;R2Kz0u<_#C0xCx%hwk^;&azKq;Q_5?l0-FSUf110#pfa5oh zl2L#go5;Lek;ip$;w>%G8hR9_&8(wu5gkG@$I-*ISn!+X#^K|hLU)%I@=5f==3nFZ z&%v9}@L@jK>l@;uwPvDG_*OVx3*kZKb6R?SCpdU_$FkZ{ zFyY(`$iJk@PFpf5E-Zxm8x(<~)(y~)jfQ*QazxJz8TaLVKoX+3ih9peaPmpq3`?G< zVy$^^d>=Rz!yMwcrjBF#q+?zCI1Jeoi!1HTVb>W2jCq-g z+b$fWZ#C0-b^aBy`Zt0r%?<&KH=)&bD|lk@6!t!2#shABB6r<$(no`DkiqcL3?=sH zn_|jt(F3{Zc0I|f-i99T?;(BKHh71!vEH}CXxqU!d5#Sl_Q(|OtWo19b`#hz)f0^R zq<~p_qEL6h1pD>-M#E<>%i6~CqD7UOVko`j30-XPZ!OHTaK zk6QE^#AvY#uFdh4ac>7`>FPZ4H<~I=oH+rT2U~H^PewSfy-n11n}&^}mFSqkKhfH8 z6|dT70m|ZN9zAUf{OByBpK=@MtmhXBwCm0Z6{dLNk|G{H+YhFnHR1$?QR3CC7C}99 zAH}^;5R596U{rQ2I+}RV(=$rwZ90LaV*3fNWkgQp&-iq%`YFuqV=A=V&I%%?4VUA{I%~#i+*>* zy_3UX-_3B|c1{KRTlU4VS(kaQ>^|Op>wukuCvmG+HT)hr5kn8xQkiu(EDDXr@+CfW zTqlMCa-#&F(kYm4UkEmrGx_QfQ#NT;tf?7R43(jSgw}}}Xmr?;^Oh~AE?tzt-Lw(j z_=nNe%`^C$-ezbB-3uqy#!$7}I1;Yfa@XRW^yR3P)O*efsCwTWJnqC|WnK-)AJK<# zBTX^xr74^n_#JdFuEmV1SYG`s6Bd2cfs6(5)!yHAQ17ZGJ{mnJ1^OQl0MVy@aZ0S?!vl&HKs~xlQX98NYZ?9epy467Lp^3O|p*-FG9zkmZS_ zad9VptWrj&04>xtyD1p0%;WE)H{&}G3$!(D5xNwn(}UBI-1KXKP~-a@q81Ec_UOTZ zulGY{SSdhG1EiA)FKDpf(t-N8tMDj=+RWgxd~Z^Bi{mzzOYpGs4xPI12SGFDi!IWd z;J(3(Ua7o+-j+f5?1C<>xS2UmEhITKDXv_W{oJF=gu-7!wXk15Y zb9RZls)nK6)4s6deHgxco=k^@AEaMd1lOj>!PMM#AzppA)Ok@1?4MapBaX&^Y#dbH zxK@~b;gG2OpCPpmPM{B`WW7YHOw#(^N!3%`=*|;OKA(6?EY92oosBLSaQh;?GJJps zEk97!tIbj+V{e&Xwq3{=umpr^doFG7Nt15}<5BO=WSFf83G0TTTz;*n^+^f5@0P=q z`>8xAc>sPWe@Vvr`?0}w6W+>Q%staK!?J_noIfy@;x4D)?M?0Aupos#{TB%tH{L+o zizpua-J0(vYH*=Zm09j+rpDoMzF*$CXaD1hP@7YaQ-%%NTHN`+@* z-i$u8UKwI)(y(sc3t^(yCKe+iw7kKscH z>%}6`9{sYhlj@6{AYlNxZ(z5T$T&nLL5dqGGK*-zft`4p16h!m#z z;!?S0Fv!|WI*W91&HQA3R4>8(lU~5Cb>mPndm48{Sa7iRc=&#Bhv=yH7yQljSp9k` zuADm#TOteuZ|hIO`mJqrY~Ob};=6>$&-EcAJ*nswsVU3d4Rk!GO>~^o2Lj3gO^+T0 z!&*I-3-zR_o?UTbloAKm4aMLY%CO^FGZYR9+-bp=d?t28V-RE&*!;q#q`-_T)6r zpXBA$5A56&kj7-w`=TVg_}G>fnCH{;fc3D>cpl{|>A;4}gY?S9mX3w(CA+L@E27pA_rw{-oIGxwNBpFRtz51R6c$Fl5VQ2+{cr77qQ; zdE*)``FK#w9bX`(`S<2OCp1}6vnS5$+Xe;c1^Bq65%$&n0o#fk_-a3pu04+jyQorl zyy!L@4fzQ_CtKi789wxlQ^n&i>bO!{JsiPoth=HLNsaL#%ySx zt{@mMO#z3o!SLZ%AMx4PbYXIF9SClIm^@?{)Cuo}|F-Ueb@$uEQQgdWfAnH1?@Xl? zIZ1dw*$y@-cjIN}Ueo!6GvewAYGReRNc7RNhHdk=pj_@mzHp^klBW-ne=pPzPx zTC>NA0@jiHM+r~uF7rgr)#8nx@6ps&1$~dwlbd== z|J_TW6rK3dKh?^4N%jf6;sVg@p+$;YW`f<+1R=fQ2w8IhyhxBh+q%UR-NS%xcQu4s z{ezq}<_bm(GUJBgJrb>!TTsz!6S^1f!GoV}QQUM#9@aPrbuX-^-#?GTpr;Bn@NbZq zlIcZ{WE#r3GxeaOc$nHsETQ+CIVkOG${Q4eIjuDwGkWE+)3R#u&#!#JF8;4LHb$1` zj}67w#Z$?D&|0kBZ-SXVuV{T@U-UCJfpZV*L9dq~?pi;QzEl-L-~13RNos|eJry}- zq6eA8?H6|*QRnWhbK%{?e!L_=cBMvr1BFBZ_sw}F4(_}tp4gp68~-dJkI~6ge^-kq zjz7VD!`1Og+hdw<)tlcfos5P@jG;@yU2$nkH(`BpAe`;$hM9!}K)+v+Z2V(o{9U>D z&1w-oT9FG&tyS>Ko_xA`EPzx>wt!p3G1!o`R@gtGO1Kc+9e*A(;?Yi?!uSQsJaz6? zFpN|IgHdS^`Kwxtt&U{9Vf!g#{4!{-v`6=Z>9}BJ8x?)2p+b|-RFbv=wtbM$m_gwS%Z^N7ne!4N)#TK3LUMd z>1Le^@19bE()-ck;IupBZS5wGE9}ZsY)^{2>i0qU-4g6}><2WPCu7o=Rd~o|F&&rj zRP?KZq&aDVT%@&lyix_7gYu!6fo(>|qVI}8_%CTLuNw9Vjt5QUm_crm?qjx6)VCV= zxT}QgO`_nOxfN>uJwgW?-omXIRmzRJ3^ z2Ii=E;R4f4L9UPW=z|_+cwQrm2KDHK!e$melnP!s|1UyzL(Ia{<^5HUM05P z0~i?64N6yxqngl%!pgvPu(N+JHcCdOf*+I~s|gJm#^m=m2?ox-EL5C5PDx8M;O0{j z%Le75fBiiyy!`-v-6`OglNH%i)Wyr+V$gOUiMFP_vEqn2J-A*cJR7;IdSA;;9Ds;{ zWlc1E%woFC{Uq08YiZo2k+d^)EBzS}jKMMv(=0$>oz^qrEN_vV-L+V^Jqzj-JaBl{ zOx!Q4J5>R5=%800&bYOe%}>elm1`gJ4m?6W#iNAlzx%<#T6uhw=0(dMPvKGDqhQP% zMOr4~z&$SyrPT^cgr0r7!;7IiIO~pt`)eAZSZ@g0q46}KMWf!lMxu{P-+vZBC|?v(3V-n+2dzRR=F- zNrit~E|BFpiP#?BTg>wkVD5hlAWwHP|7cnOFiKsR?L7sWgL`wkRVtnxN_b|I0_Zf> zz?t@ZnZ`}wiv?%dHy;rS?$rCpTevgG3_jmZ zhtt=yAkup7&!lYJ^;*o-xMkX>3jUj^%LoupD<(Oyj7;3oxNT2UTu^Q1>d5 z*WHc;bDKG!bZRwE4H^lBkNdIv{JChNV}+NR2BV?DUD1lq1Edr{Q?NB@y-~!BAPbz8 z{Sof>+XkzrItb2co_u+G6s?J@7PNzc&}Z0gQoEFZ(0UYN)gAcGiz#&A{Qw^Lp&kCD z1f#RuGH`eH6<7O*Q{b*slBm2)ZaH|2%o^{2#_(~d(RTnT%(y{Ef@EHKUk5OXb;PV0 zZ>ajkE77e%f!A#pFzc5St5>Swby@FmNqZizZ*=9=S$Ci^&=K3d*NJMlkeiIF#JZsl zm||~&PTu21%k{_5X+$FSJg9|xH7~(ZW#E1x?wI-GtWYMI1ul|AF+Q(9*ZElEkiQ#9 z&(;t>h3w`o)8V>>?LpfRZNaq-vkJWLZKgFPMa~Q7@ih-5=I6r<11nw<^qD5Jt z*FFHuT6uN!?z)@vx*ds#Gw46GkB(Jj?~=p04sXEXTzd?P(DeQ6;9oT z#I_i$|8@xq%YE?XrY%DEq$H|oGlZSmQ=r`NhG;GGitRn%!Lem~z;xgnVd{HB+#2|i zo^+jqCdp}V(&8GQ$&3xr@KUZ&wJ(ohAcoa?y^WnqUf!IIhFeMj-uxadM zo_^!JX#Au|3~Fi?jK&s_b(lT}Zvw*74TvF@C0bqUlIXOWu58DaFi zUo>}+jQ8%Q0nPfa=x(?L6_+w?ExLzMXAa>aF66+CS0up}gVAjDFk1g+6+Bs=$a@B> z@$%qte5q_A1{N)2_xP{SHAk28e)hz2&GpjRd%Hk^%{F?c;lri*Z|I_QBSkHjPoJVgZ&Yy2y(-TR-%OpMBcT0EG!$9tu$tXC90}%F zFwYk+{GJ4>)nz+YxB3AG z@~2BQ>k}}=a0fkFzmNk*O^`mkST5eYa7Ee_I~bf_9TbZm#No$ZR^&)7*uFDckhA^i zpw(|SJaMw6)Jw$#gZttRF;8$)I0qYBB51ai7tUR;%C~RC@r%1M9Dm&z%VoOC`w8DD zQ0DWw?v+ga1Ild*n|-HLEP`RpFF0>i)ItvL+(st?yEEnr`?`RJIks@yG1GT zVAUP$pl{8=>Ydc>R2y~dcNI!y-)wkBZUeViMRxSu4nB9QXrb38zTGqst9xF8F&%*C zz9f;}fj<;j`xx9@)v^6hfaFGgGT1KoD%2SnV543=D4I%yM`!Y-FUq^30;gbR@n)?@ zzIAvr@o#v2^PBkoOA0G*m_vC+=Ac{i9?&$9JnuURQ>8iAhCZQi@Olr-ULSz5i)N5* zR0XnLg<_IeAQg-Oaf35!h^dBsb9|Xv%O#ZWNW?ClwW7&3 zA2jgo55;HJpWf_UK=$gx*p;+WnEu2G`af}}lmLcH$0?HU-~&*&R1Y%qcA{hH7eTS& zweZX&7gXlE;j0cZ+%)o~kU359^oNvrFvhMG6y8~r`tE(~tCl*8k@gRlrtUX7UK561 zorZd4!{MTyA8t01YBIF5WN)PTo;@x*;do08s5xyGJ`{|_BNNqVtdl)e%}St6t430p z!XvU>e*`*M--Fp76LEjtXIZ4?OlJOCnO^2QQS~!B%<}9m%>Frn{V~ym9qT@v;-hs~ zKgS(-KFo;u{}#kfb%{oc@Il~^s)T)i>P`c9XNX~s2=8ehhplUWhZpj-FyN#+DJyq@ ziBjFw*ImJac*39EHLu2^E{V)rd@LK`P2j)c2)#Rho>t{*;3xa8XNm(tLwgUIV#Jywi5D*xJX7;`@tO0zu|3L0UF%#w28 zc<-CAXZmwEmrzJsXB#v5AU`SyuLld~D`H`|4n^sNK$-V&YHchNR*%|5W~bV~A+ibe zEw92wHwAj!ES-zj<$%B8IJ&g%4p>+{gsXLbiMozN0m1n&h0bACi9J1=V?u>>eW*6M z6MZ-tM!{`eG21GLEms>w0lx&ZjSGTl_7*+nJGDw&5>g{LJ=daqj{&g0ULCJnXJGA$ z-Q>Qe7p1YgkW#KCje`lK_dbQX%<72==bdo-P(#vM?TS}7bfp=pgW;CVMOf4@ko@A~ z@u!)cm{zhY1!`MUjQ?`sN~9(Wadc!0b{~bf2r~-b(u*00oy670Qv}Z;C)tQXJ(!T^ z%U;x*VU*(l2)eyTust}C8TPe82eoy~b*U=tt1%_#UrkAnS72n_9n>q5i?JU4nPJ%_ z?3!K=8B&dLkxh4etkr{AUc4>3zV@K=H#BMQ&{+1g?Tqm8<6~H=aoxdJr(Sdq*iP!j z^Fb}}3iJ8&i}6ea8rv@I!ZqEKrTkK94qvKEtn=K6Rw;Af!_%Xn z++#3%`0Y0gkDksdmJJZ!KV420S53+O&OR`{EUg)GXEUVSlJd~fvst2X3UlkIMRD7T z$`y_u6RSUVb~>~riGr?7VGdqpP-qwkD=V|{>aH0#e)GlO3q8Lj9w5RuD9sN5_M92m#^x|p(4xiWg4F;;Gna?o*!{{XtaEmN$s*vsBoZHJ4_+hj0x5|mHr zMSEi&L3;EMYWuYZTipaq99$%3=AOXQH?-;C&T-T*?GMp$z9Ige+=&gb>%%@M^978W*Y^Vu5H2NDSc07W$XQXuK6d4wsx-6<3 zX$9}~O0t%P-VkW@1fR^@hE-Q$NJY9<862cbMXv^l4lj+Vs`LXU^@xVX9&?%fzQg#X z<{&HnGYTFFQvINL2FC5Ygn?fS&^UH2d@R@@o_nV$eO?Ktj&ujYCPlJl={-lBQZBZB zs>DoPcY5Gs#eRWkreT)?_R~bLzTYH9C61-=8fUU>JRs9LGKIZbkR@*zBgfuSZc?LC ze|D`~7)|<|Al&I~&I-l{<6}}LI~c;^D_YRSA{q^srNHos``GwsC+0S0Kltpt1GSnN zl)Gaawp1mvLPI-NxH%RL;*8kcw%+WjLmSA`okX48aQU@g@w=xd7_XpTHdouwLP^>~YV&U{P&L4=dBy>`}=u_mUa3 z-rO$l_4)|-jm&}cB`Pe*sT1Bg>_W4YbH$g{K<~^{nSpJP^x0m-TO~${VZcg$ ztsB(N>IvPZl!|HwUaZyTzNn>}NG&eSQr({ljE}Y_mFMg6P0e%BCB%Xv{QPNi$r<5% zzkJx;=tOhPt=RxWMYg@`Om^!^6vX^ljl;7q!@RaMvY)sfFLq02fjxDYU;bMNPoBkI z*A~LeXl3@O(NWxCJ6ITccaLyl*J-FvsK;ni7j^&+OXrAWao?>E&_gE=tE`t|)2lI{ zo0DZcn@8!P5}J{C&ja6s%+{X;beBfi8*Aiq63;L7aF{&M5-%q6`H?E>u+@% z#vG3HrL!koNc0V2DHXbuZKz27W|ZT*{5PVfSr>LPdYxEiD6P5lIfc$H3t>YK^rHtk zpTRR)hhBbiz*VC?n9hoEDEp&MXw3`5%9^fhUXe11J?h1f+@IK2%eSI}rzMU!6AGuB zidlAOp6KaZiCL#`P zgx1(wU_X8+JxET3lH3k#h3W`WP5+FWeEsPpN$U}uU2|F`q$gF5nMR-LZy>Cgi*D*3 z(sg6AuvXUz2lbhY#`WvrZ-vhE;=oYm6k6>tyU7g(DJKgCu0ia)?_g@M31SY%Rzk-M z7R*RLNn92c38so};FH=D-sGvlcKHxYvcDxJxu2rixA#PxavPrAj};$&v1aF8=7NFi zILPVI3*Ihff=5jXgm^B*k4YCW_28eVxnu~Hh^|;Rts5#Ntr0sr`msq@wAss3=drrn zgBo@hi7^RgaA%VhEj?*Z)=z#D&z>5_mb#?l1gRF+`Evq=ugrzovo~;F>oC^&eX3M5 zU`nwmdd$kzf}CY%arD}r829ETKASY2_T|mNlBJulCcu>9f&@60FcuBIT!N5~x~zwf z9#pUTf|k!!nEKEL@x5acoVncw=RHec(!nO^Qm_g%d^DsQzJZiiu|k}Ez7b}>wvlH~ ze+Q}h$KhOdE&TM^OH4rkrOgx}TV4(iwrz8) zI#h%S1|wPep~3P){Z}*hm3BnI&xDxL>$p{UIqsUei!Hvj7v2ednbP?kIA8cnOnaIS z2eO}Gb@6sg_Vj~h|IXBws!hSoBU!GRHyieP5nJINFBGgFfX$(Lm>BFW>tLG-nqzfw za+EyRcGemjbS-h5W9`e+W(#vv9^8RHc4n8qY zfPOQCD9r>kQBuR>xxhkZHK7o@0AfyRk-Exa$B4V`D0938&HGlvD8FwQ?YKmAbJ)#{ zt#*sC+JB27QG*x`k;aOrEMiOQys>kaHXJAAtHsT}Dj2O?hvlt%p;SkecC~o0;^6C! z3HJyU2HQZ|zDB%oZz0pp>dXeko09H|rL5k66!q?Z}LWy_nz`|niIdF8L-@2$t=efsZ1_qn%4tJMKG{GJs# zuQ?^ozBZQ*JEk%F0n;(PrZ+VE=!+K%`%qlga#`C!EezPG#yWPC-Y@qh7MU@Zeox;IUFST4ScA2=K;?um0o3u`co7})V}<&kwdqrVK&}e;_^|aTen?D* zH*b?^+#PkC65xcBj&G#%&Up}Z`x6Fd1u~6*-$nZfFHF{*Onq)i*O~_unXWWOVY^3{ zeRNn(Nj+zfN3a^}rf&g;H%Gux6K5&U>I`ZbOr#P^9op|Sge~%!4jNK!a)!MYCA8S0 zeS|VO-7}CWS>)h)cmYP6r7{1&8mRLt!DaXNGImUhT`pA+7WygEgl7Xq&FP(}<62G8 zzG5O%VY;H%+g6ypD-AV$r1g!y8WBA-pr~cupn6#bVX2nrTiJz0zg8vl1U)==p#-GG zCVArmT!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#Z~-pB1-Jkg-~wEL3vdB0zy-Jf7vKV1fD3Q|F2Du2 z02kl_T!0I30WQD=xBwU60$hL#{J#+J_E1(ynSbA1L7{`x-wq1i9!eb*6jW6HYv;8T z1ckYw%NEUujG7ZRCn9j(u*IQKi{?Z|j2k^tA>jYoth=`N;A&3TPm0pNY~TLp9vCZF z{Lq7;A%W3>!xn``%?X@0CpIuz>dlNfA>&35`d@d*(#Fo#-p1bE#?IQ_X5>g)D?58D zOQ{XFvl}sDxZOz05z=q1M%uRbvNT;oZIIN%_RZU)LwjfcQ)-n8zQ~jAw{N}M|L|S= z-mSgULl?}Q_hXDT!Xur|(thn*pYQg6E8Y7(M$oe8(8#Ee(5N4S{OiHg{4R*EbcDYT z^6z?8A46}C&Ydm(xu?F;kpGO) v>7PIS$L-Aar>)f9OL5zqw4ZYOtzXA~+;9Kx?a`^7RH2jf)&BT@JoeuJlpN$? diff --git a/original_scores b/original_scores deleted file mode 100644 index e8c777a2efc4be8601a37ea3c4d3866069fcc635..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2500 zcmaKu3sh5A7KSeggb<&IidvS{s1PX-Lnd7eDq!uQl9>r8H4LnJW zFwc$e6KWF96Y^JWYMRMCZkn6bW+t{)E!MAySuGJb12A$Tb zH>Ox5v&>04vo0xGueC}{;<0Lzm(?Uun|g)B$9jr|S93xS^JJ&T_qghzjL$o(Ub5P% zGdj^RQT<*oyKC|M*`T*t4LT=3i#!-A;@M8`nLn6Oyu){XRe5MkRG2ILO_cw0v$7SI zHl4z8IcfNQ>z{y|=M8a6jx;>jhc}OF!N|#>#9jgG3m=l(RWtBhM~bbnehjCX*dKEv z| z2+VU6v!ib%W9uydmFgonUXqQ1)VVkkbLj!eAh34mv2_@5V-rF-?yoS}y-gkBZJi)T4_KR<30Gcm6<_g#-VR^B@|^iWd>_ zbfoKN?}6LZ3hvO=^`QFKV|-B;$NFtt0E(}+;@%r|c<|33va%69>Y$|_Un zw7Q|>u5b&f+_3>4e-g+x3(N3XVFhh!o{R;7&+Yb*TzkVBf7B+pbGsiMzMA!~pgh!gNy+;MqR# zS;u(vuRTLW%te|}6U1G~SVp#Jex+5N4(ikQM|hNR7i{N-!^XBtFe*BbJy_w#{q$fl zaoiSDapOePc+RJzuaCsdy#i=3*m~aMQt!)*qsur-j zRtUMApH@TI)ahkUkR1@ zbKz*oMk0}J!lKA3SbF+BwwP9sQoRMaK827s=8q_=A4O*L?ae+<3F4l#ZpTd(S0U*A znM8fK69j5+_DT!T)8A$z#Ek*_<#I@zP-5?s7KiNz_M^{E1aYMi{o3=8+5G{r9@t5i zfAl%dR13LvpH78+fgW7to*P)CFDK)c>yVk>sPq!};?tH)F5t;#-1dDo7Kdfw*OyE2 z+uY$?oJSJ0ec>Rb^ZQ}DCW7m~f3(UyzBjg1&%wbvbKoD+do(!TAMO8^7;J68HFUrT4VBIC&WNtaT~u9mKLX4N>^?@KAC;;5w)3T3!b86GiS86Fv_ z2#ZjLg*q)VY}|MS-=t7PC=`)O=PYNt70G4%44oby>*k#78|Ck=jzgRcZ&^=ufGefD z1`d6T(7`{sQr_sQE2jH~4Soyrd%i|+#dOK1E2I0C_`bzx;}zPK@p_|OA>F^2|633> zU(mQh{`X$S$BINjug4A-2EU>3KjSO=?{&Ge{Ne4nvsB`ZANT5RiC5RNvpGK2!&zQ1 L9{l0#-_`d&RQ8R9 From f9058600c0e6cec4d132347319febf88bb0bdc62 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 16 Jul 2024 22:42:13 +0000 Subject: [PATCH 37/57] Fix the llama 70b output accuracy resulting from gqa. --- jetstream_pt/.attention_kernel.py.swp | Bin 36864 -> 0 bytes jetstream_pt/attention_kernel.py | 13 ++--- jetstream_pt/engine.py | 10 ++-- jetstream_pt/layers.py | 51 +++++++++++-------- jetstream_pt/third_party/llama/model_args.py | 3 +- tests/test_llama_e2e.py | 43 +++++++++++++++- 6 files changed, 85 insertions(+), 35 deletions(-) delete mode 100644 jetstream_pt/.attention_kernel.py.swp diff --git a/jetstream_pt/.attention_kernel.py.swp b/jetstream_pt/.attention_kernel.py.swp deleted file mode 100644 index 7ac8ff387f9864d2ad66bb2cf20c1d4a8c985251..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 36864 zcmeI539w{Ud4L;blf?y03KgltdoKNkem(PM*z%Yj9fgT3%D~L1XF=bMUXe<)ZiUNvDVJT&hvXT;`k%|?>U=+$NMoYnkLW4-Y|19_1+ppg; zqam1d!}q%Hz2~0gKj)t9|L;GgO|QK0wXELY#{GHQRocf~7{ol%EYC)qM zE(DDkxwguuJ25}r7@t1p?1@sidCU01_@=>uUl&KXaVF|kd!>#0!fvPA4ukr5YpL68 zm#h9vvDMuebh}}r8#Non*|6ORYx&mF;FyPJhY}b{U{wk9y3^z5o|xIPc|t#(qa&xQ zlTX^S>Mg^ILkSEeFqFVh0z(N5B`}o0Py#~<3?=Y?O9I{2F`26<(vxhl_S^3x`o530 z$8WXY2l}obV~^+U_t7>27tfLQ{G9#1vhVs)_Poo-kNU13Yp<8=_mBIoKiQto+3%e; zBA32BJ^UIBd=O^g z3OE~%hx?DqWNw8w!5g6jC&N?Vfv03L{|L9hn<0Sn;SBg8n-I6d)vyn?!O`#oHb1U` z^I#H=furH;Y?|B#pN4DTPvHv4!f9|U+{?zthoB5+!-?=+HcM`WC0K-)LJq#mCdU`y z!*Da)1lPiDxEx*ro8UA!0giwlQm^;GZSY>0hf85ANS!|ok{`vV9&~1l^Fgf_PO9lz zGw7;mA~_f0#_tF#2Eq`6sLK_VDwTRGi>OpWOE&ohEX{Pn^7vKgbDBIiBkf+J*lA98xnX2%(nm&W1;w8;5-m}l)SPLB z{-I4~fa|sjOw-k!jM`dI^gvXm^rre$vDPj2pPJF9ynIJ>_=K#=%|<1gjWWW9mrxQ3L`}MaZL}KcGjo0Cw@TsS8cvkaBxmJOYtBTB z1;u9DcN71^=^GjMr5_-}IJ3H30~#Z(LoQ`wx#VaG?jfA;LZaYA&tLuYl&0c{O)Xi0 zr7y1B3Pp6%1XcWTZE%BfL*#lNQzcDTx8IGr{dVTs-K;Ib-th^Z8kY!H9De zb}pxA``K)EBF8f>f8{GHf5zk|&-`@QZX>#|7S=J=EXr6^5E1H`!WBH{9auM79u-#<5k7n)E!pko;p!WsU*L+Z4X{4bfUG#3;wh7=d z(jOTa*%{PoK}Q)oWHsGvGcyi4RTbZmCs4-LvL8WJi8-8^HKw1p#05X7*y*+jZ)eu# zK#Mr1)F@9}v(=noPT6dzax-|wovu?WUP85-3k(m7`mC;HeQ&eZZS}exeKt7- zFli)v8*@fKTkaxHqWP(x4fB)7zZglg@`ojVm~3|@fNg~4KsQ_yIx&aVVlBpEwU*y9 z`(3^Sp{~a_<@=ms?zC)$-_J!mPLLVRL{8(OAC z+0aysq}8hNWAF_wUVEPo>~bdEYFr{7uEhDcboHQHb?4*NMMl>DBdm#k46^>u*~Q?e zSku2B{syjxC0K+7cn$mke2=yJCqdTsm%|3Qn|1nK@R#s*cs=xB7K$M2`Sal%$iYeQ z3)b%6hA+cEgRJjA0N24=;Q-WN3U9l)z8|LkSEeFqFVh z0z(N5B`}mgLIR49TXd^<7uYPvADA7)jHDo+gt|oz zwWg`#L(hgwiKE3ibEO@2d+mnmgth5>N%TN$XBgE`y^7X_a>d)YZfToP6#1g1POe-! zcV&6m5+zHOD^;p=nUW=P|M_1QS7kG6gAioD0W(5xsy|7d%zu)t05Ye3+`HC02W}DS# z!PilaUFCwLY*cw2Gz@I87R{YnkB~fI-;rw6YifNhb+o#S;>E33ojD=P77{|%k{2ng zDu7k8qxX5(@VQ*KK)(rO_lNXeP?`Oqjxr&(yBFCn*6*>?QYnIM(M|SJWhb!PVfSD$ zM{{U2>ro?Mw?S`fIazK@UeOXz%IsK#QA2bwBcwlKu8EtRR`rp~-KgRn0r|>oSBHRJj!a*p5>;YT`FN5=86Z{_h zl0ASg!wv9WXu}j-0o&mD@M{wM5qt?g09V67xD@j640wdK{YM~xXTmY?b=LS_footM zBDfHq0Y}1j(BIz-*TXyE%}{|dTnayYN+$E4@Ne*WxD{@IcflK=0T)9Sj(~g7?SCHL z0|y|09dHsnh<^Y3@Byg6`S4`;Hah*!z_n0@li?9``ZvNG;T$*xM9+T*I{hAO2GQ>y z2R}n+|1}W3{$D^BUIW`;1Rg_Q|L^csxC7n~Z-yR}VGo=Mx<2`K?@7oYpyZ=>xASeD zew2&Z5*z&0S)K!)^PK8(RGAwac$G~yds)XbYWJ;LmC|{NKlx6`-hVqNhgoT`=og4l zZBrAXS&*LvRh?5CH%e-p;V)-m)x;K2qUgMk1CtDw=Zx}2F3I*i4@I_wHX^=jvsP$~ zs#k;ysvOW?sVci--NFQSB*!_~ZSKtV!f=0>bz=jY&}c1nKna`-5)XpuAZ|oBl#!5v z*HD0UMYSY7ZHHo*9x`>M7%YZz=^XMC($^YoOe%=G+2tZJE;;Ligo%XdGg1+_G|nOo zYR45)MIEhCl_dQ}GNMFSsqEg8&Xnoc*^Mk#!=S>jQf}7CeJC1k z^!QU_c^ySiYc^(--i7Z{b98-5_k!z>N1->-QSIh_Gj^%wwC{}9mm$fdWNk{fhRgLV zqvOp^6VI!dqhJZ6t&(U3m5SUL zHM|t^>Lp7u#Rz(}u8R1-qXIP_wwGM8O)K|u&}wnBNk~@(bCOoNY{nX>Tg4=ZYaMfE zN6klp>IRK!DlH1w4oOa@BgzbfM0vuQP$WYPAdu=5f{BAJ|s-8EoxSVWmdPwCo< z&xC5WtsqxVK+69cC4kXUUj(V=Dv{j*xfFqJZ_ns0y&Isq~e2dLnjDlUgZ8-?O{jvqoac25%bQ zpq|O3^1on;mNEoWe3p05%Yx`;9ZSz--L~SsvxyZ2?X&OVp_rw~0?7;s63JnkrD3O* ztVyw3OI7+xc58mB+ry3|TQ_z*o2Jm|yPuM~vFrvjNSdrFHRZ==t|kve)Y;^4BZg}J z`DA`81G?4y5w-hB@ps9@=X{wA@9MUD<*qjjd7Kil3tA;(gDi2dm)L9bD35f6>0}Y3 zzo6bJ92j9_qMM@IDt5xU*kzSFBawfLb`)1rT}bdkb2ho8M_e?9w;69 zGZIl>ql--k3h^%&tAtJUmxa_(GMMi+3;t&^084SE(iw@;d~C;s1*_G#i(;MuoPpT| z8@R4UbeOpDB-3q+m4#Mn>&=X;S#ZR~u>E>Q(u*g<_Qx42>n5ZuO|4L}OcLP*<^JE&h3il83QxhAdMzY2Y%_B@5(sMqH0I?_?FrHh$RM z$gq6*;I1U5^duu6*=oX=!`V?VDIp(0x!Xg^OM2Ru(cyGJ`!lsCCB0xE2H4rXdN47Q z+*tZlk?8+F%DVkuLG=Gml7%84S^w*E=5Y9TD1o5_h7uS`U?_p11cnkAN?<5~p#+8! z7)oF$fuRJRU=mP9ttOkD>6ad3dzs=q+a~(OM_K>Nn*SpC!}|ZNwk-Fv{=XTnhS$P5 z@LV_*j)kA2$G;E033tNZ!^h!zcpEH&yccjTOu+GQ9Q*>^{`+7L?1mS?nJ@;w4@bcR z==$%1Ti}E62570zL=V z!#wPPi{RPtF!ljogFE0LOv1z127DiGgIhp+16%}W!Rc@kJczBpC*kklO>i-6g$Wpe zN3j+76+8m>g7^*i4BP_mg}1>~Aoc|j{1IFNldut<2hRfWBk&MyLu?K{5B~tyK^>+* zw;SDNlzK7SZu;dm_fnRa-PDY^mux4pL&b*iy5G=3QT2p)>q>8|nOCdyPxC&N{^_0kK4xdXCAkgs3dv4%V*L~yi+x3S2s(gty3{GKsD2+oax)dmYf`(f(A>BHB>Z@ zYf|ylsRVSHFZJ5s`x7OFG z*m)vEF<6bWh9YOMD(5V>k+0ODDRY)pO{^{$tk7AvN+)Tavw~7b?45Y0l=?uQ^*SXsXYc^KmBRz1LbG8f_gqH`}#Lh-I)7B)Kbzv7o&s_I}Qu#RZYr zeTY7dvlfK6p`$WHHRyvIHT7yqnhS0=Hr93MhR7{y&9g+`G$)((c7%iVx z^*OElxZ)Lc9aTI+8%=v;PsXQ>iHJy3pO9jbnsdtt}Q04l{cH7b!Gb_ z67sgogB?xS3Swp+yO_A#yi8!M^00|)z{&;!@J;GsKTpYJDd&~xSfjn2)rID zP=-B_ho{2P@LksAUxb@M{Qd2MEIdm6>N+MWKJu(TuJZ$LIHpIP_1cc5rVp~VGTu8@ zc<(iG1ZyM zk}o47^EvG+BiBkyl_tha(R7+V6)f#}d9o{SGAuOPvkLD)+|U#w{e`F#nmp)o+8JwF zVv-*=X2fe@C7RDVCsoYb@y=)7x6uh&pHo|z(d$dCAhzcsEztb5=(h&|dA8O9z(K8JW1j9VfpsMI4nMcP5yXd-|%2$ldzdtIqgFBv48GIvHW ze!9_b-v4+_BorTSYdUls%UHC`W`8xuvC?dfy=``o_g8uP4%=|wxyy`QFX}Dqe!sDY z?-_Q0yV;a$Ib&k){ld{9>!(!voL$8~?BV)`JruvN>+uP@E`PA=_5~X|EOkBVA^Cx& zCtzD{y-r|tX!fuvLtNA0 zYb?7#&R(7K%XK(!6(LG*<;N6mvIAUP9cwrQRnA$XTkyN|S909BwshHMj2(E2zcyZ5 z?epM?Ir#QkG}ACD-nw{e#xY?%(3L|1>9wB8nVAqsk>`Fp?mX9U<#{M>JlE;MbDi!x z*X6oXrUm%_HLfHRI!o>L*3~J^6K$HFQP4*0G}cTjqy`p(rH;%17}d2qKz7Ay;WWni zQa7QzmWtWFyE`$zkQlfqPq0i=YzHV!f3WM#R6p4lEa7(3cdcYrf=#dHa#IB0oN3DX zpJi+2CcFMG$?A~LcUa$l0p#6)x4{8;HEe_v;0SntHU8aj8;IV&168;Po(Ye!zL$6X zAB7LYA}qiwVFI2Fx3RXb!%jFJe#(0OOK>aH;bm|lJjh!9Cb$vAZ+{gEa01-L`u($T zJ1oLW;S_j`wfn8`G59F_HSC50{E|X{6TSiOgAmSwM_H%;EnEke!6rBrPJy4YR{se+ z2!94nFQ4J}uUMZy0{;Qm!`q<=r@?dKKGx=U!s|fR?In02Jj~ksub>Juuo1pOU3>zh zPHvQXF&`QYwJT%k8q46M*9Y!M9eUs2@mu~*iEZAt*DFt^>-^t$!Mtvta8J<}ce2Z; zp9#Yj+iO|}Iwoy?G@X$Z^Cdy!F(+O&$V!H{$WaEgS?l?A3yv?$``OiiSEsvGd8xRG zOO1tURIc)pGVd?dI;OKsDIDpGRedV|;2?F~6) zeW4F1y1Sd(Z0u*~(T8YtN1?rcBrnP%Y?K5tNU>ZM@Au2!$RFD~Y5k!S6P>>v(U5=$ zOZFvX2=@-fiIbNG8#7^6{L$Fj=6(FV5_Bpa%{o3(Znyc3Gw#Q>xcxYjhh^=j?}BMR zB|HLDp#~LkdXNsKa~xMcN6WLYk#fdEqc4g0w_as*95^<*kx^WyRJ`jH`3D*n=d{C= ztf^=FD(RS1F0zxiNBgdj)mT<3w=7&(-?}Ix{c^vaHlB{*w#6jwe7A69UzGDts_`~& zTpXpR0_HhfzvCIN+8pF}W|SF!;&Dni=BKpsNlD^aQqE3~SOFg}-=@7C zOgx#&UJ*w+?rI{N>OkV);D}Uy;si}G*(a>ln;m-Zv8?m60vg=kB;@(HcReO!%Q_$5 zgxAyQB%TFZXTN(7t?4+|{&#R_MeMok*4{W8?5_o0YvbI7vuka<<7>CwT9-56Ss(C( zi_}+rn&VAUY+o22$qUbo{z-FAKZMcjzIra2t=Go4@v}x3$G?SvA0XGCn7(jSjmemi zm2SFGL@CL?-XIEi4%yveCfl!%AC)9W)`~nMp+;B88W#sE$xweac$T}HwW+E^!BSRd zn$p=H112CY9mfalA~oJ{vCPi#2 z4z)6?7C|C1YQ8}wFmr*mq)wNH3jnDBQwRyqreHe^X0@Lf?RH6`*ZZ*2a7+&<|AYYl zZUg_F&Hr{oe}%cipxmL3cjHVv;;$`VcgrgBUDi5lzIE310BmY+jIqWK_-SL1QP&Fk z8oSJ&j8`<=BSiaV*A=yPZr2w-YdxoeHOsxssufK>Crm|a5fzy#`sal<%cMf@@5+4m zFgp!KtPXD6T&@Om8cz1M<8V#g9>!gYY;}_jt7Osd@n_{le$%|hvNGMEZ{Upn*Rf3T z#djd@q;E*MmJ;7ou2JglNBf#8Bz=XDkohED(w7q(vUc{lL7aqnUs*bazOX2d5FSqU zWuomBPZG^4my)^gH=$P)pIZg*{G%^1-B%g+P6e?m)LxzZ{gP;Et7T5Ae;Gb tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" batch_size, time, head_dim = q.shape @@ -371,7 +372,6 @@ def kv_index_map(b, i, layer_ref, start_ref, else: kv_bp = (None, bk, head_dim) ks_bp = (None, 1, bk) - #import pdb; pdb.set_trace() in_specs = [ pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), pl.BlockSpec(kv_index_map, kv_bp), @@ -416,6 +416,7 @@ def kv_index_map(b, i, layer_ref, start_ref, ], grid=(batch_size, seq_len // bk), ), + interpret=testing, compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, out_shape=[ q, @@ -430,7 +431,7 @@ def kv_index_map(b, i, layer_ref, start_ref, return out, (m[..., 0], l[..., 0]) @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "q_shard_axis", "kv_shard_axis"] + jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "q_shard_axis", "kv_shard_axis", "testing"] ) def ragged_mha( q: jax.Array, @@ -448,6 +449,7 @@ def ragged_mha( normalize_var: bool = True, q_shard_axis:int = 0, kv_shard_axis:int = 0, + testing: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi head attention. Args: @@ -488,7 +490,6 @@ def ragged_mha( hkv = k.shape[-3] rep = hq // hkv if rep > 1: - #import pdb; pdb.set_trace() q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) with jax.named_scope("ragged_mha_vmap"): @@ -499,6 +500,7 @@ def ragged_mha( bk=bk, mask_value=mask_value, normalize_var=normalize_var, + testing=testing, # out_dtype=out_dtype, ), in_axes=( @@ -588,7 +590,6 @@ def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, n denominator = unnormalized.sum(axis=-1, keepdim=True) unnormalized = unnormalized * v_scaler.reshape(v_scaler.shape[0], 1, 1, v_scaler.shape[2]) - #import pdb; pdb.set_trace() o = ( torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) / denominator @@ -601,7 +602,7 @@ class RaggedAttentionKernel: """Ragged attention kernel.""" def __init__(self, env, input_specs, output_specs, q_shard_axis, kv_shard_axis): self.binded_ragged_mha = functools.partial( - ragged_mha, bk=env.block_size, q_shard_axis=q_shard_axis, kv_shard_axis=kv_shard_axis + ragged_mha, bk=env.block_size, q_shard_axis=q_shard_axis, kv_shard_axis=kv_shard_axis, testing=env.testing ) self.binded_ragged_mha = shard_map( self.binded_ragged_mha, env.mesh, input_specs, output_specs, check_rep=False diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index bf59d3bf..8daf991f 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -111,11 +111,11 @@ def __init__( donate_argnums=(0, 1), out_shardings=self.get_decode_state_sharding(), ) - self.generate = jax.jit( - self.generate, - donate_argnums=(1,), - out_shardings=(self.get_decode_state_sharding(), None), - ) + # self.generate = jax.jit( + # self.generate, + # donate_argnums=(1,), + # out_shardings=(self.get_decode_state_sharding(), None), + # ) # self._insert_wrap = jax.jit(self._insert_wrap, donate_argnums=(0, 1), # out_shardings=self.get_decode_state_sharding()) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 1ec23d50..eb26f37a 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -400,10 +400,19 @@ def __init__(self, env, layer_id): others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention self.flash_attention = ak.flash_attention - self.ragged_attention = ak.RaggedAttentionKernel( + self.ragged_attention_orig = ak.RaggedAttentionKernel( env, input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.kv_shard_axis, + ) + self.ragged_attention_new = ak.RaggedAttentionKernel( + env, + input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.q_shard_axis, ) self.layer_id = layer_id @@ -432,23 +441,17 @@ def __call__( kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads - # if not self.env.ragged_mha and seqlen == 1: - # xq_expanded = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - # else: - # xq_expanded = xq - def attend(xq, keys, values, local_mask=None): # As of right now, ragged attention doesn't support attention calculation with prefill and new cache line # We are not using ragged attention for prefill yet. - kv_shard_axis = self.kv_shard_axis - if self.kv_shard_axis > 0: - if keys.ndim == 4: - kv_shard_axis = 1 - else: - kv_shard_axis = 2 + if keys.ndim == 4: + impl = self.ragged_attention_new + else: + impl = self.ragged_attention_orig + if self.env.ragged_mha and seqlen == 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( - self.ragged_attention, + impl, xq, keys, values, @@ -457,8 +460,6 @@ def attend(xq, keys, values, local_mask=None): end, ragged_batch_index, ragged_block_index, - self.q_shard_axis, - kv_shard_axis, ) elif self.env.flash_attention: with torch_xla2.default_env(): @@ -494,7 +495,10 @@ def attend(xq, keys, values, local_mask=None): # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq, orig_keys, orig_values, mask) - + cache_len = orig_keys.shape[-2] + existing_output = existing_output.reshape(bsz, num_heads, seqlen, head_dim) + existing_max = existing_max.reshape(bsz, num_heads, seqlen, 1) + existing_denom = existing_denom.reshape(bsz, num_heads, seqlen, 1) # Updating cache during each step still has very large impact on latency. # For non flash attention or prefill, existing output contains everything if not self.env.lazy_cache_update or seqlen > 1: @@ -506,6 +510,9 @@ def attend(xq, keys, values, local_mask=None): xk = repeat_kv(xk, n_rep) xv = repeat_kv(xv, n_rep) new_output, (new_max, new_denom) = attend(xq, xk, xv, None) + new_output = new_output.reshape(bsz, num_heads, 1, head_dim) + new_max = new_max.reshape(bsz, num_heads, 1, 1) + new_denom = new_denom.reshape(bsz, num_heads, 1, 1) # if cache.cache_k is None: # Prefill # return new_output @@ -618,17 +625,18 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") self.env.apply_sharding(local_output, axis=self.q_shard_axis) return local_output, (local_max, local_denom) - #import pdb; pdb.set_trace() with jax.named_scope("attn_insert_cache"): orig_keys, orig_values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) # We are not using ragged attention for prefill yet. if not self.env.ragged_mha or seqlen > 1: - #import pdb; pdb.set_trace() orig_keys = repeat_kv(orig_keys, n_rep) orig_values = repeat_kv(orig_values, n_rep) with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq, orig_keys, orig_values, k_scaler, v_scaler, mask) - + cache_len = orig_keys.shape[-2] + existing_output = existing_output.reshape(bsz, num_heads, seqlen, head_dim) + existing_max = existing_max.reshape(bsz, num_heads, seqlen, 1) + existing_denom = existing_denom.reshape(bsz, num_heads, seqlen, 1) # For non flash attention or prefill, existing output contains everything if not self.env.lazy_cache_update or seqlen > 1: return existing_output @@ -639,8 +647,9 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): new_key = repeat_kv(new_key, n_rep) new_value = repeat_kv(new_value, n_rep) new_output, (new_max, new_denom) = attend(xq, new_key, new_value, new_k_scaler, new_v_scaler, None) - # if cache.cache_k is None: # Prefill - # return new_output + new_output = new_output.reshape(bsz, num_heads, 1, head_dim) + new_max = new_max.reshape(bsz, num_heads, 1, 1) + new_denom = new_denom.reshape(bsz, num_heads, 1, 1) with jax.named_scope("attn_global"): global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) diff --git a/jetstream_pt/third_party/llama/model_args.py b/jetstream_pt/third_party/llama/model_args.py index 7956667d..1b72c0a7 100755 --- a/jetstream_pt/third_party/llama/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -45,7 +45,8 @@ def get_arg( "dim": 128, "vocab_size": 32000, "multiple_of": 32, - "n_heads": 8, + "n_heads": 64, + "n_kv_heads": 8, "n_layers": 3, "norm_eps": 1e-05, } diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 916485f3..4991b681 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -23,13 +23,12 @@ import torch_xla2 from torch.utils import _pytree as pytree - from jetstream_pt.engine import PyTorchEngine from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt.third_party.llama.generation_original import LlamaOriginal from jetstream_pt import environment from tests import helpers - +from jetstream_pt import torchjax class LlamaE2ETest(unittest.TestCase): """This test class includes all E2E test for llama2""" @@ -187,6 +186,9 @@ def _llama_e2e(self, env, model_arg): model_ours = model_exportable.Transformer(model_arg, env) + for k, v in model_ours.state_dict().items(): + if "scale" in k: + state_dict[k] =helpers.to_xla_tensor(v) engine = PyTorchEngine(pt_model=model_ours, env=env) params = self._from_torch(state_dict) @@ -351,6 +353,43 @@ def update_env_data(env_data): out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertEqual(out_tokens, expected_output_tokens) + def test_llama_e2e_int8_left_aligned_lazy_cache_update_generate_cache_stacked_new_cache_nonstacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.ragged_mha=False + env_data.flash_attention=True + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=False + env_data.lazy_cache_update=True + env_data.quant_config.enable_kv_quantization=True + env, model_arg = helpers.make_env_tiny(True, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + + def test_llama_e2e_int8_left_aligned_lazy_cache_update_all_cache_stacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.ragged_mha=False + env_data.flash_attention=True + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=True + env_data.lazy_cache_update=True + env_data.quant_config.enable_kv_quantization=True + env_data.ragged_mha=True + + env, model_arg = helpers.make_env_tiny(True, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + def test_llama_e2e_bfloat16(self): "end to end jetstream llama test with bfloat16" From 58dda184695cf91deebccf217a994a99acac60c2 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 17 Jul 2024 19:22:59 +0000 Subject: [PATCH 38/57] Fixes the attention output slicing issue when not using flash attention. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc. --- jetstream_pt/attention_kernel.py | 139 ++++++++++++++----------------- jetstream_pt/engine.py | 10 +-- jetstream_pt/layers.py | 15 ++-- tests/test_llama_e2e.py | 20 +++++ tests/test_quantization.py | 3 +- 5 files changed, 98 insertions(+), 89 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index cdd33f64..e6eb91c2 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -16,6 +16,7 @@ def ragged_flash_attention_kernel( + layer_ref, start_ref, end_ref, line_end_ref, @@ -105,7 +106,7 @@ def run(): @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var"] + jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "testing", "quantized"] ) def ragged_mqa( q: jax.Array, @@ -121,6 +122,8 @@ def ragged_mqa( bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, + testing: bool = False, + quantized: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" with jax.named_scope("ragged_mqa"): @@ -134,6 +137,7 @@ def ragged_mqa( def kv_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -143,12 +147,13 @@ def kv_index_map( index = b * (seq_len // bk) + i if stacked: - return layer, ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 + return layer_ref[0], ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 def q_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -157,12 +162,12 @@ def q_index_map( ): index = b * (seq_len // bk) + i if stacked: - return layer, ragged_batch_index_ref[index], 0, 0 + return layer_ref[0], ragged_batch_index_ref[index], 0, 0 return ragged_batch_index_ref[index], 0, 0 - def scaler_index_map(b, i, *_): + def scaler_index_map(b, i, layer_ref, *_): if stacked: - return layer, b, 0, i + return layer_ref[0], b, 0, i return b, 0, i line_end = jnp.where(start < end, end, seq_len - 1) @@ -181,6 +186,8 @@ def scaler_index_map(b, i, *_): pl.BlockSpec(q_index_map, q_bp), pl.BlockSpec(kv_index_map, kv_bp), pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(scaler_index_map, ks_bp), + pl.BlockSpec(scaler_index_map, ks_bp), ] inputs = ( start, @@ -191,15 +198,9 @@ def scaler_index_map(b, i, *_): q, k, v, + k_scaler, + v_scaler ) - quantized = False - if k_scaler is not None: - in_specs = in_specs + [ - pl.BlockSpec(scaler_index_map, ks_bp), - pl.BlockSpec(scaler_index_map, ks_bp), - ] - inputs = inputs + (k_scaler, v_scaler) - quantized = True out, m, l = pl.pallas_call( functools.partial( @@ -220,8 +221,7 @@ def scaler_index_map(b, i, *_): grid=(batch_size, seq_len // bk), ), compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, - #interpret=True, - #debug=True, + interpret=testing, out_shape=[ q, jax.ShapeDtypeStruct( @@ -311,7 +311,7 @@ def run(): (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe ).astype(o_ref.dtype) -@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "testing"]) +@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "testing", "quantized"]) def ragged_mqa_reference( q: jax.Array, k: jax.Array, @@ -321,17 +321,17 @@ def ragged_mqa_reference( end: jax.Array, ragged_batch_index=None, ragged_block_index=None, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, + k_scaler: jax.Array = None, + v_scaler: jax.Array = None, bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, testing: bool = False, + quantized: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" batch_size, time, head_dim = q.shape #assert end.shape == (batch_size,) - assert end.dtype == jnp.int32 seq_len = k.shape[-2] stacked = False @@ -358,9 +358,7 @@ def _compute_ragged_block_indices(b, i, lengths_ref): def kv_index_map(b, i, layer_ref, start_ref, end_ref, - line_end_ref, - ragged_batch_index_ref, - ragged_block_index_ref): + *_): b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) if stacked: return layer_ref[0], b_next, i_next, 0 @@ -372,10 +370,13 @@ def kv_index_map(b, i, layer_ref, start_ref, else: kv_bp = (None, bk, head_dim) ks_bp = (None, 1, bk) + in_specs = [ - pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), - pl.BlockSpec(kv_index_map, kv_bp), - pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), # k + pl.BlockSpec(kv_index_map, kv_bp), # q + pl.BlockSpec(kv_index_map, kv_bp), # v + pl.BlockSpec(kv_index_map, ks_bp), # k_scaler + pl.BlockSpec(kv_index_map, ks_bp), # v_scaler ] inputs = ( @@ -388,15 +389,9 @@ def kv_index_map(b, i, layer_ref, start_ref, q, k, v, + k_scaler, + v_scaler, ) - quantized = False - if k_scaler is not None: - in_specs = in_specs + [ - pl.BlockSpec(kv_index_map, ks_bp), - pl.BlockSpec(kv_index_map, ks_bp), - ] - inputs = inputs + (k_scaler, v_scaler) - quantized = True out, m, l = pl.pallas_call( functools.partial( @@ -473,24 +468,35 @@ def ragged_mha( softmax denominator ([batch_size, num_heads, compute_dim, 1]). """ mask_value = DEFAULT_MASK_VALUE - if k_scaler is None: - replicated_in_axes = 5 - replicated_inputs = (ragged_batch_index, ragged_block_index) - else: - replicated_in_axes = 7 - replicated_inputs = ( - ragged_batch_index, - ragged_block_index, - jnp.squeeze(k_scaler, -1), - jnp.squeeze(v_scaler, -1), - ) - # New cache has t=1 bk = min(bk, k.shape[-2]) bq, hq, tq, dq = q.shape hkv = k.shape[-3] + tk = k.shape[-2] rep = hq // hkv if rep > 1: q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) + + replicated_in_axes = 7 + if k_scaler is None: + quantized = False + if k.ndim == 5: + kv_scale_shape = (1, 1, 1, tk) + else: + kv_scale_shape = (1, 1, tk) + k_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) + v_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) + else: + quantized = True + k_scale = jnp.squeeze(k_scaler, -1) + v_scale = jnp.squeeze(v_scaler, -1) + + replicated_inputs = ( + ragged_batch_index, + ragged_block_index, + k_scale, + v_scale, + ) + # New cache has t=1 with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( @@ -501,6 +507,7 @@ def ragged_mha( mask_value=mask_value, normalize_var=normalize_var, testing=testing, + quantized=quantized, # out_dtype=out_dtype, ), in_axes=( @@ -540,34 +547,14 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): output = torch.einsum("ikjm,ikml->ikjl", scores, values) return output -def flash_attention(xq, keys, values, mask=None, normalize_var=True): - """The vanilla attention kernel implementation.""" - # mask_value: float = DEFAULT_MASK_VALUE - logits = torch.einsum( - "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) - ) - - if normalize_var: - logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama - if mask is not None: - # logits = logits + torch.where(mask, 0.0, mask_value)[:, None] - logits = logits + mask - - logits_max, _ = torch.max(logits, axis=-1, keepdim=True) - # unnormalized = torch.exp(logits - logits_max[..., None]) - unnormalized = torch.exp(logits - logits_max) - denominator = unnormalized.sum(axis=-1, keepdim=True) - # print(f"logits {logits.shape} logits_max {logits_max.shape} denominator {denominator}") - o = ( - torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) - # / denominator[..., None] - / denominator - ) - return o, (logits_max, denominator) - - -def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, normalize_var=True): +def flash_attention(xq, keys, values, layer, k_scaler = None, v_scaler= None, mask=None, normalize_var=True): mask_value: float = DEFAULT_MASK_VALUE + if keys.ndim == 5: + keys = keys[layer] + values = values[layer] + k_scaler = k_scaler[layer] if k_scaler is not None else None + v_scaler = v_scaler[layer] if v_scaler is not None else None + logits = torch.einsum( "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) ) @@ -575,8 +562,9 @@ def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, n if normalize_var: logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama # Quantized - bs, hs, ls, ds = k_scaler.shape - logits = logits * k_scaler.reshape(k_scaler.shape[0], 1, 1, k_scaler.shape[2]) + if k_scaler is not None: + bs, hs, ls, ds = k_scaler.shape + logits = logits * k_scaler.reshape(k_scaler.shape[-4], 1, 1, k_scaler.shape[-2]) # mask = jnp.arange(keys.shape[1])[None] < lengths[:, None] if mask is not None: @@ -589,7 +577,8 @@ def flash_attention_quantized(xq, keys, values, k_scaler, v_scaler, mask=None, n # unnormalized = unnormalized * v_scaler denominator = unnormalized.sum(axis=-1, keepdim=True) - unnormalized = unnormalized * v_scaler.reshape(v_scaler.shape[0], 1, 1, v_scaler.shape[2]) + if v_scaler is not None: + unnormalized = unnormalized * v_scaler.reshape(v_scaler.shape[-4], 1, 1, v_scaler.shape[-2]) o = ( torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) / denominator diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 8daf991f..bf59d3bf 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -111,11 +111,11 @@ def __init__( donate_argnums=(0, 1), out_shardings=self.get_decode_state_sharding(), ) - # self.generate = jax.jit( - # self.generate, - # donate_argnums=(1,), - # out_shardings=(self.get_decode_state_sharding(), None), - # ) + self.generate = jax.jit( + self.generate, + donate_argnums=(1,), + out_shardings=(self.get_decode_state_sharding(), None), + ) # self._insert_wrap = jax.jit(self._insert_wrap, donate_argnums=(0, 1), # out_shardings=self.get_decode_state_sharding()) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index eb26f37a..6ae2826b 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -463,7 +463,7 @@ def attend(xq, keys, values, local_mask=None): ) elif self.env.flash_attention: with torch_xla2.default_env(): - local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, mask=local_mask) + local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, self.layer_id, mask=local_mask) else: if seqlen == 1: xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) @@ -539,7 +539,7 @@ def __init__(self, env, layer_id): kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention - self.flash_attention = ak.flash_attention_quantized + self.flash_attention = ak.flash_attention self.ragged_attention_orig = ak.RaggedAttentionKernel( env, input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), @@ -608,7 +608,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): local_denom = local_denom.reshape(*local_denom.shape, 1) elif self.env.flash_attention: with torch_xla2.default_env(): - local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, k_scaler, v_scaler, mask=local_mask) + local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, self.layer_id, k_scaler, v_scaler, mask=local_mask) else: local_output = self.dense_attention(xq, keys, values, k_scaler, v_scaler, local_mask) local_max = None @@ -620,9 +620,6 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): local_max = local_max[:, :, 0:1, :] local_denom = local_denom[:, :, 0:1, :] - # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") - # if local_max is not None and local_denom is not None: - # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") self.env.apply_sharding(local_output, axis=self.q_shard_axis) return local_output, (local_max, local_denom) with jax.named_scope("attn_insert_cache"): @@ -635,14 +632,16 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): existing_output, (existing_max, existing_denom) = attend(xq, orig_keys, orig_values, k_scaler, v_scaler, mask) cache_len = orig_keys.shape[-2] existing_output = existing_output.reshape(bsz, num_heads, seqlen, head_dim) - existing_max = existing_max.reshape(bsz, num_heads, seqlen, 1) - existing_denom = existing_denom.reshape(bsz, num_heads, seqlen, 1) # For non flash attention or prefill, existing output contains everything if not self.env.lazy_cache_update or seqlen > 1: return existing_output # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): + # At this point, flash attention or ragged attention must have been enabled + existing_max = existing_max.reshape(bsz, num_heads, seqlen, 1) + existing_denom = existing_denom.reshape(bsz, num_heads, seqlen, 1) + if not self.env.ragged_mha or seqlen > 1: new_key = repeat_kv(new_key, n_rep) new_value = repeat_kv(new_value, n_rep) diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 4991b681..5fbbc42e 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -371,6 +371,26 @@ def update_env_data(env_data): self.assertEqual(out_tokens, expected_output_tokens) + def test_llama_e2e_int8_left_aligned_lazy_cache_update_all_cache_stacked(self): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer=False + env_data.ragged_mha=False + env_data.flash_attention=True + env_data.generate_cache_stacked=True + env_data.new_cache_stacked=True + env_data.lazy_cache_update=True + env_data.quant_config.enable_kv_quantization=True + env_data.ragged_mha=False + + env, model_arg = helpers.make_env_tiny(True, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + def test_llama_e2e_int8_left_aligned_lazy_cache_update_all_cache_stacked(self): """end to end jetstream llama test with float32""" jax.config.update("jax_platform_name", "cpu") diff --git a/tests/test_quantization.py b/tests/test_quantization.py index f53543d1..c9336628 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -138,8 +138,9 @@ def update_env_data(env_data): env_data.generate_cache_stacked=True env_data.new_cache_stacked=True env_data.lazy_cache_update=True - env_data.quant_config.enable_kv_quantization=False + env_data.quant_config.enable_kv_quantization=True env_data.batch_size = 4 + env_data.ragged_mha = True env, _ = helpers.make_env_tiny(False, update_env_data) batch = env.batch_size From ba80c1960336cb8e807a525970f6e234ac068f57 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 18 Jul 2024 18:05:37 +0000 Subject: [PATCH 39/57] Fix the pallas kernel OOB issue --- benchmarks/run_offline.py | 20 ++++++++++-------- jetstream_pt/attention_kernel.py | 34 ++++++++++++++++++++++++------- jetstream_pt/layers.py | 31 ++++++++++++++++++---------- keys_original | Bin 0 -> 11706 bytes original_scores | Bin 0 -> 6468 bytes tests/test_quantization.py | 2 +- 6 files changed, 60 insertions(+), 27 deletions(-) create mode 100644 keys_original create mode 100644 original_scores diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 2abda049..e06ac30f 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -52,14 +52,18 @@ def run_prefill_time(engine, params, decode_state, seqlen): nums = 5 start = time.perf_counter() - for i in range(nums): - prefill_result, _ = engine.prefill( - params=params, padded_tokens=tokens, true_length=true_length - ) - decode_state = engine.insert( - prefill_result, decode_state, slot=jnp.int32(i) - ) - jax.block_until_ready(decode_state) + if FLAGS.profiling_prefill: + for i in range(nums): + if i == nums - 1: + jax.profiler.start_trace(FLAGS.profiling_output) + prefill_result, _ = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + decode_state = engine.insert( + prefill_result, decode_state, slot=jnp.int32(i) + ) + jax.block_until_ready(decode_state) + jax.profiler.stop_trace() end = time.perf_counter() return (end - start) / nums, decode_state diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index e6eb91c2..b81acaa6 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -364,6 +364,14 @@ def kv_index_map(b, i, layer_ref, start_ref, return layer_ref[0], b_next, i_next, 0 return b_next, i_next, 0 + def kv_scale_index_map(b, i, layer_ref, start_ref, + end_ref, + *_): + b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) + if stacked: + return layer_ref[0], b_next, 0, i_next + return b_next, 0, i_next + if stacked: kv_bp = (None, None, bk, head_dim) ks_bp = (None, None, 1, bk) @@ -372,11 +380,11 @@ def kv_index_map(b, i, layer_ref, start_ref, ks_bp = (None, 1, bk) in_specs = [ - pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), # k - pl.BlockSpec(kv_index_map, kv_bp), # q + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), # q + pl.BlockSpec(kv_index_map, kv_bp), # k pl.BlockSpec(kv_index_map, kv_bp), # v - pl.BlockSpec(kv_index_map, ks_bp), # k_scaler - pl.BlockSpec(kv_index_map, ks_bp), # v_scaler + pl.BlockSpec(kv_scale_index_map, ks_bp), # k_scaler + pl.BlockSpec(kv_scale_index_map, ks_bp), # v_scaler ] inputs = ( @@ -412,6 +420,7 @@ def kv_index_map(b, i, layer_ref, start_ref, grid=(batch_size, seq_len // bk), ), interpret=testing, + #debug=True, compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, out_shape=[ q, @@ -470,19 +479,25 @@ def ragged_mha( mask_value = DEFAULT_MASK_VALUE bk = min(bk, k.shape[-2]) bq, hq, tq, dq = q.shape + dk = k.shape[-1] hkv = k.shape[-3] tk = k.shape[-2] + + assert k.shape[-1] == q.shape[-1] + assert k.shape[-4] == q.shape[-4] + rep = hq // hkv if rep > 1: q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) - + stacked = True if k.ndim == 5 else False + replicated_in_axes = 7 if k_scaler is None: quantized = False if k.ndim == 5: - kv_scale_shape = (1, 1, 1, tk) + kv_scale_shape = (k.shape[0], bq, 1, tk) else: - kv_scale_shape = (1, 1, tk) + kv_scale_shape = (bq, 1, tk) k_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) v_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) else: @@ -490,6 +505,11 @@ def ragged_mha( k_scale = jnp.squeeze(k_scaler, -1) v_scale = jnp.squeeze(v_scaler, -1) + if stacked: + assert k_scale.shape == (k.shape[0], bq, 1, tk) + else: + assert k_scale.shape == (bq, 1, tk) + replicated_inputs = ( ragged_batch_index, ragged_block_index, diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 6ae2826b..4b7e6c99 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -586,8 +586,12 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): impl = self.ragged_attention_new else: impl = self.ragged_attention_orig - if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + #if not self.env.ragged_mha and seqlen == 1: + true_len = seqlen + if n_rep == 1 and seqlen == 1: + true_len = 2 + xq = torch.nn.functional.pad(xq, (0, 0, 0, true_len - seqlen), "constant", 0) + #xq = torch.broadcast_to(xq, (bsz, num_heads, true_len, head_dim)) # We are not using ragged attention for prefill yet. if self.env.ragged_mha and seqlen == 1: @@ -604,8 +608,8 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): k_scaler, v_scaler, ) - local_max = local_max.reshape(*local_max.shape, 1) - local_denom = local_denom.reshape(*local_denom.shape, 1) + #local_max = local_max.reshape(*local_max.shape, 1) + #local_denom = local_denom.reshape(*local_denom.shape, 1) elif self.env.flash_attention: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, self.layer_id, k_scaler, v_scaler, mask=local_mask) @@ -614,7 +618,12 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): local_max = None local_denom = None - if local_output.shape[-2] == 2: + local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) + if local_max is not None: + local_max = local_max.reshape(bsz, num_heads, true_len, 1) + local_denom= local_denom.reshape(bsz, num_heads, true_len, 1) + + if true_len != seqlen: local_output = local_output[:, :, 0:1, :] if local_max is not None: local_max = local_max[:, :, 0:1, :] @@ -631,7 +640,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend(xq, orig_keys, orig_values, k_scaler, v_scaler, mask) cache_len = orig_keys.shape[-2] - existing_output = existing_output.reshape(bsz, num_heads, seqlen, head_dim) + #existing_output = existing_output.reshape(bsz, num_heads, seqlen, head_dim) # For non flash attention or prefill, existing output contains everything if not self.env.lazy_cache_update or seqlen > 1: return existing_output @@ -639,16 +648,16 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): # At this point, flash attention or ragged attention must have been enabled - existing_max = existing_max.reshape(bsz, num_heads, seqlen, 1) - existing_denom = existing_denom.reshape(bsz, num_heads, seqlen, 1) + #existing_max = existing_max.reshape(bsz, num_heads, seqlen, 1) + #existing_denom = existing_denom.reshape(bsz, num_heads, seqlen, 1) if not self.env.ragged_mha or seqlen > 1: new_key = repeat_kv(new_key, n_rep) new_value = repeat_kv(new_value, n_rep) new_output, (new_max, new_denom) = attend(xq, new_key, new_value, new_k_scaler, new_v_scaler, None) - new_output = new_output.reshape(bsz, num_heads, 1, head_dim) - new_max = new_max.reshape(bsz, num_heads, 1, 1) - new_denom = new_denom.reshape(bsz, num_heads, 1, 1) + #new_output = new_output.reshape(bsz, num_heads, 1, head_dim) + #new_max = new_max.reshape(bsz, num_heads, 1, 1) + #new_denom = new_denom.reshape(bsz, num_heads, 1, 1) with jax.named_scope("attn_global"): global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) diff --git a/keys_original b/keys_original new file mode 100644 index 0000000000000000000000000000000000000000..d18bd8dd3ddcfb0140450b3cfde59b54a931865f GIT binary patch literal 11706 zcmbW-3vg7`83*u9AnzhV1A$hAXuu5#AtdBwINuk71e~2fcn(l831LH5R+dc`3uVRz z0-*#D2#gksAQ3yoDo<;9?ESvB+6PlR2#mGSK1OAVwibmJg;_g{qRC%gWq4Rn3Ywb;9&xi-+`TjX0(?;pLeK2)E#bg{pxGTAln zMr(<02nCuNzMWUQEadYyr`9%k7cBGn7uJTnOPT{swT~2}23>CVoFtpyY}c$`Tqu~dzM&!D_j?;czQB^^)TvDiy-nVQ_xTz^slnvz>A{pxFm-w`Ew9Ft zlsxNN(Sm(8Iie=!hG)3?M{W3m6(MgR5?obz+s!S%z61QeP{{9%gs(O44o7FG+U8hhW9#1GnFfHCd`T)mQ8~EjOHq0s9F%WtPpDTsIwuubD=h{&WriqnE}XzffBlq@H)R+hXX*EZjG-Zl9Opzu#yejZwjAUxLA}Q=*q=s7wnZ)e|nJ0r7X*PL;tl&IF-lCrf za{72*MjoIS%8`AvOc9UlWMmT$BV-Z{S7c4tGM(PDYQuL4@^W;rJQY4d+&uAJW)72qlu9+#wY+5M@ z+p}*ZP2(APQc4`8gGVUxSoo|#-29$FM#@ox>~+2`$TH_0LH=f56l5MTBa2Odk)w_X zl1Ud0@(<}Eq=Iu93Go33$>#qE^0;YH;^0fIxkTqt$L6$pH9Hdb16692Pydd4q zDMEIXPhsR6yoQn2ox2tJttlgJPFluTmeGYZ)GDYhcMLy>qMWM`1rF32uH zHfp0H2kDqW3OQMkvre5s#+gb+g7knOZMu$-3|>yiIeT@Bd`uq+GD=GX`M{q4BbUji z$Pb*;gtYR%1bJ0{PDmemMv+10VL@J{e+x2$x(%|ETNp{A{fr!uI~A$ZIzig_nuE2O zHhYB=%T=l+WV=0oM!Km*ky3t85tluGLNetqg3RXwjD*CeNS|w*el%F zI^`om()q`P)XC$7r0XF@9%5$Xi21-EljyJ@iPEOXNISbDPtk`48O+&)OjTB7G99o_ zP-4rGfP?Iij|Dj#{Ioc%1E3})C zM!rXpH2s8-a_z8B;A4B2o)E+<`wa3eZYQLj?`7l)u^`#{XF@up4^?5;d(WNkQLXSAeHt8f}7>`n<)6PD~ZrWsLpx7p9nIhTz z8$#Zf`w4m3&K^jJ&l{vTJdu%2TA;{f`}_d8CUJz!aOxBZaFQacb(a1A#CB2KicA=n z$jCh1t;k6;hmdL~m5{4aOh^O&(Lsi4jv~u-F(X}04k5pz)q=bs^BEb-2kaBv*e-K2 zBa3*XAhX4<$XTigBLg&vklXn*A(6^Ur#t1KE8CV`TCpoH2e?dB<>e#Zre<3-T2)#0 zf0bT-UYJ>!8X3XqU5jSL+W|zx9zm8XH);fVpBH~Ndz*aqe&2WNBPFc0zJ*nl?yo$c z{GyWL(xTGRqLRYWqO!8$f|AmLeA~v1DJ(86EGQ@(TUJn3TviepWuHS^JKT1J5jCE; z$Y5_){@z*|J23J?kM`wtz|k4ER={TT=zrCI>X~t~q&7O_)`OYeBV@KM4(=Ipy>vD@ z;nssOphv=MBcl^;`~{*TZrxK^JtC&(L`Qu2F|6?450~ literal 0 HcmV?d00001 diff --git a/original_scores b/original_scores new file mode 100644 index 0000000000000000000000000000000000000000..bd56f64431517fd785076a21d79b25f6016f6c16 GIT binary patch literal 6468 zcmbVR3rv<*ljbgR(|SWJUQj{QB7#7H@5~&6E5;j(Kdxmpm8cM8C8AvP*J!*U(jt{o z6!C&Zs)aL=qLb-E(PWCd_ zUmo#EPRYxkUlbLW{>yovWo{X+a@+W2?iud#xo~OmGLIQaK1ItsXHIhopXNF(NsiG) z%e=zn?6lFcyB00$@zKJC`FVNS3yX5|mlXQ^b3s;iL3Y+Z`?ZZ_9cXE%W)? zGT(7&(|UMs`Hutt`juGjKGObj;z2IG|1|LhD~q!8yG^Gg`~5oeU*?~eTU3;n-EDjm z{8x~d816p)+V}vMxIY>1_G;w#l;r3?SpLQS|MoZ8cEWmmi4~Bp``Jl4$d1EccGzJ1 zhG$}`{gN8-H1^Ubu$eRX5?;V$uI6Yuiu2G5zrthejbHE?_yOyp$#$+z#z^|a`6F+@VCn=<^3<;O7}cl}JjB~D#`eY}ybTxd3QR&zuE5#ORkT3u zr$Vm5@qC?6liZC*a3WsfQxGiknu_gIrnYh@rqN;;!56Wf57AWWhr4+-de{(jLK7TB zKa9avDnl=|mqO?9Q45>lz}rGIkeUs zq%avnGga^b8coUg0k*(;o(Lzo9dAGc6+4jc!Fi)#K814_mg8J%PzI9l9F6BB+J_r> z7wv_|@HOXS3hY%)>MNd&A$T6s@gNs+I#;O>>d#f|k1g_S6`iA8p3G0pdfbI)xr;~O zD928T+*q7}Lu7s2^f5S189Y><;c~iWQq3K_&bN4ioxobxVH1YQycgjkoWtQ*PCLbe zHP|e9(#G%MRD8ozIDtFmnTHUrb9tb>0-svt7!Qw8aLQ z%K_}6_wWg5husoSsPnPS+~PV&paGokFh)CWQ8K)O6UWDMJ7mBdT0}$9m0z0aW{*Tn zBONuZh`>0W3|1_V{|zl;A+|+8NZwrynza=$C26-6R3`t)7K^t`$8lf zqL&at<9GtZU_Mn-6(qr9a>XE6uUw&wgJ2iMVkw?*q;V&0gF1(yVf;OKQ=_RPLzB6l zZYK7?8Vb`{e2rIewyx&wCQ{#%6+FgO;>ROOGx8IfL!q1{zT2oBRA)i&XzWLyn=Wb)^w_FB z^bC90?!w-Dh9`2hN#NBgL@ZB|oUM__?jbR~4>p=bY6xkrH&baCM&VenbY4Am?j#?v zwbR@%wdNYt!#Q=;%*Gh5#NnJlJE&P*h22mAskn+R!wOD7AO2Y4VK^s1wPfxhI7t1t z#o55td@ufPz%#fT+rWbqdx=L6P&z&bA9mL-IZiiVD;@(kt`q!O#0J|1rONp!PN7F! z%Bx`kucJ%2-keSRoCEARx`Ofs!N8C47OYbhv=j>EOfz^1CPwP{c8adYYgn$T1vRV5 z&pPo1kJdi6pS}P!mg%DE%a%%^FBOwNCG$R$gfC_OsggYlAlWohy^4VabP_he4f+(f z^GO^ncoqVUG@3nN5KqKh@p%N7;7GHNe^faJWgCfmAs9z-FSu)ZbAV(Q@&ukO@jC>^ zI2)kDyvGY9cfz?J?B`LE+20WxFAR{3N@Bz^XmvUy`U3?qyi7Tt;9)kL4{;w|iEfw- zvHTSd#Z1(2O;FZ@8hI*yA)^|ichD>0tS~`(2QJ4_YQa~i@iO(3x%-0PnD#Y2-*<;m zumV0+{UMua8wTuSZ7hm$WIHdjo&IbCiHdDsv0!BuR#fGc)T3RFtwI3-@M zVFc!2h(vEP52R|If}6M*t9UQYhGDi%H3>SF@&@6Gd96#d_{R~}oEOwGR_9j0$yN`l7UO03x zrkXO<$kUQDj*oXb;IYgU*5KO1(s6=cz3=odmhnhdc2ymqk!RQ;X6V9_M z7D}${$5ptRa=3>rMQ>fDT{yxb@1ZV}DZJ@wX9z~!p!MRr7?^2W=sGv?PK=h-E`=?E zqGcFlFY--}hv(qscu3wn2>f{jMaeoQ&=P*ihU%~iRybb45I%iuE(Wrnu#dYVM)jsW z_>jJkQB36C;>BBn6??#y%P|Y{#83C(T+0o5MGuub2axCq;mdj4$@}Ru{0O!?I-n1C zCCm^O_7J}Eatu`-R4rUyBQYGq9vp{1(juY$?!s0s=MFlF z?z)4Y;Yu81Ehnp3eMI}|6*iS4p|9j?1ccgN`ZS;81jl;v!wfq@@3wtKqxk6z`%E-k zndr6*JJC+D)e;-|_(r`Wh~iGGISGH1ReTTI(8q>x3%Lricr#O!cr1o{<6%)E;j${d z?0&2?ois+Ga|35%fkFskZ*WmB)G%^|V=&@pO%^5m+pH3x15$<0{kYZn#)RSr#-CQ3 zQ)Ue8RMGsY+Ck?{Bztm}VBG=O1v{yON+hSXu;)HXlN`^bdH4)p-~+V$0!#T*>?J5) zz(zPXUJcd_jze4}<9Y-oILwC18a-1xggL$zET4r3@m-F$i*0`#VjDyQ98upvC2ge^ zx<)5umFCJk+mr`Ki}o;(L32TqQYB~+i z=xewGkKhZq1uZn$ghH!$V5CYsJi z`J}MMDcs9J)JDbpj8=)-Xi%l#Xp&Y@ z58DJWqa>dnL#)2W-hyT=7z<q=({g+$pueS$sv>hDbDQ#>sO1DI9>Qe3oKtvuLmy z>x$n9E_KN&9i%j93`0KOOj+5Xlmgw`6Nv~;(u;pfcNAT?%(%Rp)aE>nI8TulN;y^#)%o=&ZO{%KR zTN)FPcw3~inWvBCW(P{U2mYu0T5!5Y_ z%=W}{CR^^`OobH5H>9G70w4XZSn+2s`z5=x8zGK>PMm?Mj;T7+l<*#&rG`pX(k`f2 ztv2Fx-U7QYhaMnvj_~a3#PQgHGk6n!Ch=Axm>F)b3f>0t6%@?E*YrYiJyaCMO%4)2 zT*C&sBaF9GYQrCX{BU*7Y)K;*evGBAZk~a%`FpP+)O$(m@|$J)qXq&#y}WdF}EEDRZC(ZJT?2w2~%NiCHe}^ zEyX&B*B8)3*kmp*QhhK~_Fh;hjO)iE;2x|M9&*EvDM?f&aL`RiVDZVgcF;!(!G%r!L^8N$7 z0gu8M_OmUL`*V1v)GxPG347YvoGMX$z?6zt_evk5f|gS+(Zkx_!8&0YR~;_3|778g zbn?I;breGR2DWkwu7YAU0LSP_R1DwI8q>?x^G&G=*FmUs5=wXn+;KKxF`iR%rFIFh z_qg6h*bMs|e!>rxlb7>%oXf8TA7@Kt_9K6c%C41a*om|3k2qGh3C6XG+6oeU?ISDD zfhjiHek|DNE8_~qwfeFy5VSjmnRc!yU8S!(T=f-rkKL>fDO;}R2)1v-yWDKI>i#y~ z4iWsTpdfq14!7YFHxoHiFObOcx9Pk?)Xfwsal})Q=%f)m5CS>OY{!wj7W=9=sgkzh zer^&KwNEuts&KRirb|`nOMUo^tlt!vAWS+(xcEa+l2-()gYgh9F&iNQzK5HX3Qvy@ z;%K!>kg1WA>03${&({f;xQdEQ7i^u2`Fw}^*ea>9-jNQ?GYY{QR4!UK2k+w>>2a*) zkvfUoa5FrBFdh!0RWnyA7H!;H6ygQCEAjRchuX*Frfj2ggfYk<6DxuPO2prhG-`ls842DxAnQcwP;}O^KUi z{)`93XV27p9>=j*FVX17VZxRj)Bt^mkFS<&UMy_-h#xzr;7ohYc!};h%4bBIqtrtI zG*^$bYjr1_#v0VR#MbHykDvRgP0cS5}JJ z?8Bq%0jUBe;shMXA!zA@X#D*WqfvSe?dR^!=ikq{|L3}XE-u~2Uw1wOT^xVvd@j!} zD3rEsO7i>vkJjhUtQf^k1LUkkGfa_8seuA1N7 zM9$33N=Xj;KMrVo?E7)?vGMV-?@x%2O-P7~em_2Xylhe96QbWwNSGj3;>JhECv=Z; zSkRL8a2a9ukv7e>d$7NC-hbbk9@u^158Ge$;or9M`yO=AA1u6-|K7IpTT}aOGrwQZ z0e>(vTN?CloB5@E{kDzYFG=7ZY@~_Fw{83y2yYwty~o1-VC3Kbfsuc_mTA+xyx#lu z*&*%`zghA8_1ph^{#*BN53>FJ*U`#z-9K8mbo{y;-Ivp*d3Lw=T|8y)zW Date: Thu, 18 Jul 2024 22:33:47 +0000 Subject: [PATCH 40/57] Fix tests; Fix lint issues; --- benchmarks/run_offline.py | 31 +- jetstream_pt/attention_kernel.py | 148 ++++--- jetstream_pt/cache_manager.py | 393 +++++++++++++----- jetstream_pt/config.py | 8 +- jetstream_pt/engine.py | 67 +-- jetstream_pt/environment.py | 45 +- jetstream_pt/layers.py | 191 ++++++--- .../third_party/llama/model_exportable.py | 2 +- jetstream_pt/third_party/mixtral/model.py | 5 +- keys_original | Bin 11706 -> 0 bytes original_scores | Bin 6468 -> 0 bytes run_interactive.py | 33 +- tests/test_llama_e2e.py | 222 ++-------- tests/test_quantization.py | 155 +++++-- 14 files changed, 761 insertions(+), 539 deletions(-) delete mode 100644 keys_original delete mode 100644 original_scores diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index e06ac30f..bcbe9704 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -31,6 +31,7 @@ flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") +profiler_started = False def run_prefill_time(engine, params, decode_state, seqlen): """Run prefill and measure time.""" @@ -52,18 +53,19 @@ def run_prefill_time(engine, params, decode_state, seqlen): nums = 5 start = time.perf_counter() - if FLAGS.profiling_prefill: - for i in range(nums): - if i == nums - 1: - jax.profiler.start_trace(FLAGS.profiling_output) - prefill_result, _ = engine.prefill( - params=params, padded_tokens=tokens, true_length=true_length - ) - decode_state = engine.insert( - prefill_result, decode_state, slot=jnp.int32(i) - ) - jax.block_until_ready(decode_state) - jax.profiler.stop_trace() + for i in range(nums): + if i == nums - 1 and FLAGS.profiling_prefill: + jax.profiler.start_trace(FLAGS.profiling_output) + profiler_started = True + + prefill_result, _ = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + decode_state = engine.insert( + prefill_result, decode_state, slot=jnp.int32(i) + ) + jax.block_until_ready(decode_state) + end = time.perf_counter() return (end - start) / nums, decode_state @@ -109,8 +111,9 @@ def main(argv): print("======= decode starting ===") dec_times = [] for i in range(10): - if profiling_output and i == 7: + if profiling_output and i == 7 and not profiler_started: jax.profiler.start_trace(profiling_output) + profiler_started = True start = time.perf_counter() # pylint: disable-next=all decode_state, sampled_tokens = engine.generate(params, decode_state) @@ -120,7 +123,7 @@ def main(argv): dec_times.append(end - start) print(i, "decode time", (end - start)) - if profiling_output: + if profiler_started: jax.profiler.stop_trace() print("prefill ", prefill_times) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index b81acaa6..3acf992c 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -106,7 +106,14 @@ def run(): @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "testing", "quantized"] + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "testing", + "quantized", + ], ) def ragged_mqa( q: jax.Array, @@ -147,7 +154,12 @@ def kv_index_map( index = b * (seq_len // bk) + i if stacked: - return layer_ref[0], ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 + return ( + layer_ref[0], + ragged_batch_index_ref[index], + ragged_block_index_ref[index], + 0, + ) return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 def q_index_map( @@ -172,7 +184,6 @@ def scaler_index_map(b, i, layer_ref, *_): line_end = jnp.where(start < end, end, seq_len - 1) - if stacked: q_bp = (None, None, time, head_dim) kv_bp = (None, None, bk, head_dim) @@ -199,7 +210,7 @@ def scaler_index_map(b, i, layer_ref, *_): k, v, k_scaler, - v_scaler + v_scaler, ) out, m, l = pl.pallas_call( @@ -224,12 +235,8 @@ def scaler_index_map(b, i, layer_ref, *_): interpret=testing, out_shape=[ q, - jax.ShapeDtypeStruct( - (batch_size, time, head_dim), jnp.float32 - ), - jax.ShapeDtypeStruct( - (batch_size, time, head_dim), jnp.float32 - ), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), ], )(*inputs) return out, (m[..., 0], l[..., 0]) @@ -258,6 +265,7 @@ def ragged_mqa_kernel_reference( """Pallas kernel for flash attention.""" b, i = pl.program_id(0), pl.program_id(1) del layer_ref + @pl.when(i == 0) def init(): m_ref[...] = jnp.full_like(m_ref, -jnp.inf) @@ -280,7 +288,7 @@ def run(): ) if normalize_var: - qk = qk / math.sqrt(k.shape[-1]) # Align with meta llama + qk = qk / math.sqrt(k.shape[-1]) # Align with meta llama # Quantized if quantized: qk = qk * k_scaler_ref[...] @@ -291,7 +299,6 @@ def run(): s_curr = jnp.exp(qk - m_curr[..., None]) - l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) # Quantized if quantized: @@ -311,7 +318,17 @@ def run(): (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe ).astype(o_ref.dtype) -@functools.partial(jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "testing", "quantized"]) + +@functools.partial( + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "testing", + "quantized", + ], +) def ragged_mqa_reference( q: jax.Array, k: jax.Array, @@ -331,7 +348,7 @@ def ragged_mqa_reference( ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" batch_size, time, head_dim = q.shape - #assert end.shape == (batch_size,) + # assert end.shape == (batch_size,) seq_len = k.shape[-2] stacked = False @@ -349,24 +366,20 @@ def _compute_ragged_block_indices(b, i, lengths_ref): b_next = jnp.where(not_done, b, jnp.where(am_last_batch, b, b + 1)) # if not done, i next = i # if done - #if last batch, previous good block - #if not last batch, i next = 0 + # if last batch, previous good block + # if not last batch, i next = 0 i_next = jnp.where( not_done, i, jnp.where(am_last_batch, last_good_block, 0) ) return b_next, i_next - def kv_index_map(b, i, layer_ref, start_ref, - end_ref, - *_): + def kv_index_map(b, i, layer_ref, start_ref, end_ref, *_): b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) if stacked: return layer_ref[0], b_next, i_next, 0 return b_next, i_next, 0 - def kv_scale_index_map(b, i, layer_ref, start_ref, - end_ref, - *_): + def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_): b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) if stacked: return layer_ref[0], b_next, 0, i_next @@ -380,18 +393,18 @@ def kv_scale_index_map(b, i, layer_ref, start_ref, ks_bp = (None, 1, bk) in_specs = [ - pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), # q - pl.BlockSpec(kv_index_map, kv_bp), # k - pl.BlockSpec(kv_index_map, kv_bp), # v - pl.BlockSpec(kv_scale_index_map, ks_bp), # k_scaler - pl.BlockSpec(kv_scale_index_map, ks_bp), # v_scaler - ] + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), # q + pl.BlockSpec(kv_index_map, kv_bp), # k + pl.BlockSpec(kv_index_map, kv_bp), # v + pl.BlockSpec(kv_scale_index_map, ks_bp), # k_scaler + pl.BlockSpec(kv_scale_index_map, ks_bp), # v_scaler + ] inputs = ( jnp.array([layer]), start, end, - end, # line_end, not actually used + end, # line_end, not actually used ragged_batch_index, ragged_block_index, q, @@ -420,22 +433,27 @@ def kv_scale_index_map(b, i, layer_ref, start_ref, grid=(batch_size, seq_len // bk), ), interpret=testing, - #debug=True, + # debug=True, compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, out_shape=[ q, - jax.ShapeDtypeStruct( - (batch_size, time, head_dim), jnp.float32 - ), - jax.ShapeDtypeStruct( - (batch_size, time, head_dim), jnp.float32 - ), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), ], )(*inputs) return out, (m[..., 0], l[..., 0]) + @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "q_shard_axis", "kv_shard_axis", "testing"] + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "q_shard_axis", + "kv_shard_axis", + "testing", + ], ) def ragged_mha( q: jax.Array, @@ -451,8 +469,8 @@ def ragged_mha( bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, - q_shard_axis:int = 0, - kv_shard_axis:int = 0, + q_shard_axis: int = 0, + kv_shard_axis: int = 0, testing: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi head attention. @@ -506,9 +524,9 @@ def ragged_mha( v_scale = jnp.squeeze(v_scaler, -1) if stacked: - assert k_scale.shape == (k.shape[0], bq, 1, tk) + assert k_scale.shape == (k.shape[0], bq, 1, tk) else: - assert k_scale.shape == (bq, 1, tk) + assert k_scale.shape == (bq, 1, tk) replicated_inputs = ( ragged_batch_index, @@ -521,7 +539,7 @@ def ragged_mha( with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( - #ragged_mqa, + # ragged_mqa, ragged_mqa_reference, bk=bk, mask_value=mask_value, @@ -536,7 +554,7 @@ def ragged_mha( kv_shard_axis, *([None] * replicated_in_axes), ), - out_axes=q_shard_axis + out_axes=q_shard_axis, )(q, k, v, layer, start, end, *replicated_inputs) return out, (m, l) @@ -567,24 +585,36 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): output = torch.einsum("ikjm,ikml->ikjl", scores, values) return output -def flash_attention(xq, keys, values, layer, k_scaler = None, v_scaler= None, mask=None, normalize_var=True): + +def flash_attention( + xq, + keys, + values, + layer, + k_scaler=None, + v_scaler=None, + mask=None, + normalize_var=True, +): mask_value: float = DEFAULT_MASK_VALUE if keys.ndim == 5: keys = keys[layer] values = values[layer] k_scaler = k_scaler[layer] if k_scaler is not None else None v_scaler = v_scaler[layer] if v_scaler is not None else None - + logits = torch.einsum( "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) ) if normalize_var: - logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama + logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama # Quantized if k_scaler is not None: bs, hs, ls, ds = k_scaler.shape - logits = logits * k_scaler.reshape(k_scaler.shape[-4], 1, 1, k_scaler.shape[-2]) + logits = logits * k_scaler.reshape( + k_scaler.shape[-4], 1, 1, k_scaler.shape[-2] + ) # mask = jnp.arange(keys.shape[1])[None] < lengths[:, None] if mask is not None: @@ -593,12 +623,14 @@ def flash_attention(xq, keys, values, layer, k_scaler = None, v_scaler= None, ma logits_max, _ = torch.max(logits, axis=-1, keepdim=True) unnormalized = torch.exp(logits - logits_max) - #Quantized, should not put here, otherwise sum will have this too, which cancels with denominator + # Quantized, should not put here, otherwise sum will have this too, which cancels with denominator # unnormalized = unnormalized * v_scaler denominator = unnormalized.sum(axis=-1, keepdim=True) if v_scaler is not None: - unnormalized = unnormalized * v_scaler.reshape(v_scaler.shape[-4], 1, 1, v_scaler.shape[-2]) + unnormalized = unnormalized * v_scaler.reshape( + v_scaler.shape[-4], 1, 1, v_scaler.shape[-2] + ) o = ( torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) / denominator @@ -609,12 +641,23 @@ def flash_attention(xq, keys, values, layer, k_scaler = None, v_scaler= None, ma class RaggedAttentionKernel: """Ragged attention kernel.""" - def __init__(self, env, input_specs, output_specs, q_shard_axis, kv_shard_axis): + + def __init__( + self, env, input_specs, output_specs, q_shard_axis, kv_shard_axis + ): self.binded_ragged_mha = functools.partial( - ragged_mha, bk=env.block_size, q_shard_axis=q_shard_axis, kv_shard_axis=kv_shard_axis, testing=env.testing + ragged_mha, + bk=env.block_size, + q_shard_axis=q_shard_axis, + kv_shard_axis=kv_shard_axis, + testing=env.testing, ) self.binded_ragged_mha = shard_map( - self.binded_ragged_mha, env.mesh, input_specs, output_specs, check_rep=False + self.binded_ragged_mha, + env.mesh, + input_specs, + output_specs, + check_rep=False, ) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) @@ -630,7 +673,6 @@ def __call__( ragged_block_index, k_scaler=None, v_scaler=None, - ): return self.binded_ragged_mha( xq, diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index b3c4fbc9..59220149 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -19,6 +19,7 @@ from jax.experimental.shard_map import shard_map import torch_xla2 + # pylint: disable-next=all class CacheInterface: """Kv cache interface""" @@ -64,6 +65,7 @@ def state(self): def finalize(self): return + # pylint: disable-next=all def KVCachePrefill_flatten(cache): return ( @@ -117,32 +119,51 @@ def __init__( if self.env.new_cache_stacked: layer, batch, heads, time, dim = self.cache_k.shape new_dim = (layer, batch, heads, 1, dim) - self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_dim, dtype=self.env.default_type), jnp.zeros(new_dim, dtype=self.env.default_type))) + self.new_ks, self.new_vs = torchjax.to_torch( + ( + jnp.zeros(new_dim, dtype=self.env.default_type), + jnp.zeros(new_dim, dtype=self.env.default_type), + ) + ) else: self.new_ks, self.new_vs = [], [] else: # when generate cache is not stacked, new cache cannot stack assert not self.env.new_cache_stacked - cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads + cache_pspec = self.env.partition_by_axis( + self.env.cache_sharding_axis + ) # Number of heads none_pspec = self.env.partition_by_axis() in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec, none_pspec) out_specs = (cache_pspec, cache_pspec) - self.update_single_cache_line = jax.jit(shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False)) + self.update_single_cache_line = jax.jit( + shard_map( + self.update_single_cache_line, + self.env.mesh, + in_specs, + out_specs, + check_rep=False, + ) + ) def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): b = cache_k.shape[-4] for bb, pp in enumerate(pos.reshape(b)): - slice_dim = 0 - update_start_indices = (bb, 0, pp, 0) - if self.env.generate_cache_stacked: - if self.env.new_cache_stacked: - slice_dim = 1 - update_start_indices = (0, bb, 0, pp, 0) - # We are not handling generate_cache_stacked=True new_cache_stacked=False here - new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) - new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) - cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) - cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + # We are not handling generate_cache_stacked=True new_cache_stacked=False here + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) return cache_k, cache_v def finalize(self): @@ -152,22 +173,44 @@ def finalize(self): # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_vs._elem, -2)) if self.env.ring_buffer: # Assume no cache stack for ring buffer - self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set(self.new_ks._elem) - self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set(self.new_vs._elem) + self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set( + self.new_ks._elem + ) + self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set( + self.new_vs._elem + ) else: if self.env.generate_cache_stacked: layer, b, head, len, dim = self.cache_k.shape if self.env.new_cache_stacked: - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.input_pos) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k._elem, + self.cache_v._elem, + self.new_ks._elem, + self.new_vs._elem, + self.input_pos, + ) else: for i in range(self.env.num_layers): - self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) + self.cache_k._elem = self.cache_k._elem.at[ + i, self.batch, :, self.input_pos, : + ].set(self.new_ks[i]._elem.reshape(b, head, dim)) + self.cache_v._elem = self.cache_v._elem.at[ + i, self.batch, :, self.input_pos, : + ].set(self.new_vs[i]._elem.reshape(b, head, dim)) else: # Try to use shard_map to get rid of the data copy - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.input_pos) + self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k._elem, + self.cache_v._elem, + self.new_ks._elem, + self.new_vs._elem, + self.input_pos, + ) - def update(self, key, value, layer_id:int): + def update(self, key, value, layer_id: int): """Update kv cache""" # Will process in insert() at the end of the transformer forward pass keyj, valuej = torchjax.to_torch((key, value)) @@ -186,34 +229,37 @@ def update(self, key, value, layer_id:int): self.new_ks = keyj self.new_vs = valuej return self.cache_k, self.cache_v - elif self.env.ring_buffer: # Assume no cache stack for ring buffer # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set(keyj) - self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set(valuej) + self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set( + keyj + ) + self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set( + valuej + ) return self.cache_k, self.cache_v else: if self.env.generate_cache_stacked: - # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[layer_id, self.batch, :, self.input_pos, :].set( - keyj.squeeze(2) - ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[layer_id, self.batch, :, self.input_pos, :].set( - valuej.squeeze(2) - ) + self.cache_k._elem = self.cache_k._elem.at[ + layer_id, self.batch, :, self.input_pos, : + ].set(keyj.squeeze(2)) + # pylint: disable-next=all + self.cache_v._elem = self.cache_v._elem.at[ + layer_id, self.batch, :, self.input_pos, : + ].set(valuej.squeeze(2)) return self.cache_k[layer_id], self.cache_v[layer_id] else: # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[self.batch, :, self.input_pos, :].set( - keyj.squeeze(2) - ) + self.cache_k._elem = self.cache_k._elem.at[ + self.batch, :, self.input_pos, : + ].set(keyj.squeeze(2)) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[self.batch, :, self.input_pos, :].set( - valuej.squeeze(2) - ) + self.cache_v._elem = self.cache_v._elem.at[ + self.batch, :, self.input_pos, : + ].set(valuej.squeeze(2)) return self.cache_k, self.cache_v def state(self): @@ -293,63 +339,133 @@ def __init__( if self.env.generate_cache_stacked: layer, batch, heads, time, dim = self.cache_k.shape new_kv_dim = (layer, batch, heads, 1, dim) - self.new_ks, self.new_vs = torchjax.to_torch((jnp.zeros(new_kv_dim, dtype=jnp.int8), jnp.zeros(new_kv_dim, dtype=jnp.int8))) + self.new_ks, self.new_vs = torchjax.to_torch( + ( + jnp.zeros(new_kv_dim, dtype=jnp.int8), + jnp.zeros(new_kv_dim, dtype=jnp.int8), + ) + ) if self.env.new_cache_stacked: new_scale_dim = (layer, batch, 1, 1, 1) - self.new_k_scaler, self.new_v_scaler = torchjax.to_torch((jnp.zeros(new_scale_dim, dtype=self.env.default_type), jnp.zeros(new_scale_dim, dtype=self.env.default_type))) + self.new_k_scaler, self.new_v_scaler = torchjax.to_torch( + ( + jnp.zeros(new_scale_dim, dtype=self.env.default_type), + jnp.zeros(new_scale_dim, dtype=self.env.default_type), + ) + ) else: - self.new_ks, self.new_vs, self.new_k_scaler, self.new_v_scaler = [], [], [], [] + self.new_ks, self.new_vs, self.new_k_scaler, self.new_v_scaler = ( + [], + [], + [], + [], + ) else: # when generate cache is not stacked, new cache cannot stack assert not self.env.new_cache_stacked - cache_pspec = self.env.partition_by_axis(self.env.cache_sharding_axis) # Number of heads - new_cache_pspec = self.env.partition_by_axis(2) if self.env.new_cache_stacked else self.env.partition_by_axis(1) + cache_pspec = self.env.partition_by_axis( + self.env.cache_sharding_axis + ) # Number of heads + new_cache_pspec = ( + self.env.partition_by_axis(2) + if self.env.new_cache_stacked + else self.env.partition_by_axis(1) + ) none_pspec = self.env.partition_by_axis() - in_specs = (*([cache_pspec] * 2), *([new_cache_pspec] * 2), *([none_pspec] * 5)) + in_specs = ( + *([cache_pspec] * 2), + *([new_cache_pspec] * 2), + *([none_pspec] * 5), + ) out_specs = (cache_pspec, cache_pspec, none_pspec, none_pspec) - self.update_single_cache_line = shard_map(self.update_single_cache_line, self.env.mesh, in_specs, out_specs, check_rep=False) + self.update_single_cache_line = shard_map( + self.update_single_cache_line, + self.env.mesh, + in_specs, + out_specs, + check_rep=False, + ) self.update_single_cache_line = jax.jit(self.update_single_cache_line) - def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, k_scaler, v_scaler, new_k_scaler, new_v_scaler, pos): + def update_single_cache_line( + self, + cache_k, + cache_v, + new_ks, + new_vs, + k_scaler, + v_scaler, + new_k_scaler, + new_v_scaler, + pos, + ): b = cache_k.shape[-4] for bb, pp in enumerate(pos.reshape(b)): - slice_dim = 0 - update_start_indices = (bb, 0, pp, 0) - if self.env.generate_cache_stacked: - if self.env.new_cache_stacked: - slice_dim = 1 - update_start_indices = (0, bb, 0, pp, 0) - if self.env.generate_cache_stacked and not self.env.new_cache_stacked: - for slice in range(self.env.num_layers): - update_start_indices = (slice, bb, 0, pp, 0) - new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks[slice], bb, 1, slice_dim) - new_ks_slice = jnp.expand_dims(new_ks_slice, 0) - cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) - - new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs[slice], bb, 1, slice_dim) - new_vs_slice = jnp.expand_dims(new_vs_slice, 0) - cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) - - new_k_scaler_slice = jax.lax.dynamic_slice_in_dim(new_k_scaler[slice], bb, 1, slice_dim) - new_k_scaler_slice = jnp.expand_dims(new_k_scaler_slice, 0) - k_scaler = jax.lax.dynamic_update_slice(k_scaler, new_k_scaler_slice, update_start_indices) - - new_v_scaler_slice = jax.lax.dynamic_slice_in_dim(new_v_scaler[slice], bb, 1, slice_dim) - new_v_scaler_slice = jnp.expand_dims(new_v_scaler_slice, 0) - v_scaler = jax.lax.dynamic_update_slice(v_scaler, new_v_scaler_slice, update_start_indices) - else: - new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) - cache_k = jax.lax.dynamic_update_slice(cache_k, new_ks_slice, update_start_indices) - - new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) - cache_v = jax.lax.dynamic_update_slice(cache_v, new_vs_slice, update_start_indices) - - new_k_scaler_slice = jax.lax.dynamic_slice_in_dim(new_k_scaler, bb, 1, slice_dim) - k_scaler = jax.lax.dynamic_update_slice(k_scaler, new_k_scaler_slice, update_start_indices) - - new_v_scaler_slice = jax.lax.dynamic_slice_in_dim(new_v_scaler, bb, 1, slice_dim) - v_scaler = jax.lax.dynamic_update_slice(v_scaler, new_v_scaler_slice, update_start_indices) + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + if self.env.generate_cache_stacked and not self.env.new_cache_stacked: + for slice in range(self.env.num_layers): + update_start_indices = (slice, bb, 0, pp, 0) + new_ks_slice = jax.lax.dynamic_slice_in_dim( + new_ks[slice], bb, 1, slice_dim + ) + new_ks_slice = jnp.expand_dims(new_ks_slice, 0) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + + new_vs_slice = jax.lax.dynamic_slice_in_dim( + new_vs[slice], bb, 1, slice_dim + ) + new_vs_slice = jnp.expand_dims(new_vs_slice, 0) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_k_scaler[slice], bb, 1, slice_dim + ) + new_k_scaler_slice = jnp.expand_dims(new_k_scaler_slice, 0) + k_scaler = jax.lax.dynamic_update_slice( + k_scaler, new_k_scaler_slice, update_start_indices + ) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_v_scaler[slice], bb, 1, slice_dim + ) + new_v_scaler_slice = jnp.expand_dims(new_v_scaler_slice, 0) + v_scaler = jax.lax.dynamic_update_slice( + v_scaler, new_v_scaler_slice, update_start_indices + ) + else: + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_k_scaler, bb, 1, slice_dim + ) + k_scaler = jax.lax.dynamic_update_slice( + k_scaler, new_k_scaler_slice, update_start_indices + ) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_v_scaler, bb, 1, slice_dim + ) + v_scaler = jax.lax.dynamic_update_slice( + v_scaler, new_v_scaler_slice, update_start_indices + ) return cache_k, cache_v, k_scaler, v_scaler @@ -367,11 +483,11 @@ def empty(cls, shape, device, env): """Create empty kv caches""" cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) - + if env.generate_cache_stacked: - s_shape = (shape[0], shape[1], 1, shape[3], 1) + s_shape = (shape[0], shape[1], 1, shape[3], 1) else: - s_shape = (shape[0], 1, shape[2], 1) + s_shape = (shape[0], 1, shape[2], 1) kscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) vscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) @@ -387,7 +503,7 @@ def quantize(self, val): scale = scale / 127 return (val / scale).to(torch.int8), scale - def update(self, xk, xv, layer_id:int): + def update(self, xk, xv, layer_id: int): """Update kv cache""" k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) @@ -421,34 +537,95 @@ def update(self, xk, xv, layer_id:int): self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2) self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2) - return self.cache_k, self.cache_v, k_quant, v_quant, self.k_scaler, self.v_scaler, kscale, vscale + return ( + self.cache_k, + self.cache_v, + k_quant, + v_quant, + self.k_scaler, + self.v_scaler, + kscale, + vscale, + ) def finalize(self): if not self.env.lazy_cache_update: return if self.env.ring_buffer: # Assume no cache stack for ring buffer - self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set(self.new_ks._elem) - self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set(self.new_vs._elem) + self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set( + self.new_ks._elem + ) + self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set( + self.new_vs._elem + ) else: - if self.env.generate_cache_stacked: - layer, b, head, _, dim = self.cache_k.shape - if self.env.new_cache_stacked: - # new kv scaler also has to go through shard_map instead of indexing because it needs to reshape to (batch, layer) which mess up with the data - caches = [self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.k_scaler._elem, self.v_scaler._elem, self.new_k_scaler, self.new_v_scaler] - self.cache_k._elem, self.cache_v._elem, self.k_scaler._elem, self.v_scaler._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, *caches, self.input_pos) - else: - # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. - caches = [self.cache_k._elem, self.cache_v._elem, self.new_ks, self.new_vs, self.k_scaler._elem, self.v_scaler._elem, self.new_k_scaler, self.new_v_scaler] - self.cache_k._elem, self.cache_v._elem, self.k_scaler._elem, self.v_scaler._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, *caches, self.input_pos) - # for i in range(self.env.num_layers): - # self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) - # self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) - # self.k_scaler._elem = self.k_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_k_scaler[i]._elem.reshape(b, 1, 1)) - # self.v_scaler._elem = self.v_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_v_scaler[i]._elem.reshape(b, 1, 1)) + if self.env.generate_cache_stacked: + layer, b, head, _, dim = self.cache_k.shape + if self.env.new_cache_stacked: + # new kv scaler also has to go through shard_map instead of indexing because it needs to reshape to (batch, layer) which mess up with the data + caches = [ + self.cache_k._elem, + self.cache_v._elem, + self.new_ks._elem, + self.new_vs._elem, + self.k_scaler._elem, + self.v_scaler._elem, + self.new_k_scaler, + self.new_v_scaler, + ] + ( + self.cache_k._elem, + self.cache_v._elem, + self.k_scaler._elem, + self.v_scaler._elem, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, *caches, self.input_pos + ) else: - # Try to use shard_map to get rid of the data copy - b = self.cache_k.shape[-4] - self.cache_k._elem, self.cache_v._elem, self.k_scaler._elem, self.v_scaler._elem = torch_xla2.interop.call_jax(self.update_single_cache_line, self.cache_k._elem, self.cache_v._elem, self.new_ks._elem, self.new_vs._elem, self.k_scaler._elem, self.v_scaler._elem, self.new_k_scaler, self.new_v_scaler, self.input_pos) - #self.k_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, 1, 1)) - #self.v_scaler._elem = self.v_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, 1, 1)) + # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. + caches = [ + self.cache_k._elem, + self.cache_v._elem, + self.new_ks, + self.new_vs, + self.k_scaler._elem, + self.v_scaler._elem, + self.new_k_scaler, + self.new_v_scaler, + ] + ( + self.cache_k._elem, + self.cache_v._elem, + self.k_scaler._elem, + self.v_scaler._elem, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, *caches, self.input_pos + ) + # for i in range(self.env.num_layers): + # self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) + # self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) + # self.k_scaler._elem = self.k_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_k_scaler[i]._elem.reshape(b, 1, 1)) + # self.v_scaler._elem = self.v_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_v_scaler[i]._elem.reshape(b, 1, 1)) + else: + # Try to use shard_map to get rid of the data copy + b = self.cache_k.shape[-4] + ( + self.cache_k._elem, + self.cache_v._elem, + self.k_scaler._elem, + self.v_scaler._elem, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k._elem, + self.cache_v._elem, + self.new_ks._elem, + self.new_vs._elem, + self.k_scaler._elem, + self.v_scaler._elem, + self.new_k_scaler, + self.new_v_scaler, + self.input_pos, + ) + # self.k_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, 1, 1)) + # self.v_scaler._elem = self.v_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, 1, 1)) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index c2e44f6e..75e11bab 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -93,25 +93,25 @@ flags.DEFINE_bool( "flash_attention", True, - "Whether to enable flas attention", + "Whether to enable flas attention. Only takes effect at test mode", required=False, ) flags.DEFINE_bool( "generate_cache_stacked", True, - "Whether to stack the generate cache to the layer dimension", + "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", required=False, ) flags.DEFINE_bool( "new_cache_stacked", True, - "Whether to stack the generate cache to the layer dimension", + "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", required=False, ) flags.DEFINE_bool( "lazy_cache_update", True, - "Whether to update the cache during attention or delayed until all the layers are done", + "Whether to update the cache during attention or delayed until all the layers are done. Only takes effect at test mode", required=False, ) flags.DEFINE_float( diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index bf59d3bf..56dc915f 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -357,7 +357,12 @@ def insert(cache, new_entry, update_index): update_index = [idx, slot, 0, pos, 0] newk = jnp.expand_dims(newk, 0) newv = jnp.expand_dims(newv, 0) - caches = [(insert(caches[0][0], newk, update_index),insert(caches[0][1], newv, update_index))] + caches = [ + ( + insert(caches[0][0], newk, update_index), + insert(caches[0][1], newv, update_index), + ) + ] else: update_index = [slot, 0, pos, 0] caches = [ @@ -373,8 +378,8 @@ def insert(cache, scaler, new_entry, update_index): quantize.quantize_tensor, new_entry, reduce_axis ) if self.env.generate_cache_stacked: - vals = jnp.expand_dims(vals, 0) - scales = jnp.expand_dims(scales, 0) + vals = jnp.expand_dims(vals, 0) + scales = jnp.expand_dims(scales, 0) new_scaler = jax.lax.dynamic_update_slice( scaler, scales, @@ -392,25 +397,31 @@ def insert(cache, scaler, new_entry, update_index): return res, new_scaler if self.env.generate_cache_stacked: - cache_k, k_scale = decode_state.caches[0][0], decode_state.cache_scales[0][0] - cache_v, v_scale = decode_state.caches[0][1], decode_state.cache_scales[0][1] - for idx, (newk, newv) in enumerate(prefix.caches): - update_index = [idx, slot, 0, pos, 0] - #newk = jnp.expand_dims(newk, 0) - #newv = jnp.expand_dims(newv, 0) - cache_k, k_scale = insert(cache_k, k_scale, newk, update_index) - cache_v, v_scale = insert(cache_v, v_scale, newv, update_index) - caches = [(cache_k, cache_v)] - scales = [(k_scale, v_scale)] + cache_k, k_scale = ( + decode_state.caches[0][0], + decode_state.cache_scales[0][0], + ) + cache_v, v_scale = ( + decode_state.caches[0][1], + decode_state.cache_scales[0][1], + ) + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + # newk = jnp.expand_dims(newk, 0) + # newv = jnp.expand_dims(newv, 0) + cache_k, k_scale = insert(cache_k, k_scale, newk, update_index) + cache_v, v_scale = insert(cache_v, v_scale, newv, update_index) + caches = [(cache_k, cache_v)] + scales = [(k_scale, v_scale)] else: update_index = [slot, 0, pos, 0] for (k, v), (kscaler, vscaler), (newk, newv) in zip( decode_state.caches, decode_state.cache_scales, prefix.caches ): - kcache, kscale = insert(k, kscaler, newk, update_index) - vcache, vscale = insert(v, vscaler, newv, update_index) - caches.append((kcache, vcache)) - scales.append((kscale, vscale)) + kcache, kscale = insert(k, kscaler, newk, update_index) + vcache, vscale = insert(v, vscaler, newv, update_index) + caches.append((kcache, vcache)) + scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) return DecodeState( tokens, @@ -463,25 +474,21 @@ def _insert_wrap( old_scales = decode_state.cache_scales cache_inserts = prefix.caches - # print(f"YY old_caches: {len(decode_state.caches)} cache_inserts: {len(cache_inserts)}") scales = [] caches = [] if not self.env.quant_config.enable_kv_quantization: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, new_entry, layer_id): + def insert(cache, new_entry): new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2)) - if self.env.generate_cache_stacked: - res = cache.at[layer_id, slot, :, update_indexes, :].set(new_entry) - else: - res = cache.at[slot, :, update_indexes, :].set(new_entry) + res = cache.at[slot, :, update_indexes, :].set(new_entry) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res - for idx, (newk, newv) in enumerate(prefix.caches): - caches = [ - (insert(old_caches[0][0], newk, idx), insert(old_caches[0][1], newv, idx)) - ] + caches = [ + (insert(k, newk), insert(v, newv)) + for (k, v), (newk, newv) in zip(old_caches, cache_inserts) + ] else: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) @@ -526,11 +533,6 @@ def insert( decode_state: DecodeState, slot: int, ) -> DecodeState: - # logging.info( - # 'Jet input prefix: %s, decode state before insert: %s', - # prefix, - # decode_state, - # ) if self.env.ring_buffer: start_insert = decode_state.current_position - prefix.seq_len end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen @@ -622,7 +624,6 @@ def generate( (-1) ), ragged_block_index.reshape((-1)) - def update_mask(): if self.env.ring_buffer: return decode_state.mask.at[:, decode_state.current_position].set(0) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 1836cf58..3e646232 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -86,7 +86,7 @@ class JetEngineEnvironmentData: shard_on_batch: bool = False # Whether to enable ragged multi head attention. - ragged_mha: bool = True + ragged_mha: bool = False # The block size for the ragged attention. block_size: int = 512 @@ -95,15 +95,15 @@ class JetEngineEnvironmentData: starting_position: int = 512 # Ring buffer - ring_buffer: bool = False + ring_buffer: bool = True - flash_attention: bool = True + flash_attention: bool = False - generate_cache_stacked: bool = True + generate_cache_stacked: bool = False - new_cache_stacked: bool = True + new_cache_stacked: bool = False - lazy_cache_update: bool = True + lazy_cache_update: bool = False # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") sampling_algorithm: str = "greedy" @@ -131,18 +131,29 @@ def __init__(self, data: JetEngineEnvironmentData): self.batch_size = self._data.batch_size self.seq_len = self._data.max_input_sequence_length self.cache_len = self._data.cache_sequence_length - self.ragged_mha = self._data.ragged_mha self.block_size = self._data.block_size self.starting_position = self._data.starting_position - self.flash_attention = self._data.flash_attention - self.generate_cache_stacked = self._data.generate_cache_stacked - self.new_cache_stacked = self._data.new_cache_stacked self.num_layers = self._data.num_layers - self.ring_buffer = self._data.ring_buffer - self.lazy_cache_update = self._data.lazy_cache_update self.testing = self._data.testing self.testing_seed = self._data.testing_seed + self.ring_buffer = self._data.ring_buffer + + if not self.ring_buffer: + self.lazy_cache_update = True + self.ragged_mha = True + self.flash_attention = True + self.generate_cache_stacked = True + self.new_cache_stacked = True + + if self.testing: + self.lazy_cache_update = self._data.lazy_cache_update + self.ragged_mha = self._data.ragged_mha + self.flash_attention = self._data.flash_attention + self.generate_cache_stacked = self._data.generate_cache_stacked + self.new_cache_stacked = self._data.new_cache_stacked + self.default_type = jnp.bfloat16 if self._data.bf16_enable else jnp.float32 + if self.generate_cache_stacked: self.cache_shape = (self.num_layers, *self._data.cache_shape) else: @@ -163,11 +174,11 @@ def __init__(self, data: JetEngineEnvironmentData): if self.generate_cache_stacked: self.attention_kv_axis_names = ( - "layer", - "batch", - "num_attn_heads", - "sequence_length", - "head_dim", + "layer", + "batch", + "num_attn_heads", + "sequence_length", + "head_dim", ) if data.shard_on_batch: self.kv_cache_shard_axis = "batch" diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 4b7e6c99..e9542703 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -368,19 +368,24 @@ def apply_rotary_emb( def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep).""" - bs, n_kv_heads, slen, head_dim = x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1] + bs, n_kv_heads, slen, head_dim = ( + x.shape[-4], + x.shape[-3], + x.shape[-2], + x.shape[-1], + ) if x.ndim == 5: - stacked = True + stacked = True else: - stacked = False + stacked = False if n_rep == 1: return x if stacked: layer = x.shape[0] return ( - x[:, :, :, None, :, :] - .expand(layer, bs, n_kv_heads, n_rep, slen, head_dim) - .reshape(layer, bs, n_kv_heads * n_rep, slen, head_dim) + x[:, :, :, None, :, :] + .expand(layer, bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(layer, bs, n_kv_heads * n_rep, slen, head_dim) ) return ( x[:, :, None, :, :] @@ -393,8 +398,14 @@ class AttentionKernel: def __init__(self, env, layer_id): self.env = env - self.q_shard_axis = 0 if self.env.shard_on_batch else 1 - self.kv_shard_axis = 0 if self.env.shard_on_batch else 2 if self.env.generate_cache_stacked else 1 + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = ( + 0 + if self.env.shard_on_batch + else 2 + if self.env.generate_cache_stacked + else 1 + ) q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() @@ -442,12 +453,18 @@ def __call__( n_rep = num_heads // num_kv_heads def attend(xq, keys, values, local_mask=None): - # As of right now, ragged attention doesn't support attention calculation with prefill and new cache line - # We are not using ragged attention for prefill yet. if keys.ndim == 4: - impl = self.ragged_attention_new + impl = self.ragged_attention_new else: - impl = self.ragged_attention_orig + impl = self.ragged_attention_orig + + true_len = seqlen + # When GQA is enabled, it not necessary to expand + if n_rep == 1 and seqlen == 1: + true_len = 2 + xq = torch.nn.functional.pad( + xq, (0, 0, 0, true_len - seqlen), "constant", 0 + ) if self.env.ragged_mha and seqlen == 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( @@ -463,27 +480,33 @@ def attend(xq, keys, values, local_mask=None): ) elif self.env.flash_attention: with torch_xla2.default_env(): - local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, self.layer_id, mask=local_mask) + local_output, (local_max, local_denom) = self.flash_attention( + xq, keys, values, self.layer_id, mask=local_mask + ) else: - if seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - local_output = self.dense_attention(xq, keys, values, None, None, local_mask) + local_output = self.dense_attention( + xq, keys, values, None, None, local_mask + ) local_max = None local_denom = None - if seqlen == 1: - local_output = local_output[:, :, 0:1, :] - if local_max is not None: - local_max = local_max[:, :, 0:1, :] - if local_denom is not None: - local_denom = local_denom[:, :, 0:1, :] + local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) + if local_max is not None: + local_max = local_max.reshape(bsz, num_heads, true_len, 1) + local_denom = local_denom.reshape(bsz, num_heads, true_len, 1) + + if true_len != seqlen: + local_output = local_output[:, :, 0:seqlen, :] + if local_max is not None: + local_max = local_max[:, :, 0:seqlen, :] + if local_denom is not None: + local_denom = local_denom[:, :, 0:seqlen, :] # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") # if local_max is not None and local_denom is not None: # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") self.env.apply_sharding(local_output, axis=self.q_shard_axis) return local_output, (local_max, local_denom) - with jax.named_scope("attn_insert_cache"): orig_keys, orig_values = cache.update(xk, xv, self.layer_id) @@ -494,11 +517,9 @@ def attend(xq, keys, values, local_mask=None): # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): - existing_output, (existing_max, existing_denom) = attend(xq, orig_keys, orig_values, mask) - cache_len = orig_keys.shape[-2] - existing_output = existing_output.reshape(bsz, num_heads, seqlen, head_dim) - existing_max = existing_max.reshape(bsz, num_heads, seqlen, 1) - existing_denom = existing_denom.reshape(bsz, num_heads, seqlen, 1) + existing_output, (existing_max, existing_denom) = attend( + xq, orig_keys, orig_values, mask + ) # Updating cache during each step still has very large impact on latency. # For non flash attention or prefill, existing output contains everything if not self.env.lazy_cache_update or seqlen > 1: @@ -510,22 +531,23 @@ def attend(xq, keys, values, local_mask=None): xk = repeat_kv(xk, n_rep) xv = repeat_kv(xv, n_rep) new_output, (new_max, new_denom) = attend(xq, xk, xv, None) - new_output = new_output.reshape(bsz, num_heads, 1, head_dim) - new_max = new_max.reshape(bsz, num_heads, 1, 1) - new_denom = new_denom.reshape(bsz, num_heads, 1, 1) - # if cache.cache_k is None: # Prefill - # return new_output with jax.named_scope("attn_global"): # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") - global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) - existing_output = existing_output * existing_denom * torch.exp(existing_max) / global_sum + global_sum = existing_denom * torch.exp( + existing_max + ) + new_denom * torch.exp(new_max) + existing_output = ( + existing_output + * existing_denom + * torch.exp(existing_max) + / global_sum + ) new_output = new_output * new_denom * torch.exp(new_max) / global_sum attn_out = existing_output + new_output - return attn_out @@ -534,7 +556,13 @@ class Int8KVAttentionKernel: def __init__(self, env, layer_id): self.env = env self.q_shard_axis = 0 if self.env.shard_on_batch else 1 - self.kv_shard_axis = 0 if self.env.shard_on_batch else 2 if self.env.generate_cache_stacked else 1 + self.kv_shard_axis = ( + 0 + if self.env.shard_on_batch + else 2 + if self.env.generate_cache_stacked + else 1 + ) q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() @@ -580,18 +608,21 @@ def __call__( num_kv_heads = xk.shape[-3] kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads - + def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): if keys.ndim == 4: - impl = self.ragged_attention_new + impl = self.ragged_attention_new else: - impl = self.ragged_attention_orig - #if not self.env.ragged_mha and seqlen == 1: + impl = self.ragged_attention_orig + true_len = seqlen + # When GQA is enabled, it not necessary to expand if n_rep == 1 and seqlen == 1: true_len = 2 - xq = torch.nn.functional.pad(xq, (0, 0, 0, true_len - seqlen), "constant", 0) - #xq = torch.broadcast_to(xq, (bsz, num_heads, true_len, head_dim)) + xq = torch.nn.functional.pad( + xq, (0, 0, 0, true_len - seqlen), "constant", 0 + ) + # xq = torch.broadcast_to(xq, (bsz, num_heads, true_len, head_dim)) # We are not using ragged attention for prefill yet. if self.env.ragged_mha and seqlen == 1: @@ -608,39 +639,58 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): k_scaler, v_scaler, ) - #local_max = local_max.reshape(*local_max.shape, 1) - #local_denom = local_denom.reshape(*local_denom.shape, 1) elif self.env.flash_attention: with torch_xla2.default_env(): - local_output, (local_max, local_denom) = self.flash_attention(xq, keys, values, self.layer_id, k_scaler, v_scaler, mask=local_mask) + local_output, (local_max, local_denom) = self.flash_attention( + xq, + keys, + values, + self.layer_id, + k_scaler, + v_scaler, + mask=local_mask, + ) else: - local_output = self.dense_attention(xq, keys, values, k_scaler, v_scaler, local_mask) + local_output = self.dense_attention( + xq, keys, values, k_scaler, v_scaler, local_mask + ) local_max = None local_denom = None local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) if local_max is not None: local_max = local_max.reshape(bsz, num_heads, true_len, 1) - local_denom= local_denom.reshape(bsz, num_heads, true_len, 1) + local_denom = local_denom.reshape(bsz, num_heads, true_len, 1) if true_len != seqlen: - local_output = local_output[:, :, 0:1, :] + local_output = local_output[:, :, 0:seqlen, :] if local_max is not None: - local_max = local_max[:, :, 0:1, :] - local_denom = local_denom[:, :, 0:1, :] + local_max = local_max[:, :, 0:seqlen, :] + local_denom = local_denom[:, :, 0:seqlen, :] self.env.apply_sharding(local_output, axis=self.q_shard_axis) return local_output, (local_max, local_denom) + with jax.named_scope("attn_insert_cache"): - orig_keys, orig_values, new_key, new_value, k_scaler, v_scaler, new_k_scaler, new_v_scaler = cache.update(xk, xv, self.layer_id) + ( + orig_keys, + orig_values, + new_key, + new_value, + k_scaler, + v_scaler, + new_k_scaler, + new_v_scaler, + ) = cache.update(xk, xv, self.layer_id) # We are not using ragged attention for prefill yet. if not self.env.ragged_mha or seqlen > 1: orig_keys = repeat_kv(orig_keys, n_rep) orig_values = repeat_kv(orig_values, n_rep) with jax.named_scope("attn_qkv"): - existing_output, (existing_max, existing_denom) = attend(xq, orig_keys, orig_values, k_scaler, v_scaler, mask) - cache_len = orig_keys.shape[-2] - #existing_output = existing_output.reshape(bsz, num_heads, seqlen, head_dim) + existing_output, (existing_max, existing_denom) = attend( + xq, orig_keys, orig_values, k_scaler, v_scaler, mask + ) + # For non flash attention or prefill, existing output contains everything if not self.env.lazy_cache_update or seqlen > 1: return existing_output @@ -648,29 +698,34 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): # At this point, flash attention or ragged attention must have been enabled - #existing_max = existing_max.reshape(bsz, num_heads, seqlen, 1) - #existing_denom = existing_denom.reshape(bsz, num_heads, seqlen, 1) - if not self.env.ragged_mha or seqlen > 1: new_key = repeat_kv(new_key, n_rep) new_value = repeat_kv(new_value, n_rep) - new_output, (new_max, new_denom) = attend(xq, new_key, new_value, new_k_scaler, new_v_scaler, None) - #new_output = new_output.reshape(bsz, num_heads, 1, head_dim) - #new_max = new_max.reshape(bsz, num_heads, 1, 1) - #new_denom = new_denom.reshape(bsz, num_heads, 1, 1) + new_output, (new_max, new_denom) = attend( + xq, new_key, new_value, new_k_scaler, new_v_scaler, None + ) with jax.named_scope("attn_global"): - global_sum = existing_denom * torch.exp(existing_max) + new_denom * torch.exp(new_max) - existing_output = existing_output * existing_denom * torch.exp(existing_max) / global_sum + global_sum = existing_denom * torch.exp( + existing_max + ) + new_denom * torch.exp(new_max) + existing_output = ( + existing_output + * existing_denom + * torch.exp(existing_max) + / global_sum + ) new_output = new_output * new_denom * torch.exp(new_max) / global_sum attn_out = existing_output + new_output - + return attn_out class Attention(ModuleBase): """Attention module.""" - def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id): + def __init__( + self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id + ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads @@ -778,9 +833,9 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) if mask.ndim == 2: - mask = mask[:, None, None, :] + mask = mask[:, None, None, :] # if cache is not None and cache.cache_k is not None: - # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") + # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( xq, xk, diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 1bb4a1c0..55497596 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -258,7 +258,7 @@ def forward( bsz, seqlen = tokens.shape freqs_cis = self.freqs_cis[input_pos] freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - + # Should check more thoroughly, as of now, when prefill, it's always not stacked. When generate, it's controlled by the parameter. # target_cache_layers = 1 if self.env.generate_cache_stacked else len(self.layers) # assert len(caches) == target_cache_layers, f"Number of caches ({len(caches)}) and layers ({target_cache_layers}) dont match" diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index df0e0056..f85688b0 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -38,7 +38,8 @@ def __init__(self, config: ModelArgs, env) -> None: config.vocab_size, config.dim, device=config.device ) self.layers = nn.ModuleList( - TransformerBlock(config, env, layer_id) for layer_id, _ in enumerate(range(config.n_layer)) + TransformerBlock(config, env, layer_id) + for layer_id, _ in enumerate(range(config.n_layer)) ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) LinearLayer = get_quantized_linear_layer(env.quant_config) @@ -151,7 +152,7 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None: config.dim, env=env, device=config.device, - layer_id=layer_id + layer_id=layer_id, ) self.block_sparse_moe = MOEFeedForward(config, config.device, env) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) diff --git a/keys_original b/keys_original deleted file mode 100644 index d18bd8dd3ddcfb0140450b3cfde59b54a931865f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11706 zcmbW-3vg7`83*u9AnzhV1A$hAXuu5#AtdBwINuk71e~2fcn(l831LH5R+dc`3uVRz z0-*#D2#gksAQ3yoDo<;9?ESvB+6PlR2#mGSK1OAVwibmJg;_g{qRC%gWq4Rn3Ywb;9&xi-+`TjX0(?;pLeK2)E#bg{pxGTAln zMr(<02nCuNzMWUQEadYyr`9%k7cBGn7uJTnOPT{swT~2}23>CVoFtpyY}c$`Tqu~dzM&!D_j?;czQB^^)TvDiy-nVQ_xTz^slnvz>A{pxFm-w`Ew9Ft zlsxNN(Sm(8Iie=!hG)3?M{W3m6(MgR5?obz+s!S%z61QeP{{9%gs(O44o7FG+U8hhW9#1GnFfHCd`T)mQ8~EjOHq0s9F%WtPpDTsIwuubD=h{&WriqnE}XzffBlq@H)R+hXX*EZjG-Zl9Opzu#yejZwjAUxLA}Q=*q=s7wnZ)e|nJ0r7X*PL;tl&IF-lCrf za{72*MjoIS%8`AvOc9UlWMmT$BV-Z{S7c4tGM(PDYQuL4@^W;rJQY4d+&uAJW)72qlu9+#wY+5M@ z+p}*ZP2(APQc4`8gGVUxSoo|#-29$FM#@ox>~+2`$TH_0LH=f56l5MTBa2Odk)w_X zl1Ud0@(<}Eq=Iu93Go33$>#qE^0;YH;^0fIxkTqt$L6$pH9Hdb16692Pydd4q zDMEIXPhsR6yoQn2ox2tJttlgJPFluTmeGYZ)GDYhcMLy>qMWM`1rF32uH zHfp0H2kDqW3OQMkvre5s#+gb+g7knOZMu$-3|>yiIeT@Bd`uq+GD=GX`M{q4BbUji z$Pb*;gtYR%1bJ0{PDmemMv+10VL@J{e+x2$x(%|ETNp{A{fr!uI~A$ZIzig_nuE2O zHhYB=%T=l+WV=0oM!Km*ky3t85tluGLNetqg3RXwjD*CeNS|w*el%F zI^`om()q`P)XC$7r0XF@9%5$Xi21-EljyJ@iPEOXNISbDPtk`48O+&)OjTB7G99o_ zP-4rGfP?Iij|Dj#{Ioc%1E3})C zM!rXpH2s8-a_z8B;A4B2o)E+<`wa3eZYQLj?`7l)u^`#{XF@up4^?5;d(WNkQLXSAeHt8f}7>`n<)6PD~ZrWsLpx7p9nIhTz z8$#Zf`w4m3&K^jJ&l{vTJdu%2TA;{f`}_d8CUJz!aOxBZaFQacb(a1A#CB2KicA=n z$jCh1t;k6;hmdL~m5{4aOh^O&(Lsi4jv~u-F(X}04k5pz)q=bs^BEb-2kaBv*e-K2 zBa3*XAhX4<$XTigBLg&vklXn*A(6^Ur#t1KE8CV`TCpoH2e?dB<>e#Zre<3-T2)#0 zf0bT-UYJ>!8X3XqU5jSL+W|zx9zm8XH);fVpBH~Ndz*aqe&2WNBPFc0zJ*nl?yo$c z{GyWL(xTGRqLRYWqO!8$f|AmLeA~v1DJ(86EGQ@(TUJn3TviepWuHS^JKT1J5jCE; z$Y5_){@z*|J23J?kM`wtz|k4ER={TT=zrCI>X~t~q&7O_)`OYeBV@KM4(=Ipy>vD@ z;nssOphv=MBcl^;`~{*TZrxK^JtC&(L`Qu2F|6?450~ diff --git a/original_scores b/original_scores deleted file mode 100644 index bd56f64431517fd785076a21d79b25f6016f6c16..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6468 zcmbVR3rv<*ljbgR(|SWJUQj{QB7#7H@5~&6E5;j(Kdxmpm8cM8C8AvP*J!*U(jt{o z6!C&Zs)aL=qLb-E(PWCd_ zUmo#EPRYxkUlbLW{>yovWo{X+a@+W2?iud#xo~OmGLIQaK1ItsXHIhopXNF(NsiG) z%e=zn?6lFcyB00$@zKJC`FVNS3yX5|mlXQ^b3s;iL3Y+Z`?ZZ_9cXE%W)? zGT(7&(|UMs`Hutt`juGjKGObj;z2IG|1|LhD~q!8yG^Gg`~5oeU*?~eTU3;n-EDjm z{8x~d816p)+V}vMxIY>1_G;w#l;r3?SpLQS|MoZ8cEWmmi4~Bp``Jl4$d1EccGzJ1 zhG$}`{gN8-H1^Ubu$eRX5?;V$uI6Yuiu2G5zrthejbHE?_yOyp$#$+z#z^|a`6F+@VCn=<^3<;O7}cl}JjB~D#`eY}ybTxd3QR&zuE5#ORkT3u zr$Vm5@qC?6liZC*a3WsfQxGiknu_gIrnYh@rqN;;!56Wf57AWWhr4+-de{(jLK7TB zKa9avDnl=|mqO?9Q45>lz}rGIkeUs zq%avnGga^b8coUg0k*(;o(Lzo9dAGc6+4jc!Fi)#K814_mg8J%PzI9l9F6BB+J_r> z7wv_|@HOXS3hY%)>MNd&A$T6s@gNs+I#;O>>d#f|k1g_S6`iA8p3G0pdfbI)xr;~O zD928T+*q7}Lu7s2^f5S189Y><;c~iWQq3K_&bN4ioxobxVH1YQycgjkoWtQ*PCLbe zHP|e9(#G%MRD8ozIDtFmnTHUrb9tb>0-svt7!Qw8aLQ z%K_}6_wWg5husoSsPnPS+~PV&paGokFh)CWQ8K)O6UWDMJ7mBdT0}$9m0z0aW{*Tn zBONuZh`>0W3|1_V{|zl;A+|+8NZwrynza=$C26-6R3`t)7K^t`$8lf zqL&at<9GtZU_Mn-6(qr9a>XE6uUw&wgJ2iMVkw?*q;V&0gF1(yVf;OKQ=_RPLzB6l zZYK7?8Vb`{e2rIewyx&wCQ{#%6+FgO;>ROOGx8IfL!q1{zT2oBRA)i&XzWLyn=Wb)^w_FB z^bC90?!w-Dh9`2hN#NBgL@ZB|oUM__?jbR~4>p=bY6xkrH&baCM&VenbY4Am?j#?v zwbR@%wdNYt!#Q=;%*Gh5#NnJlJE&P*h22mAskn+R!wOD7AO2Y4VK^s1wPfxhI7t1t z#o55td@ufPz%#fT+rWbqdx=L6P&z&bA9mL-IZiiVD;@(kt`q!O#0J|1rONp!PN7F! z%Bx`kucJ%2-keSRoCEARx`Ofs!N8C47OYbhv=j>EOfz^1CPwP{c8adYYgn$T1vRV5 z&pPo1kJdi6pS}P!mg%DE%a%%^FBOwNCG$R$gfC_OsggYlAlWohy^4VabP_he4f+(f z^GO^ncoqVUG@3nN5KqKh@p%N7;7GHNe^faJWgCfmAs9z-FSu)ZbAV(Q@&ukO@jC>^ zI2)kDyvGY9cfz?J?B`LE+20WxFAR{3N@Bz^XmvUy`U3?qyi7Tt;9)kL4{;w|iEfw- zvHTSd#Z1(2O;FZ@8hI*yA)^|ichD>0tS~`(2QJ4_YQa~i@iO(3x%-0PnD#Y2-*<;m zumV0+{UMua8wTuSZ7hm$WIHdjo&IbCiHdDsv0!BuR#fGc)T3RFtwI3-@M zVFc!2h(vEP52R|If}6M*t9UQYhGDi%H3>SF@&@6Gd96#d_{R~}oEOwGR_9j0$yN`l7UO03x zrkXO<$kUQDj*oXb;IYgU*5KO1(s6=cz3=odmhnhdc2ymqk!RQ;X6V9_M z7D}${$5ptRa=3>rMQ>fDT{yxb@1ZV}DZJ@wX9z~!p!MRr7?^2W=sGv?PK=h-E`=?E zqGcFlFY--}hv(qscu3wn2>f{jMaeoQ&=P*ihU%~iRybb45I%iuE(Wrnu#dYVM)jsW z_>jJkQB36C;>BBn6??#y%P|Y{#83C(T+0o5MGuub2axCq;mdj4$@}Ru{0O!?I-n1C zCCm^O_7J}Eatu`-R4rUyBQYGq9vp{1(juY$?!s0s=MFlF z?z)4Y;Yu81Ehnp3eMI}|6*iS4p|9j?1ccgN`ZS;81jl;v!wfq@@3wtKqxk6z`%E-k zndr6*JJC+D)e;-|_(r`Wh~iGGISGH1ReTTI(8q>x3%Lricr#O!cr1o{<6%)E;j${d z?0&2?ois+Ga|35%fkFskZ*WmB)G%^|V=&@pO%^5m+pH3x15$<0{kYZn#)RSr#-CQ3 zQ)Ue8RMGsY+Ck?{Bztm}VBG=O1v{yON+hSXu;)HXlN`^bdH4)p-~+V$0!#T*>?J5) zz(zPXUJcd_jze4}<9Y-oILwC18a-1xggL$zET4r3@m-F$i*0`#VjDyQ98upvC2ge^ zx<)5umFCJk+mr`Ki}o;(L32TqQYB~+i z=xewGkKhZq1uZn$ghH!$V5CYsJi z`J}MMDcs9J)JDbpj8=)-Xi%l#Xp&Y@ z58DJWqa>dnL#)2W-hyT=7z<q=({g+$pueS$sv>hDbDQ#>sO1DI9>Qe3oKtvuLmy z>x$n9E_KN&9i%j93`0KOOj+5Xlmgw`6Nv~;(u;pfcNAT?%(%Rp)aE>nI8TulN;y^#)%o=&ZO{%KR zTN)FPcw3~inWvBCW(P{U2mYu0T5!5Y_ z%=W}{CR^^`OobH5H>9G70w4XZSn+2s`z5=x8zGK>PMm?Mj;T7+l<*#&rG`pX(k`f2 ztv2Fx-U7QYhaMnvj_~a3#PQgHGk6n!Ch=Axm>F)b3f>0t6%@?E*YrYiJyaCMO%4)2 zT*C&sBaF9GYQrCX{BU*7Y)K;*evGBAZk~a%`FpP+)O$(m@|$J)qXq&#y}WdF}EEDRZC(ZJT?2w2~%NiCHe}^ zEyX&B*B8)3*kmp*QhhK~_Fh;hjO)iE;2x|M9&*EvDM?f&aL`RiVDZVgcF;!(!G%r!L^8N$7 z0gu8M_OmUL`*V1v)GxPG347YvoGMX$z?6zt_evk5f|gS+(Zkx_!8&0YR~;_3|778g zbn?I;breGR2DWkwu7YAU0LSP_R1DwI8q>?x^G&G=*FmUs5=wXn+;KKxF`iR%rFIFh z_qg6h*bMs|e!>rxlb7>%oXf8TA7@Kt_9K6c%C41a*om|3k2qGh3C6XG+6oeU?ISDD zfhjiHek|DNE8_~qwfeFy5VSjmnRc!yU8S!(T=f-rkKL>fDO;}R2)1v-yWDKI>i#y~ z4iWsTpdfq14!7YFHxoHiFObOcx9Pk?)Xfwsal})Q=%f)m5CS>OY{!wj7W=9=sgkzh zer^&KwNEuts&KRirb|`nOMUo^tlt!vAWS+(xcEa+l2-()gYgh9F&iNQzK5HX3Qvy@ z;%K!>kg1WA>03${&({f;xQdEQ7i^u2`Fw}^*ea>9-jNQ?GYY{QR4!UK2k+w>>2a*) zkvfUoa5FrBFdh!0RWnyA7H!;H6ygQCEAjRchuX*Frfj2ggfYk<6DxuPO2prhG-`ls842DxAnQcwP;}O^KUi z{)`93XV27p9>=j*FVX17VZxRj)Bt^mkFS<&UMy_-h#xzr;7ohYc!};h%4bBIqtrtI zG*^$bYjr1_#v0VR#MbHykDvRgP0cS5}JJ z?8Bq%0jUBe;shMXA!zA@X#D*WqfvSe?dR^!=ikq{|L3}XE-u~2Uw1wOT^xVvd@j!} zD3rEsO7i>vkJjhUtQf^k1LUkkGfa_8seuA1N7 zM9$33N=Xj;KMrVo?E7)?vGMV-?@x%2O-P7~em_2Xylhe96QbWwNSGj3;>JhECv=Z; zSkRL8a2a9ukv7e>d$7NC-hbbk9@u^158Ge$;or9M`yO=AA1u6-|K7IpTT}aOGrwQZ z0e>(vTN?CloB5@E{kDzYFG=7ZY@~_Fw{83y2yYwty~o1-VC3Kbfsuc_mTA+xyx#lu z*&*%`zghA8_1ph^{#*BN53>FJ*U`#z-9K8mbo{y;-Ivp*d3Lw=T|8y)zW`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", @@ -62,21 +63,24 @@ def main(argv): print(f"---- Encoded tokens are: {tokens}") # pylint: disable-next=all - prefill_result, _ = engine.prefill( - params=params, padded_tokens=tokens, true_length=true_length - ) - # pylint: disable-next=all - decode_state = engine.insert(prefill_result, decode_state, slot=slot) + if profiling_prefill: + jax.profiler.start_trace(profiling_output) + prefill_result, _ = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + # pylint: disable-next=all + decode_state = engine.insert(prefill_result, decode_state, slot=slot) + jax.profiler.stop_trace() + sampled_tokens_list = [] print(f"---- Streaming decode started on #slot{slot}.") complete = np.zeros((1,), dtype=np.bool_) while True: - if profiling_output and not profiling_prefill: + if profiling_output: jax.profiler.start_trace(profiling_output) - decode_state, result_tokens = engine.generate(params, decode_state) - if profiling_output and not profiling_prefill: + decode_state, result_tokens = engine.generate(params, decode_state) + result_tokens = result_tokens.convert_to_numpy() jax.profiler.stop_trace() - result_tokens = result_tokens.convert_to_numpy() output, complete = token_utils.process_result_tokens( tokenizer=tokenizer, slot=slot, @@ -94,9 +98,6 @@ def main(argv): print("---- All output text.") print(tokenizer.decode(sampled_tokens_list)) - if profiling_output and profiling_prefill: - jax.profiler.stop_trace() - if __name__ == "__main__": os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 5fbbc42e..73d0ce6c 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -29,8 +29,10 @@ from jetstream_pt import environment from tests import helpers from jetstream_pt import torchjax +from absl.testing import parameterized -class LlamaE2ETest(unittest.TestCase): + +class LlamaE2ETest(parameterized.TestCase): """This test class includes all E2E test for llama2""" def _from_torch(self, tree): @@ -188,7 +190,7 @@ def _llama_e2e(self, env, model_arg): for k, v in model_ours.state_dict().items(): if "scale" in k: - state_dict[k] =helpers.to_xla_tensor(v) + state_dict[k] = helpers.to_xla_tensor(v) engine = PyTorchEngine(pt_model=model_ours, env=env) params = self._from_torch(state_dict) @@ -225,199 +227,65 @@ def test_llama_e2e_float32(self): out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertEqual(out_tokens, expected_output_tokens) - def test_llama_e2e_float32_left_aligned_cache(self): - """end to end jetstream llama test with float32""" - jax.config.update("jax_platform_name", "cpu") - print(f"---------> {jax.devices()}") - - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.flash_attention=False - env_data.generate_cache_stacked=False - env_data.new_cache_stacked=False - env_data.lazy_cache_update=False - env, model_arg = helpers.make_env_tiny(False, update_env_data) - - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - - def test_llama_e2e_float32_left_aligned_generate_cache_stacked(self): - """end to end jetstream llama test with float32""" - jax.config.update("jax_platform_name", "cpu") - print(f"---------> {jax.devices()}") - - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.flash_attention=False - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=False - env_data.lazy_cache_update=False - env, model_arg = helpers.make_env_tiny(False, update_env_data) - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - def test_llama_e2e_float32_left_aligned_new_cache_stacked(self): - """end to end jetstream llama test with float32""" - jax.config.update("jax_platform_name", "cpu") - print(f"---------> {jax.devices()}") - - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.flash_attention=False - env_data.generate_cache_stacked=False - env_data.new_cache_stacked=True - env_data.lazy_cache_update=False - env, model_arg = helpers.make_env_tiny(False, update_env_data) - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - def test_llama_e2e_float32_left_aligned_all_cache_stacked(self): - """end to end jetstream llama test with float32""" - jax.config.update("jax_platform_name", "cpu") - print(f"---------> {jax.devices()}") - - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.flash_attention=False - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=True - env_data.lazy_cache_update=False - env, model_arg = helpers.make_env_tiny(False, update_env_data) - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - - def test_llama_e2e_float32_left_aligned_lazy_cache_update(self): - """end to end jetstream llama test with float32""" - jax.config.update("jax_platform_name", "cpu") - print(f"---------> {jax.devices()}") - - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.flash_attention=True - env_data.generate_cache_stacked=False - env_data.new_cache_stacked=False - env_data.lazy_cache_update=True - env, model_arg = helpers.make_env_tiny(False, update_env_data) - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - - # Won't work after removed the cache.finalize() in the Transformer - def test_llama_e2e_float32_left_aligned_lazy_cache_update_generate_cache_stacked(self): - """end to end jetstream llama test with float32""" - jax.config.update("jax_platform_name", "cpu") - print(f"---------> {jax.devices()}") - - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.flash_attention=True - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=False - env_data.lazy_cache_update=True - env, model_arg = helpers.make_env_tiny(False, update_env_data) - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - - @unittest.skip("When generate cache is not stacked, new cache cannot stack") - def test_llama_e2e_float32_left_aligned_lazy_cache_update_new_cache_stacked(self): - """end to end jetstream llama test with float32""" - jax.config.update("jax_platform_name", "cpu") - print(f"---------> {jax.devices()}") - - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.flash_attention=True - env_data.generate_cache_stacked=False - env_data.new_cache_stacked=True - env_data.lazy_cache_update=True - env, model_arg = helpers.make_env_tiny(False, update_env_data) - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - - def test_llama_e2e_float32_left_aligned_lazy_cache_update_all_cache_stacked(self): - """end to end jetstream llama test with float32""" - jax.config.update("jax_platform_name", "cpu") - print(f"---------> {jax.devices()}") - - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.flash_attention=True - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=True - env_data.lazy_cache_update=True - env, model_arg = helpers.make_env_tiny(False, update_env_data) - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - def test_llama_e2e_int8_left_aligned_lazy_cache_update_generate_cache_stacked_new_cache_nonstacked(self): - """end to end jetstream llama test with float32""" + def test_llama_e2e_bfloat16(self): + "end to end jetstream llama test with bfloat16" jax.config.update("jax_platform_name", "cpu") + jax.config.update("jax_default_matmul_precision", jax.lax.Precision.HIGHEST) print(f"---------> {jax.devices()}") - def update_env_data(env_data): - env_data.ring_buffer=False - env_data.ragged_mha=False - env_data.flash_attention=True - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=False - env_data.lazy_cache_update=True - env_data.quant_config.enable_kv_quantization=True - env, model_arg = helpers.make_env_tiny(True, update_env_data) + env, model_arg = helpers.make_env_tiny(bf16_enable=True) out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - + self.assertNotEqual(out_tokens, expected_output_tokens) - def test_llama_e2e_int8_left_aligned_lazy_cache_update_all_cache_stacked(self): + @parameterized.named_parameters( + ("ring_buffer_f32", True, False, False), + ("left_aligned_f32", False, False, False), + ) + def test_llama_e2e_result_verification( + self, ring_buffer, quantized, bf16_enabled + ): """end to end jetstream llama test with float32""" jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") def update_env_data(env_data): - env_data.ring_buffer=False - env_data.ragged_mha=False - env_data.flash_attention=True - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=True - env_data.lazy_cache_update=True - env_data.quant_config.enable_kv_quantization=True - env_data.ragged_mha=False - - env, model_arg = helpers.make_env_tiny(True, update_env_data) + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.quant_config.enable_kv_quantization = quantized + + env, model_arg = helpers.make_env_tiny(bf16_enabled, update_env_data) out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertEqual(out_tokens, expected_output_tokens) - - def test_llama_e2e_int8_left_aligned_lazy_cache_update_all_cache_stacked(self): + @parameterized.named_parameters( + ("ring_buffer_int8", True, True, True), + ("ring_buffer_bf16", True, False, True), + ("left_aligned_int8", False, True, True), + ("left_aligned_bf16", False, False, True), + ) + def test_llama_e2e_no_result_verification( + self, ring_buffer, quantized, bf16_enabled + ): """end to end jetstream llama test with float32""" jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") def update_env_data(env_data): - env_data.ring_buffer=False - env_data.ragged_mha=False - env_data.flash_attention=True - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=True - env_data.lazy_cache_update=True - env_data.quant_config.enable_kv_quantization=True - env_data.ragged_mha=True - - env, model_arg = helpers.make_env_tiny(True, update_env_data) - out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) - self.assertEqual(out_tokens, expected_output_tokens) - - - def test_llama_e2e_bfloat16(self): - "end to end jetstream llama test with bfloat16" - jax.config.update("jax_platform_name", "cpu") - jax.config.update("jax_default_matmul_precision", jax.lax.Precision.HIGHEST) - print(f"---------> {jax.devices()}") - - env, model_arg = helpers.make_env_tiny(bf16_enable=True) + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.quant_config.enable_kv_quantization = quantized + + env, model_arg = helpers.make_env_tiny(bf16_enabled, update_env_data) out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertNotEqual(out_tokens, expected_output_tokens) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index a960cd7e..f48809ea 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -34,11 +34,12 @@ from torch.utils import _pytree as pytree from torch_xla2 import tensor import copy +from absl.testing import parameterized torch.manual_seed(12345) -class QuantizationTest(unittest.TestCase): +class QuantizationTest(parameterized.TestCase): """test kv cache quantization""" def _xla_tensor(self, shape): @@ -71,35 +72,47 @@ def _print_diff(self, w, w_dq): print(" norm: ", (w - w_dq).norm()) print(" cosine dist: ", self._calc_cosine_dist(w, w_dq)) - def test_kv_cache(self): + @parameterized.named_parameters( + ("ring_buffer", True), + ("left_aligned", False), + ) + def test_kv_cache(self, ring_buffer): """test kv cache quantization""" + def update_env_data(env_data): - env_data.ring_buffer=False - env_data.ragged_mha=False - env_data.flash_attention=True - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=True - env_data.lazy_cache_update=True + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer env_data.quant_config.enable_kv_quantization = True env_data.batch_size = 4 + env, _ = helpers.make_env_tiny(True, update_env_data) - + batch = env.batch_size if env.generate_cache_stacked: - cache_shape = (env.num_layers, batch, 2, 100, 2) # layer, bs, num heads, seqlen, dim + cache_shape = ( + env.num_layers, + batch, + 2, + 100, + 2, + ) # layer, bs, num heads, seqlen, dim else: cache_shape = (batch, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - - cache = cache_manager.Int8KVCacheGenerate.empty( - cache_shape, None, env - ) + + cache = cache_manager.Int8KVCacheGenerate.empty(cache_shape, None, env) # seqlen is 1 k = self._xla_tensor((batch, 2, 1, 2)) v = self._xla_tensor((batch, 2, 1, 2)) def update_finalize_compare(in_k, in_v, in_layer, in_pos): - cache.input_pos = [in_pos] if env.ring_buffer else jnp.array([in_pos] * batch) + cache.input_pos = ( + [in_pos] if env.ring_buffer else jnp.array([in_pos] * batch) + ) # layer id may or may not take effect, depends on the env config. cache.update(in_k, in_v, layer_id=in_layer) @@ -113,44 +126,68 @@ def update_finalize_compare(in_k, in_v, in_layer, in_pos): if env.generate_cache_stacked: self.assertTrue( - jnp.allclose(k._elem, new_k._elem[in_layer, :, :, in_pos:(in_pos + 1), :], atol=0.1) + jnp.allclose( + k._elem, + new_k._elem[in_layer, :, :, in_pos : (in_pos + 1), :], + atol=0.1, + ) ) self.assertTrue( - jnp.allclose(v._elem, new_v._elem[in_layer, :, :, in_pos:(in_pos + 1), :], atol=0.1) + jnp.allclose( + v._elem, + new_v._elem[in_layer, :, :, in_pos : (in_pos + 1), :], + atol=0.1, + ) ) else: self.assertTrue( - jnp.allclose(k._elem, new_k._elem[:, :, in_pos:(in_pos + 1), :], atol=0.1) + jnp.allclose( + k._elem, new_k._elem[:, :, in_pos : (in_pos + 1), :], atol=0.1 + ) ) self.assertTrue( - jnp.allclose(v._elem, new_v._elem[:, :, in_pos:(in_pos + 1), :], atol=0.1) + jnp.allclose( + v._elem, new_v._elem[:, :, in_pos : (in_pos + 1), :], atol=0.1 + ) ) + update_finalize_compare(k, v, in_layer=1, in_pos=57) update_finalize_compare(k, v, in_layer=1, in_pos=58) update_finalize_compare(k, v, in_layer=2, in_pos=3) - def test_kv_kernel(self): + @parameterized.named_parameters( + ("ring_buffer", True), + ("left_aligned", False), + ) + def test_kv_kernel(self, ring_buffer): """test kv cache quantization""" + def update_env_data(env_data): - env_data.ring_buffer=False - env_data.ragged_mha=False - env_data.flash_attention=True - env_data.generate_cache_stacked=True - env_data.new_cache_stacked=True - env_data.lazy_cache_update=True - env_data.quant_config.enable_kv_quantization=True + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.quant_config.enable_kv_quantization = True env_data.batch_size = 4 - env_data.ragged_mha = True + env, _ = helpers.make_env_tiny(False, update_env_data) batch = env.batch_size if env.generate_cache_stacked: - cache_shape = (env.num_layers, batch, 2, 100, 2) # bs, num heads, seqlen, dim + cache_shape = ( + env.num_layers, + batch, + 2, + 100, + 2, + ) # bs, num heads, seqlen, dim else: cache_shape = (batch, 2, 100, 2) # layers, bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - + key = jax.random.PRNGKey(123) key2 = jax.random.PRNGKey(456) cache_k_jax = jax.random.normal(key, cache_shape, dtype=env.default_type) @@ -158,8 +195,10 @@ def update_env_data(env_data): start = jnp.zeros((batch,), dtype=jnp.int32) - cache_k, cache_v, start = torchjax.to_torch((cache_k_jax, cache_v_jax, start)) - + cache_k, cache_v, start = torchjax.to_torch( + (cache_k_jax, cache_v_jax, start) + ) + # Prepare quantized cache before written in cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (-3, -1)) cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (-3, -1)) @@ -172,30 +211,48 @@ def update_env_data(env_data): xq, xk, xv = torchjax.to_torch((xq, xk, xv)) def get_var(position: int): - pos = [position] if env.ring_buffer else jnp.array([position] * batch, dtype=jnp.int64) - mask = jax.lax.broadcast_in_dim(jnp.array([0] * position + [float("-inf")] * (100 - position)), (env.batch_size, 1, 1, 100), (3,)) + pos = ( + [position] + if env.ring_buffer + else jnp.array([position] * batch, dtype=jnp.int64) + ) + mask = jax.lax.broadcast_in_dim( + jnp.array([0] * position + [float("-inf")] * (100 - position)), + (env.batch_size, 1, 1, 100), + (3,), + ) mask = torchjax.to_torch((mask)) return pos, mask - cache = cache_manager.KVCacheGenerate(cache_k, cache_v, None, None, env) # layer_id doesn't matter, will assign later attention_float = layers.AttentionKernel(env, layer_id=0) float_res = [] - def update_finalize_record(in_attention, in_cache, in_q, in_k, in_v, in_layer, in_pos): + + def update_finalize_record( + in_attention, in_cache, in_q, in_k, in_v, in_layer, in_pos + ): pos, mask = get_var(in_pos) - in_attention.layer_id=in_layer + in_attention.layer_id = in_layer in_cache.input_pos = pos - ret = in_attention(in_q, in_k, in_v, mask, in_cache, start=start, end=pos) + ret = in_attention( + in_q, in_k, in_v, mask, in_cache, start=start, end=pos + ) in_cache.finalize() return ret - float_res.append(update_finalize_record(attention_float, cache, xq, xk, xv, 1, 57)) - float_res.append(update_finalize_record(attention_float, cache, xq, xk, xv, 1, 58)) - float_res.append(update_finalize_record(attention_float, cache, xq, xk, xv, 2, 3)) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 1, 57) + ) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 1, 58) + ) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 2, 3) + ) - # Running into the issue of multiple env object always share the same quant_config. + # Running into the issue of multiple env object always share the same quant_config. # Record the results and compare as a workaround. env._data.quant_config.enable_kv_quantization = True env = environment.JetEngineEnvironment(env._data) @@ -213,13 +270,19 @@ def update_finalize_record(in_attention, in_cache, in_q, in_k, in_v, in_layer, i attention_quant = layers.Int8KVAttentionKernel(env, layer_id=0) int_res = [] - int_res.append(update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 57)) - int_res.append(update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 58)) - int_res.append(update_finalize_record(attention_quant, cache_int, xq, xk, xv, 2, 3)) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 57) + ) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 58) + ) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 2, 3) + ) for f, i in zip(float_res, int_res): self.assertTrue(jnp.allclose(f.jax(), i.jax(), atol=0.01)) - + def test_quantize_dequantize_tensor(self): def quantize_dequantize_weight(w, n_bit): From 4b6bfcb090553baa5e6bb076379cae6087b21fcb Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 18 Jul 2024 22:45:03 +0000 Subject: [PATCH 41/57] Fix the interactive script. --- run_interactive.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/run_interactive.py b/run_interactive.py index 4fe2a681..914ca6a2 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -46,7 +46,8 @@ def main(argv): if profiling_prefill: jax.profiler.start_trace(profiling_output) - decode_state = engine.init_decode_state() + decode_state = engine.init_decode_state() + if profiling_prefill: jax.profiler.stop_trace() prompts: List[str] = [ "I believe the meaning of life is", @@ -69,7 +70,8 @@ def main(argv): params=params, padded_tokens=tokens, true_length=true_length ) # pylint: disable-next=all - decode_state = engine.insert(prefill_result, decode_state, slot=slot) + decode_state = engine.insert(prefill_result, decode_state, slot=slot) + if profiling_prefill: jax.profiler.stop_trace() sampled_tokens_list = [] @@ -78,8 +80,10 @@ def main(argv): while True: if profiling_output: jax.profiler.start_trace(profiling_output) - decode_state, result_tokens = engine.generate(params, decode_state) - result_tokens = result_tokens.convert_to_numpy() + decode_state, result_tokens = engine.generate(params, decode_state) + result_tokens = result_tokens.convert_to_numpy() + + if profiling_output: jax.profiler.stop_trace() output, complete = token_utils.process_result_tokens( tokenizer=tokenizer, From 57cd1ed577a28cf2369c5ab415880b3aced351e5 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jul 2024 00:08:03 +0000 Subject: [PATCH 42/57] Fix lint errors. --- jetstream_pt/engine.py | 9 --------- jetstream_pt/third_party/gemma/model.py | 2 +- jetstream_pt/third_party/llama/model_exportable.py | 9 +++------ jetstream_pt/third_party/mixtral/model.py | 2 +- 4 files changed, 5 insertions(+), 17 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 56dc915f..3194259d 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -652,10 +652,6 @@ def update_mask(): mask = update_mask() next_token = self._sampling(logits, self.env.batch_size) - # print(f"current input pos: {decode_state.input_pos} and generated token is {next_token}") - # # for layer, (k,v) in enumerate(new_caches[0]): - # data = new_caches[0][0] * new_scales[0][0] if self.env.quant_config.enable_kv_quantization else new_caches[0][0] - # print(f"layer 0, scaled back k is {data}") if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 lens = decode_state.lens + 1 @@ -698,11 +694,6 @@ def update_mask(): input_pos, mask, ) - # print( - # "new_pos", - # (decode_state.current_position + 1) % self.env.cache_sequence_length, - # ) - # print(f"new_token: {jnp.squeeze(next_token)}") return new_decode_state, result_tokens # pylint: disable-next=all diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index ff0a903e..112c5813 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -342,7 +342,7 @@ def __init__(self, config: gemma_config.GemmaConfig, env): self.env = env self.layers = nn.ModuleList() - for layer_id, _ in enumerate(range(config.num_hidden_layers)): + for layer_id in range(config.num_hidden_layers): self.layers.append(GemmaDecoderLayer(config, env, layer_id)) self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, device=config.device diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 55497596..19b848a3 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -259,19 +259,16 @@ def forward( freqs_cis = self.freqs_cis[input_pos] freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - # Should check more thoroughly, as of now, when prefill, it's always not stacked. When generate, it's controlled by the parameter. - # target_cache_layers = 1 if self.env.generate_cache_stacked else len(self.layers) - # assert len(caches) == target_cache_layers, f"Number of caches ({len(caches)}) and layers ({target_cache_layers}) dont match" end = None if start is None else (start + input_pos) % self.env.cache_len # For stacked case, cannot get cache inside the loop which will cause cache copy - for layer_id, layer in enumerate(self.layers): + for layer in range(self.layers): if caches[0].stacked: cache = caches[0] else: - cache = caches[layer_id] + cache = caches[layer] # else: # For stacked case, there is only 1 yer of kv cache - with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): + with jax.named_scope("TransformerBlock_Layer_" + str(layer)): h = layer( h, freqs_cis, diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index f85688b0..89a66378 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -39,7 +39,7 @@ def __init__(self, config: ModelArgs, env) -> None: ) self.layers = nn.ModuleList( TransformerBlock(config, env, layer_id) - for layer_id, _ in enumerate(range(config.n_layer)) + for layer_id in range(config.n_layer) ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) LinearLayer = get_quantized_linear_layer(env.quant_config) From 1f5153651eb90db22f2619315ff8e4f61b1828c0 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jul 2024 00:18:42 +0000 Subject: [PATCH 43/57] Fix errors. --- jetstream_pt/third_party/llama/model_exportable.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 19b848a3..0a986f2a 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -261,14 +261,14 @@ def forward( end = None if start is None else (start + input_pos) % self.env.cache_len # For stacked case, cannot get cache inside the loop which will cause cache copy - for layer in range(self.layers): + for layer_id, layer in enumerate(self.layers): if caches[0].stacked: cache = caches[0] else: - cache = caches[layer] + cache = caches[layer_id] # else: # For stacked case, there is only 1 yer of kv cache - with jax.named_scope("TransformerBlock_Layer_" + str(layer)): + with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): h = layer( h, freqs_cis, From 3893e50aea69400f8c685b3dab2e2ecb58a1e647 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jul 2024 04:38:30 +0000 Subject: [PATCH 44/57] Fix the comments. --- jetstream_pt/attention_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 3acf992c..1efe40d0 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -262,7 +262,7 @@ def ragged_mqa_kernel_reference( normalize_var: bool, quantized: bool, ): - """Pallas kernel for flash attention.""" + """Pallas kernel for ragged attention.""" b, i = pl.program_id(0), pl.program_id(1) del layer_ref From 89c4e88dfcbe138218a01fe497dac175fde01d78 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jul 2024 21:39:45 +0000 Subject: [PATCH 45/57] Fix based on comments; Fix all the unit tests. --- jetstream_pt/cache_manager.py | 166 +++++++++++++----------- jetstream_pt/engine.py | 2 +- jetstream_pt/environment.py | 2 +- jetstream_pt/layers.py | 21 ++- jetstream_pt/third_party/gemma/model.py | 2 +- tests/test_model_impl.py | 13 +- 6 files changed, 111 insertions(+), 95 deletions(-) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 59220149..70f1bfa9 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -169,44 +169,48 @@ def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): def finalize(self): if not self.env.lazy_cache_update: return - # self.cache_k._elem = self.cache_k._elem.at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_ks._elem, -2)) - # self.cache_v._elem = self.cache_v._elem.at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_vs._elem, -2)) + # self.cache_k._elem = self.cache_k.jax().at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_ks.jax(), -2)) + # self.cache_v._elem = self.cache_v.jax().at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_vs.jax(), -2)) if self.env.ring_buffer: # Assume no cache stack for ring buffer - self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set( - self.new_ks._elem + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) ) - self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set( - self.new_vs._elem + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) ) else: if self.env.generate_cache_stacked: - layer, b, head, len, dim = self.cache_k.shape + _, b, head, _, dim = self.cache_k.shape if self.env.new_cache_stacked: - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax( + self.cache_k, self.cache_v = torch_xla2.interop.call_jax( self.update_single_cache_line, - self.cache_k._elem, - self.cache_v._elem, - self.new_ks._elem, - self.new_vs._elem, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, self.input_pos, ) else: for i in range(self.env.num_layers): - self.cache_k._elem = self.cache_k._elem.at[ - i, self.batch, :, self.input_pos, : - ].set(self.new_ks[i]._elem.reshape(b, head, dim)) - self.cache_v._elem = self.cache_v._elem.at[ - i, self.batch, :, self.input_pos, : - ].set(self.new_vs[i]._elem.reshape(b, head, dim)) + self.cache_k._elem = ( + self.cache_k.jax() + .at[i, self.batch, :, self.input_pos, :] + .set(self.new_ks[i].jax().reshape(b, head, dim)) + ) + self.cache_v._elem = ( + self.cache_v.jax() + .at[i, self.batch, :, self.input_pos, :] + .set(self.new_vs[i].jax().reshape(b, head, dim)) + ) else: # Try to use shard_map to get rid of the data copy - self.cache_k._elem, self.cache_v._elem = torch_xla2.interop.call_jax( + self.cache_k, self.cache_v = torch_xla2.interop.call_jax( self.update_single_cache_line, - self.cache_k._elem, - self.cache_v._elem, - self.new_ks._elem, - self.new_vs._elem, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, self.input_pos, ) @@ -233,33 +237,41 @@ def update(self, key, value, layer_id: int): elif self.env.ring_buffer: # Assume no cache stack for ring buffer # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set( - keyj + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(keyj) ) - self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set( - valuej + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(valuej) ) return self.cache_k, self.cache_v else: if self.env.generate_cache_stacked: # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[ - layer_id, self.batch, :, self.input_pos, : - ].set(keyj.squeeze(2)) + self.cache_k._elem = ( + self.cache_k.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) + ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[ - layer_id, self.batch, :, self.input_pos, : - ].set(valuej.squeeze(2)) + self.cache_v._elem = ( + self.cache_v.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) + ) return self.cache_k[layer_id], self.cache_v[layer_id] else: # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[ - self.batch, :, self.input_pos, : - ].set(keyj.squeeze(2)) + self.cache_k._elem = ( + self.cache_k.jax() + .at[self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) + ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[ - self.batch, :, self.input_pos, : - ].set(valuej.squeeze(2)) + self.cache_v._elem = ( + self.cache_v.jax() + .at[self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) + ) return self.cache_k, self.cache_v def state(self): @@ -553,11 +565,11 @@ def finalize(self): return if self.env.ring_buffer: # Assume no cache stack for ring buffer - self.cache_k._elem = self.cache_k._elem.at[..., self.input_pos, :].set( - self.new_ks._elem + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) ) - self.cache_v._elem = self.cache_v._elem.at[..., self.input_pos, :].set( - self.new_vs._elem + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) ) else: if self.env.generate_cache_stacked: @@ -565,67 +577,67 @@ def finalize(self): if self.env.new_cache_stacked: # new kv scaler also has to go through shard_map instead of indexing because it needs to reshape to (batch, layer) which mess up with the data caches = [ - self.cache_k._elem, - self.cache_v._elem, - self.new_ks._elem, - self.new_vs._elem, - self.k_scaler._elem, - self.v_scaler._elem, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, self.new_k_scaler, self.new_v_scaler, ] ( - self.cache_k._elem, - self.cache_v._elem, - self.k_scaler._elem, - self.v_scaler._elem, + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, ) = torch_xla2.interop.call_jax( self.update_single_cache_line, *caches, self.input_pos ) else: # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. caches = [ - self.cache_k._elem, - self.cache_v._elem, + self.cache_k, + self.cache_v, self.new_ks, self.new_vs, - self.k_scaler._elem, - self.v_scaler._elem, + self.k_scaler, + self.v_scaler, self.new_k_scaler, self.new_v_scaler, ] ( - self.cache_k._elem, - self.cache_v._elem, - self.k_scaler._elem, - self.v_scaler._elem, + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, ) = torch_xla2.interop.call_jax( self.update_single_cache_line, *caches, self.input_pos ) # for i in range(self.env.num_layers): - # self.cache_k._elem = self.cache_k._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i]._elem.reshape(b, head, dim)) - # self.cache_v._elem = self.cache_v._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i]._elem.reshape(b, head, dim)) - # self.k_scaler._elem = self.k_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_k_scaler[i]._elem.reshape(b, 1, 1)) - # self.v_scaler._elem = self.v_scaler._elem.at[i, self.batch, :, self.input_pos, :].set(self.new_v_scaler[i]._elem.reshape(b, 1, 1)) + # self.cache_k._elem = self.cache_k.jax().at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i].jax().reshape(b, head, dim)) + # self.cache_v._elem = self.cache_v.jax().at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i].jax().reshape(b, head, dim)) + # self.k_scaler._elem = self.k_scaler.jax().at[i, self.batch, :, self.input_pos, :].set(self.new_k_scaler[i].jax().reshape(b, 1, 1)) + # self.v_scaler._elem = self.v_scaler.jax().at[i, self.batch, :, self.input_pos, :].set(self.new_v_scaler[i].jax().reshape(b, 1, 1)) else: # Try to use shard_map to get rid of the data copy b = self.cache_k.shape[-4] ( - self.cache_k._elem, - self.cache_v._elem, - self.k_scaler._elem, - self.v_scaler._elem, + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, ) = torch_xla2.interop.call_jax( self.update_single_cache_line, - self.cache_k._elem, - self.cache_v._elem, - self.new_ks._elem, - self.new_vs._elem, - self.k_scaler._elem, - self.v_scaler._elem, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, self.new_k_scaler, self.new_v_scaler, self.input_pos, ) - # self.k_scaler._elem = self.k_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_k_scaler._elem.reshape(b, 1, 1)) - # self.v_scaler._elem = self.v_scaler._elem.at[self.batch, :, self.input_pos, :].set(self.new_v_scaler._elem.reshape(b, 1, 1)) + # self.k_scaler._elem = self.k_scaler.jax().at[self.batch, :, self.input_pos, :].set(self.new_k_scaler.jax().reshape(b, 1, 1)) + # self.v_scaler._elem = self.v_scaler.jax().at[self.batch, :, self.input_pos, :].set(self.new_v_scaler.jax().reshape(b, 1, 1)) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 3194259d..79dfb945 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -145,7 +145,7 @@ def init_decode_state( (self.env.batch_size, self.env.cache_sequence_length), float("-inf"), dtype=self.default_dtype, - ), + ), # mask ) # pylint: disable-next=all diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 3e646232..84289d90 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -92,7 +92,7 @@ class JetEngineEnvironmentData: block_size: int = 512 # Starting position - starting_position: int = 512 + starting_position: int = 0 # Ring buffer ring_buffer: bool = True diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index e9542703..607cf17f 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -368,18 +368,12 @@ def apply_rotary_emb( def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep).""" - bs, n_kv_heads, slen, head_dim = ( - x.shape[-4], - x.shape[-3], - x.shape[-2], - x.shape[-1], - ) - if x.ndim == 5: - stacked = True - else: - stacked = False + *_, bs, n_kv_heads, slen, head_dim = x.shape + stacked = x.ndim == 5 + if n_rep == 1: return x + if stacked: layer = x.shape[0] return ( @@ -832,8 +826,13 @@ def forward( xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) + if mask.ndim == 2: - mask = mask[:, None, None, :] + if seqlen == 1: + mask = mask[:, None, None, :] + else: + mask = mask[None, None, :, :] + # if cache is not None and cache.cache_k is not None: # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 112c5813..5773b8bd 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -136,7 +136,7 @@ def __init__( if env.quant_config.enable_kv_quantization else layers.AttentionKernel ) - self.attention_kernel = Kernel(env) + self.attention_kernel = Kernel(env, layer_id) def forward( self, diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index a3472760..703ce444 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -65,9 +65,13 @@ def _make_freqs_cis(self, model_arg, seqlen, start_pos): freqs_cis = freqs_cis[start_pos : start_pos + seqlen] return freqs_cis - def _generate_mask(self, cache_length, pos, seqlen): + def _generate_mask(self, cache_length, pos, seqlen, ring_buffer=True): x = jnp.arange(0, cache_length) - cond = jnp.logical_and(x < pos, x >= pos - seqlen) + if ring_buffer: + cond = jnp.logical_and(x <= pos, x >= pos - seqlen) + else: + # Left aligned buffer we postpone the cache update + cond = jnp.logical_and(x < pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) return torchjax.to_torch(res) @@ -205,6 +209,7 @@ def init_weights(model): head_dim=head_dim, device="meta", env=env, + layer_id=0, ) def load_hook(state_dict, prefix, *args): @@ -230,8 +235,8 @@ def load_hook(state_dict, prefix, *args): freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) mask = self._prefill_mask(seqlen, start_pos) kv_write_indexes = torch.arange(0, seqlen) - cache_k = torch.zeros((batch, seqlen, num_heads, head_dim)) - cache_v = torch.zeros((batch, seqlen, num_heads, head_dim)) + cache_k = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) + cache_v = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask) expected_out = attention_orig(*inputs_orig) From 004269b09af0b95207648096ef428f021b251b1d Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jul 2024 21:41:34 +0000 Subject: [PATCH 46/57] Fix the remaining pylint errors. --- benchmarks/run_offline.py | 7 ++++--- run_interactive.py | 8 ++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index bcbe9704..e43344b0 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -33,6 +33,7 @@ profiler_started = False + def run_prefill_time(engine, params, decode_state, seqlen): """Run prefill and measure time.""" metadata = engine.get_tokenizer() @@ -55,9 +56,9 @@ def run_prefill_time(engine, params, decode_state, seqlen): start = time.perf_counter() for i in range(nums): if i == nums - 1 and FLAGS.profiling_prefill: - jax.profiler.start_trace(FLAGS.profiling_output) - profiler_started = True - + jax.profiler.start_trace(FLAGS.profiling_output) + profiler_started = True + prefill_result, _ = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) diff --git a/run_interactive.py b/run_interactive.py index 914ca6a2..a4ab8053 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -42,7 +42,11 @@ def main(argv): max_output_length = 1024 profiling_output = FLAGS.profiling_output - profiling_prefill = FLAGS.profiling_prefill and profiling_output is not None and profiling_output != "" + profiling_prefill = ( + FLAGS.profiling_prefill + and profiling_output is not None + and profiling_output != "" + ) if profiling_prefill: jax.profiler.start_trace(profiling_output) @@ -82,7 +86,7 @@ def main(argv): jax.profiler.start_trace(profiling_output) decode_state, result_tokens = engine.generate(params, decode_state) result_tokens = result_tokens.convert_to_numpy() - + if profiling_output: jax.profiler.stop_trace() output, complete = token_utils.process_result_tokens( From d0777fd4524e7864e4a1464376b24fb1341415e7 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jul 2024 21:53:35 +0000 Subject: [PATCH 47/57] Default ring buffer back to true so that all the test_run_server and run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run. --- jetstream_pt/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 75e11bab..9b1498d3 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -86,31 +86,31 @@ ) flags.DEFINE_bool( "ring_buffer", - False, + True, "Whether to enable ring buffer", required=False, ) flags.DEFINE_bool( "flash_attention", - True, + False, "Whether to enable flas attention. Only takes effect at test mode", required=False, ) flags.DEFINE_bool( "generate_cache_stacked", - True, + False, "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", required=False, ) flags.DEFINE_bool( "new_cache_stacked", - True, + False, "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", required=False, ) flags.DEFINE_bool( "lazy_cache_update", - True, + False, "Whether to update the cache during attention or delayed until all the layers are done. Only takes effect at test mode", required=False, ) From e99a815a061c94b43cd4ed3aa0a3f8ce37c642a7 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jul 2024 22:29:57 +0000 Subject: [PATCH 48/57] Fix all the lint errors. --- benchmarks/run_offline.py | 7 +- jetstream_pt/attention_kernel.py | 6 +- jetstream_pt/cache_manager.py | 153 +++++++++++++++++-------------- jetstream_pt/config.py | 3 +- 4 files changed, 90 insertions(+), 79 deletions(-) diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index e43344b0..3b4dcc42 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -31,10 +31,8 @@ flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") -profiler_started = False - -def run_prefill_time(engine, params, decode_state, seqlen): +def run_prefill_time(engine, params, decode_state, seqlen, profiler_started): """Run prefill and measure time.""" metadata = engine.get_tokenizer() tokenizer = engine.build_tokenizer(metadata) @@ -93,9 +91,10 @@ def main(argv): prefill_times = {} decode_state = engine.init_decode_state() + profiler_started = False for batch, _ in MAXTEXT_PREFILL.items(): runtime, decode_state = run_prefill_time( - engine, params, decode_state, batch + engine, params, decode_state, batch, profiler_started ) prefill_times[batch] = runtime diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 1efe40d0..6d571d2c 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -497,7 +497,6 @@ def ragged_mha( mask_value = DEFAULT_MASK_VALUE bk = min(bk, k.shape[-2]) bq, hq, tq, dq = q.shape - dk = k.shape[-1] hkv = k.shape[-3] tk = k.shape[-2] @@ -507,7 +506,7 @@ def ragged_mha( rep = hq // hkv if rep > 1: q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) - stacked = True if k.ndim == 5 else False + stacked = k.ndim == 5 replicated_in_axes = 7 if k_scaler is None: @@ -596,7 +595,7 @@ def flash_attention( mask=None, normalize_var=True, ): - mask_value: float = DEFAULT_MASK_VALUE + """Flash attention kernel.""" if keys.ndim == 5: keys = keys[layer] values = values[layer] @@ -611,7 +610,6 @@ def flash_attention( logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama # Quantized if k_scaler is not None: - bs, hs, ls, ds = k_scaler.shape logits = logits * k_scaler.reshape( k_scaler.shape[-4], 1, 1, k_scaler.shape[-2] ) diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 70f1bfa9..76f44120 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -14,11 +14,12 @@ import jax import jax.numpy as jnp -import torch -from jetstream_pt import torchjax from jax.experimental.shard_map import shard_map +import torch import torch_xla2 +from jetstream_pt import torchjax + # pylint: disable-next=all class CacheInterface: @@ -63,6 +64,7 @@ def state(self): # Placeholder, to match with GenerateCache def finalize(self): + """Finalize the cache operation and updates the cache.""" return @@ -87,11 +89,11 @@ def KVCachePrefill_unflatten(auxdata, data): ) -# Refactor out cache management -# Easier to test for quantized kv cache class KVCacheGenerate: """Kvache generator without quantization""" + # pylint: disable=too-many-instance-attributes + # More than 7 is reasonable in this case. def __init__( self, cache_k: torch.Tensor, # previous cache @@ -117,7 +119,7 @@ def __init__( if self.env.lazy_cache_update: if self.env.generate_cache_stacked: if self.env.new_cache_stacked: - layer, batch, heads, time, dim = self.cache_k.shape + layer, batch, heads, _, dim = self.cache_k.shape new_dim = (layer, batch, heads, 1, dim) self.new_ks, self.new_vs = torchjax.to_torch( ( @@ -146,7 +148,10 @@ def __init__( ) ) + # pylint: disable=method-hidden + # False alarm. The jit above doesn't hide this method. def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): + """The shard map version of single cache line update.""" b = cache_k.shape[-4] for bb, pp in enumerate(pos.reshape(b)): slice_dim = 0 @@ -167,15 +172,17 @@ def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): return cache_k, cache_v def finalize(self): + """Finalize the cache operation and updates the cache.""" if not self.env.lazy_cache_update: return - # self.cache_k._elem = self.cache_k.jax().at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_ks.jax(), -2)) - # self.cache_v._elem = self.cache_v.jax().at[:, :, :, self.input_pos].set(jnp.squeeze(self.new_vs.jax(), -2)) + if self.env.ring_buffer: # Assume no cache stack for ring buffer + # pylint: disable-next=all self.cache_k._elem = ( self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) ) + # pylint: disable-next=all self.cache_v._elem = ( self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) ) @@ -193,11 +200,13 @@ def finalize(self): ) else: for i in range(self.env.num_layers): + # pylint: disable-next=all self.cache_k._elem = ( self.cache_k.jax() .at[i, self.batch, :, self.input_pos, :] .set(self.new_ks[i].jax().reshape(b, head, dim)) ) + # pylint: disable-next=all self.cache_v._elem = ( self.cache_v.jax() .at[i, self.batch, :, self.input_pos, :] @@ -216,67 +225,74 @@ def finalize(self): def update(self, key, value, layer_id: int): """Update kv cache""" - # Will process in insert() at the end of the transformer forward pass keyj, valuej = torchjax.to_torch((key, value)) if self.env.lazy_cache_update: - # When new cache stacked, must have generate_cache_stacked if self.env.new_cache_stacked: + assert ( + self.env.generate_cache_stacked + ), "When new cache stacked, must have generate_cache_stacked!" self.new_ks[layer_id, ...] = keyj self.new_vs[layer_id, ...] = valuej return self.cache_k[layer_id], self.cache_v[layer_id] - else: - if self.env.generate_cache_stacked: - self.new_ks.append(keyj) - self.new_vs.append(valuej) - return self.cache_k[layer_id], self.cache_v[layer_id] - else: - self.new_ks = keyj - self.new_vs = valuej - return self.cache_k, self.cache_v - elif self.env.ring_buffer: - # Assume no cache stack for ring buffer + # Generate cache stacked, but new cache unstacked + if self.env.generate_cache_stacked: + self.new_ks.append(keyj) + self.new_vs.append(valuej) + return self.cache_k[layer_id], self.cache_v[layer_id] + + # all cache unstacked + self.new_ks = keyj + self.new_vs = valuej + return self.cache_k, self.cache_v + + if self.env.ring_buffer: + assert ( + not self.env.new_cache_stacked and not self.env.generate_cache_stacked + ), "Ring buffer doesn't support stacked cache." # pylint: disable-next=all self.cache_k._elem = ( self.cache_k.jax().at[..., self.input_pos, :].set(keyj) ) + # pylint: disable-next=all self.cache_v._elem = ( self.cache_v.jax().at[..., self.input_pos, :].set(valuej) ) return self.cache_k, self.cache_v - else: - if self.env.generate_cache_stacked: - # pylint: disable-next=all - self.cache_k._elem = ( - self.cache_k.jax() - .at[layer_id, self.batch, :, self.input_pos, :] - .set(keyj.squeeze(2)) - ) - # pylint: disable-next=all - self.cache_v._elem = ( - self.cache_v.jax() - .at[layer_id, self.batch, :, self.input_pos, :] - .set(valuej.squeeze(2)) - ) - return self.cache_k[layer_id], self.cache_v[layer_id] - else: - # pylint: disable-next=all - self.cache_k._elem = ( - self.cache_k.jax() - .at[self.batch, :, self.input_pos, :] - .set(keyj.squeeze(2)) - ) - # pylint: disable-next=all - self.cache_v._elem = ( - self.cache_v.jax() - .at[self.batch, :, self.input_pos, :] - .set(valuej.squeeze(2)) - ) - return self.cache_k, self.cache_v + + # Non lazy cache update, non ring buffer, generate cache stacked + if self.env.generate_cache_stacked: + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) + ) + return self.cache_k[layer_id], self.cache_v[layer_id] + + # Non lazy cache update, non ring buffer, generate cache non stacked + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax() + .at[self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax() + .at[self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) + ) + return self.cache_k, self.cache_v def state(self): """Get kv cache state""" - # pylint: disable-next=all return self.cache_k.jax(), self.cache_v.jax() @classmethod @@ -320,7 +336,8 @@ def KVCacheGenerate_unflatten(auxdata, data): class Int8KVCacheGenerate: """Int8 quantized kvache with scalers""" - # pylint: disable-next=all + # pylint: disable=too-many-instance-attributes + # More than 7 is reasonable in this case. def __init__( self, cache_k, @@ -349,7 +366,7 @@ def __init__( if self.env.lazy_cache_update: if self.env.generate_cache_stacked: - layer, batch, heads, time, dim = self.cache_k.shape + layer, batch, heads, _, dim = self.cache_k.shape new_kv_dim = (layer, batch, heads, 1, dim) self.new_ks, self.new_vs = torchjax.to_torch( ( @@ -399,6 +416,8 @@ def __init__( ) self.update_single_cache_line = jax.jit(self.update_single_cache_line) + # pylint: disable=method-hidden + # False alarm. The jit above doesn't hide this method. def update_single_cache_line( self, cache_k, @@ -411,6 +430,7 @@ def update_single_cache_line( new_v_scaler, pos, ): + """The shard map version of single cache line update.""" b = cache_k.shape[-4] for bb, pp in enumerate(pos.reshape(b)): @@ -421,10 +441,10 @@ def update_single_cache_line( slice_dim = 1 update_start_indices = (0, bb, 0, pp, 0) if self.env.generate_cache_stacked and not self.env.new_cache_stacked: - for slice in range(self.env.num_layers): - update_start_indices = (slice, bb, 0, pp, 0) + for layer in range(self.env.num_layers): + update_start_indices = (layer, bb, 0, pp, 0) new_ks_slice = jax.lax.dynamic_slice_in_dim( - new_ks[slice], bb, 1, slice_dim + new_ks[layer], bb, 1, slice_dim ) new_ks_slice = jnp.expand_dims(new_ks_slice, 0) cache_k = jax.lax.dynamic_update_slice( @@ -432,7 +452,7 @@ def update_single_cache_line( ) new_vs_slice = jax.lax.dynamic_slice_in_dim( - new_vs[slice], bb, 1, slice_dim + new_vs[layer], bb, 1, slice_dim ) new_vs_slice = jnp.expand_dims(new_vs_slice, 0) cache_v = jax.lax.dynamic_update_slice( @@ -440,7 +460,7 @@ def update_single_cache_line( ) new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( - new_k_scaler[slice], bb, 1, slice_dim + new_k_scaler[layer], bb, 1, slice_dim ) new_k_scaler_slice = jnp.expand_dims(new_k_scaler_slice, 0) k_scaler = jax.lax.dynamic_update_slice( @@ -448,7 +468,7 @@ def update_single_cache_line( ) new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( - new_v_scaler[slice], bb, 1, slice_dim + new_v_scaler[layer], bb, 1, slice_dim ) new_v_scaler_slice = jnp.expand_dims(new_v_scaler_slice, 0) v_scaler = jax.lax.dynamic_update_slice( @@ -561,21 +581,24 @@ def update(self, xk, xv, layer_id: int): ) def finalize(self): + """Finalize the cache operation and updates the cache.""" if not self.env.lazy_cache_update: return if self.env.ring_buffer: # Assume no cache stack for ring buffer + # pylint: disable-next=all self.cache_k._elem = ( self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) ) + # pylint: disable-next=all self.cache_v._elem = ( self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) ) else: if self.env.generate_cache_stacked: - layer, b, head, _, dim = self.cache_k.shape if self.env.new_cache_stacked: - # new kv scaler also has to go through shard_map instead of indexing because it needs to reshape to (batch, layer) which mess up with the data + # new kv scaler also has to go through shard_map instead of indexing + # because it needs to reshape to (batch, layer) which mess up with the data caches = [ self.cache_k, self.cache_v, @@ -595,7 +618,6 @@ def finalize(self): self.update_single_cache_line, *caches, self.input_pos ) else: - # We don't optimize generate_cache_stacked=True but new_cache_stacked=False yet. caches = [ self.cache_k, self.cache_v, @@ -614,14 +636,7 @@ def finalize(self): ) = torch_xla2.interop.call_jax( self.update_single_cache_line, *caches, self.input_pos ) - # for i in range(self.env.num_layers): - # self.cache_k._elem = self.cache_k.jax().at[i, self.batch, :, self.input_pos, :].set(self.new_ks[i].jax().reshape(b, head, dim)) - # self.cache_v._elem = self.cache_v.jax().at[i, self.batch, :, self.input_pos, :].set(self.new_vs[i].jax().reshape(b, head, dim)) - # self.k_scaler._elem = self.k_scaler.jax().at[i, self.batch, :, self.input_pos, :].set(self.new_k_scaler[i].jax().reshape(b, 1, 1)) - # self.v_scaler._elem = self.v_scaler.jax().at[i, self.batch, :, self.input_pos, :].set(self.new_v_scaler[i].jax().reshape(b, 1, 1)) else: - # Try to use shard_map to get rid of the data copy - b = self.cache_k.shape[-4] ( self.cache_k, self.cache_v, @@ -639,5 +654,3 @@ def finalize(self): self.new_v_scaler, self.input_pos, ) - # self.k_scaler._elem = self.k_scaler.jax().at[self.batch, :, self.input_pos, :].set(self.new_k_scaler.jax().reshape(b, 1, 1)) - # self.v_scaler._elem = self.v_scaler.jax().at[self.batch, :, self.input_pos, :].set(self.new_v_scaler.jax().reshape(b, 1, 1)) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 9b1498d3..70b530fc 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -111,7 +111,8 @@ flags.DEFINE_bool( "lazy_cache_update", False, - "Whether to update the cache during attention or delayed until all the layers are done. Only takes effect at test mode", + "Whether to update the cache during attention or delayed until all the layers are done. " + "Only takes effect at test mode", required=False, ) flags.DEFINE_float( From 223338f971ce3047e9d824ca7f412f910ebfc1bb Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Sat, 20 Jul 2024 01:23:03 +0000 Subject: [PATCH 49/57] Fix run_offline script. --- benchmarks/run_offline.py | 7 ++++--- deps/JetStream | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 3b4dcc42..d5e85d2f 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -53,7 +53,7 @@ def run_prefill_time(engine, params, decode_state, seqlen, profiler_started): nums = 5 start = time.perf_counter() for i in range(nums): - if i == nums - 1 and FLAGS.profiling_prefill: + if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started: jax.profiler.start_trace(FLAGS.profiling_output) profiler_started = True @@ -66,7 +66,7 @@ def run_prefill_time(engine, params, decode_state, seqlen, profiler_started): jax.block_until_ready(decode_state) end = time.perf_counter() - return (end - start) / nums, decode_state + return (end - start) / nums, decode_state, profiler_started MAXTEXT_PREFILL = { @@ -93,7 +93,7 @@ def main(argv): decode_state = engine.init_decode_state() profiler_started = False for batch, _ in MAXTEXT_PREFILL.items(): - runtime, decode_state = run_prefill_time( + runtime, decode_state, profiler_started = run_prefill_time( engine, params, decode_state, batch, profiler_started ) prefill_times[batch] = runtime @@ -109,6 +109,7 @@ def main(argv): profiling_output = FLAGS.profiling_output print("======= decode starting ===") + dec_times = [] for i in range(10): if profiling_output and i == 7 and not profiler_started: diff --git a/deps/JetStream b/deps/JetStream index 69ce8a26..26872c3c 160000 --- a/deps/JetStream +++ b/deps/JetStream @@ -1 +1 @@ -Subproject commit 69ce8a2646ac32bea9194019078248b49e69728e +Subproject commit 26872c3c6e726f52f5bac1cb63e60a9a2a0bbe8a From 1444e07a341548f6395a16d35d1dd451f7b66af7 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 30 Jul 2024 00:06:13 +0000 Subject: [PATCH 50/57] Fix the ring buffer mode long latency issue. --- jetstream_pt/layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 607cf17f..7fd62c6f 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -454,8 +454,9 @@ def attend(xq, keys, values, local_mask=None): true_len = seqlen # When GQA is enabled, it not necessary to expand - if n_rep == 1 and seqlen == 1: + if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1: true_len = 2 + #xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) xq = torch.nn.functional.pad( xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) @@ -611,12 +612,11 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): true_len = seqlen # When GQA is enabled, it not necessary to expand - if n_rep == 1 and seqlen == 1: + if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1: true_len = 2 xq = torch.nn.functional.pad( xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) - # xq = torch.broadcast_to(xq, (bsz, num_heads, true_len, head_dim)) # We are not using ragged attention for prefill yet. if self.env.ragged_mha and seqlen == 1: From 595ead2b4718f75d867e6d6a92169f66d395613f Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 5 Aug 2024 18:51:14 +0000 Subject: [PATCH 51/57] Rebase to main. --- deps/JetStream | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/JetStream b/deps/JetStream index 26872c3c..69ce8a26 160000 --- a/deps/JetStream +++ b/deps/JetStream @@ -1 +1 @@ -Subproject commit 26872c3c6e726f52f5bac1cb63e60a9a2a0bbe8a +Subproject commit 69ce8a2646ac32bea9194019078248b49e69728e From d14e7f5a9a9e8a279a8e8367605d5545e6d4c76e Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 6 Aug 2024 00:20:21 +0000 Subject: [PATCH 52/57] Fix all the lint issues. --- benchmarks/run_offline.py | 2 +- jetstream_pt/layers.py | 3 +- run_interactive.py | 10 ++-- run_interactive_disaggregated.py | 9 ++-- run_interactive_multiple_host.py | 1 - run_ray_serve_interleave.py | 12 ++++- run_server.py | 1 - run_server_with_ray.py | 1 - tests/helpers.py | 6 ++- tests/test_hf_names.py | 16 +++++++ tests/test_llama_e2e.py | 4 +- tests/test_model_impl.py | 41 +++++++++------- tests/test_quantization.py | 65 ++++++++++++++----------- tests/test_run_server.py | 81 ++++++++++++++++++-------------- 14 files changed, 153 insertions(+), 99 deletions(-) diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index d5e85d2f..1fdc0cb7 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -109,7 +109,7 @@ def main(argv): profiling_output = FLAGS.profiling_output print("======= decode starting ===") - + dec_times = [] for i in range(10): if profiling_output and i == 7 and not profiler_started: diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 7fd62c6f..d484df93 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -456,7 +456,7 @@ def attend(xq, keys, values, local_mask=None): # When GQA is enabled, it not necessary to expand if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1: true_len = 2 - #xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + # xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) xq = torch.nn.functional.pad( xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) @@ -714,6 +714,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): return attn_out + class Attention(ModuleBase): """Attention module.""" diff --git a/run_interactive.py b/run_interactive.py index a4ab8053..86cd6df2 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -18,13 +18,10 @@ from typing import List # import torch_xla2 first! -import torch_xla2 # pylint: disable import jax import numpy as np -from absl import app, flags -from colorama import Fore, Style +from absl import app from jetstream.engine import token_utils -from jetstream_pt import engine as je from jetstream_pt.config import FLAGS, create_engine_from_config_flags @@ -54,10 +51,15 @@ def main(argv): if profiling_prefill: jax.profiler.stop_trace() prompts: List[str] = [ + # pylint: disable-next=all "I believe the meaning of life is", + # pylint: disable-next=all "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", + # pylint: disable-next=all "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", + # pylint: disable-next=all "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", + # pylint: disable-next=all "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: diff --git a/run_interactive_disaggregated.py b/run_interactive_disaggregated.py index 0d11796e..b6ffb43c 100644 --- a/run_interactive_disaggregated.py +++ b/run_interactive_disaggregated.py @@ -19,9 +19,7 @@ from typing import List from absl import app from absl import flags -from colorama import Fore, Style -import numpy as np import jax from jetstream.engine import token_utils @@ -129,7 +127,6 @@ def main(argv): print("Load params ", time.perf_counter() - start) metadata = prefill_engine.get_tokenizer() - tokenizer = prefill_engine.build_tokenizer(metadata) vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) stop_tokens = [vocab.eos_id, vocab.pad_id] max_output_length = 1024 @@ -157,19 +154,21 @@ def main(argv): print(f"---- Input prompts are: {prompt}") print(f"---- Encoded tokens are: {tokens}") - # pylint: disable-next=all print( + # pylint: disable-next=all f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}" ) prefill_result, _ = prefill_engine.prefill( params=None, padded_tokens=tokens, true_length=true_length ) print( + # pylint: disable-next=all f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}" ) decode_engine.transfer(prefill_result) - # pylint: disable-next=all + print( + # pylint: disable-next=all f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}" ) decode_state = decode_engine.insert(prefill_result, None, slot=slot) diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 24b27987..9192076a 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -19,7 +19,6 @@ import jax from absl import app, flags -from colorama import Fore, Style from jetstream.engine import token_utils from jetstream_pt import ray_engine from jetstream_pt.config import FLAGS diff --git a/run_ray_serve_interleave.py b/run_ray_serve_interleave.py index 6d4edb5d..a6b3421d 100644 --- a/run_ray_serve_interleave.py +++ b/run_ray_serve_interleave.py @@ -40,6 +40,7 @@ def create_head_resource_name(generation, tpu_chips): + """Create head resource name.""" return f"TPU-{generation}-{tpu_chips}-head" @@ -73,6 +74,7 @@ def create_engine(**kwargs): @serve.deployment class JetStreamDeployment: + """JetStream deployment.""" def __init__(self, **kwargs): os.environ["XLA_FLAGS"] = ( @@ -111,18 +113,24 @@ def __init__(self, **kwargs): print("Started jetstream driver....") + # pylint: disable-next=all async def Decode( - self, request: jetstream_pb2.DecodeRequest + self, + # pylint: disable-next=all + request: jetstream_pb2.DecodeRequest, + # pylint: disable-next=all ) -> AsyncIterator[jetstream_pb2.DecodeResponse]: - + """Async decode function.""" return self.orchestrator.Decode(request) def main(_argv): + """Main function""" resource_name = create_head_resource_name( FLAGS.tpu_generation, FLAGS.tpu_chips ) print(f"Using head resource {resource_name}") + # pylint: disable-next=all deployment = JetStreamDeployment.options( ray_actor_options={"resources": {resource_name: 1}} ).bind( diff --git a/run_server.py b/run_server.py index be5933ec..6f9fbed8 100644 --- a/run_server.py +++ b/run_server.py @@ -17,7 +17,6 @@ from typing import Sequence # import torch_xla2 first! -import torch_xla2 # pylint: disable import jax from absl import app, flags from jetstream.core import server_lib diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 03489e1a..97592804 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -19,7 +19,6 @@ from absl import app, flags # import torch_xla2 first! -import torch_xla2 # pylint: disable import jax from jetstream.core import server_lib from jetstream.core.config_lib import ServerConfig diff --git a/tests/helpers.py b/tests/helpers.py index 62c0789b..3c5cb4ec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,6 +6,7 @@ from jetstream_pt import environment +# pylint: disable-next=all def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) @@ -33,6 +34,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): return env, config +# pylint: disable-next=all def make_mixtral_env(bf16_enable=True): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) @@ -57,14 +59,16 @@ def make_mixtral_env(bf16_enable=True): return env, config +# pylint: disable-next=all def to_xla_tensor(tree): return torch_xla2.default_env().to_xla(tree) +# pylint: disable-next=all def call_xla_model(model, weights, args): with jax.default_device(jax.devices("cpu")[0]): xla_weights, xla_inputs = to_xla_tensor((weights, args)) with torch_xla2.default_env(): result = torch.func.functional_call(model, xla_weights, xla_inputs) - result_torch = torch_xla2.tensor.j2t(result._elem) + result_torch = torch_xla2.tensor.j2t(result.jax()) return result_torch diff --git a/tests/test_hf_names.py b/tests/test_hf_names.py index c2230cde..83b76425 100644 --- a/tests/test_hf_names.py +++ b/tests/test_hf_names.py @@ -4,10 +4,13 @@ class TestModuleBase(unittest.TestCase): + """Test module base.""" def test_get_hf_names_to_real_name(self): + """Test get hugginface names to real name.""" class MyModule(ModuleBase): + """My module.""" def __init__(self): super().__init__() @@ -18,6 +21,9 @@ def __init__(self): self.param = torch.nn.Parameter(torch.randn(10)) self.hf_name("param", "model.param") + def forward(self): + """Forward function.""" + module = MyModule() expected_mapping = { "model.my_linear1.weight": "linear1.weight", @@ -30,7 +36,10 @@ def __init__(self): self.assertEqual(module.get_hf_names_to_real_name(), expected_mapping) def test_get_sharding_annotations(self): + """Test get sharding annotations.""" + class MyModule(ModuleBase): + """MyModule.""" def __init__(self): super().__init__() @@ -38,12 +47,19 @@ def __init__(self): self.embedding = torch.nn.Embedding(100, 50) self.inner = InnerModule() + def forward(self): + """Forward function.""" + class InnerModule(ModuleBase): + """Inner modeule.""" def __init__(self): super().__init__() self.fc = torch.nn.Linear(50, 100) + def forward(self): + """Forward function.""" + module = MyModule() module.annotate_sharding("linear.weight", 0) module.annotate_sharding("embedding.weight", 1) diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 73d0ce6c..6ea6dd1b 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -22,14 +22,13 @@ import torch import torch_xla2 from torch.utils import _pytree as pytree +from absl.testing import parameterized from jetstream_pt.engine import PyTorchEngine from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt.third_party.llama.generation_original import LlamaOriginal from jetstream_pt import environment from tests import helpers -from jetstream_pt import torchjax -from absl.testing import parameterized class LlamaE2ETest(parameterized.TestCase): @@ -43,6 +42,7 @@ def _make_env(self, bf16_enable=True): torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) jax.config.update("jax_traceback_filtering", "off") + # pylint: disable-next=all config = model_args.get_model_args("tiny", 128, 1, 32000, True) environment_data = environment.JetEngineEnvironmentData() environment_data.max_input_sequence_length = 128 diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 703ce444..efbaa09b 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -17,7 +17,6 @@ import jax.numpy as jnp import torch import torch_xla2 -from . import helpers from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.third_party.llama import model_original @@ -30,6 +29,8 @@ from jetstream_pt import layers from jetstream_pt import cache_manager +from . import helpers + class ModelComponentTest(unittest.TestCase): """Test diff between original model and xla model for transformer, @@ -77,7 +78,7 @@ def _generate_mask(self, cache_length, pos, seqlen, ring_buffer=True): def _compare_cache(self, cache_torch, cache_jax): _, seq, _, _ = cache_torch.shape - cache_j = torch_xla2.tensor.j2t(cache_jax._elem) + cache_j = torch_xla2.tensor.j2t(cache_jax.jax()) for s in range(seq): print("diff ", (cache_torch[0, s] - cache_j[0, :, s]).norm()) @@ -141,13 +142,14 @@ def test_attention(self): cache_decode = self._make_one_cache_for_generate(env, pos) # insert prefilled cache entry - cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - ..., :pos, : - ].set(cache.cache_k._elem) - - cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - ..., :pos, : - ].set(cache.cache_v._elem) + # pylint: disable-next=all + cache_decode.cache_k._elem = ( + cache_decode.cache_k.jax().at[..., :pos, :].set(cache.cache_k.jax()) + ) + # pylint: disable-next=all + cache_decode.cache_v._elem = ( + cache_decode.cache_v.jax().at[..., :pos, :].set(cache.cache_v.jax()) + ) # self._compare_cache(attention_orig.cache_k, cache_decode.cache_k) # Now do one with decode @@ -176,6 +178,7 @@ def test_attention(self): self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) def test_gemma_attention(self): + """Test gemma attention.""" with jax.default_matmul_precision("float32"): env, model_arg = helpers.make_env_tiny(False) @@ -306,12 +309,14 @@ def test_transformer_block(self): cache_decode = self._make_one_cache_for_generate(env, pos) # insert prefilled cache entry - cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - ..., :pos, : - ].set(cache.cache_k._elem) - cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - ..., :pos, : - ].set(cache.cache_v._elem) + # pylint: disable-next=all + cache_decode.cache_k._elem = ( + cache_decode.cache_k.jax().at[..., :pos, :].set(cache.cache_k.jax()) + ) + # pylint: disable-next=all + cache_decode.cache_v._elem = ( + cache_decode.cache_v.jax().at[..., :pos, :].set(cache.cache_v.jax()) + ) # Now do one with decode x2 = torch.randn((1, 1, model_arg.dim)) @@ -433,14 +438,16 @@ def test_mixtral_transformer(self): self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) def test_mixtral_moe(self): + """Test mixtral moe module.""" config = mixtral_config.ModelArgs() config.intermediate_size = 16 config.dim = 16 m = mixtral.ConditionalFeedForward(config) # random init states = m.state_dict() - for k, v in states.items(): - states[k].normal_() + for _, v in states.items(): + # pylint: disable-next=all + v.normal_() m.load_state_dict(states, assign=True) seqlen = 3 diff --git a/tests/test_quantization.py b/tests/test_quantization.py index f48809ea..d150c67b 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import functools import unittest import jax import jax.numpy as jnp -import jax.sharding as jsharding import torch import torch_xla2 -from jax.experimental import mesh_utils -from jetstream_pt import cache_manager, layers, quantize, torchjax, environment +from absl.testing import parameterized +from tests import helpers + + +from jetstream_pt import cache_manager, layers, torchjax, environment from jetstream_pt.environment import QuantizationConfig from jetstream_pt.layers import ( WeightOnlyBlockwiseQuantizedLinear, @@ -30,11 +31,7 @@ ) from jetstream_pt.quantize_model import quantize_model from jetstream_pt.quantize import dequantize_tensor, quantize_tensor -from tests import helpers -from torch.utils import _pytree as pytree -from torch_xla2 import tensor -import copy -from absl.testing import parameterized + torch.manual_seed(12345) @@ -127,27 +124,27 @@ def update_finalize_compare(in_k, in_v, in_layer, in_pos): if env.generate_cache_stacked: self.assertTrue( jnp.allclose( - k._elem, - new_k._elem[in_layer, :, :, in_pos : (in_pos + 1), :], + k.jax(), + new_k.jax()[in_layer, :, :, in_pos : (in_pos + 1), :], atol=0.1, ) ) self.assertTrue( jnp.allclose( - v._elem, - new_v._elem[in_layer, :, :, in_pos : (in_pos + 1), :], + v.jax(), + new_v.jax()[in_layer, :, :, in_pos : (in_pos + 1), :], atol=0.1, ) ) else: self.assertTrue( jnp.allclose( - k._elem, new_k._elem[:, :, in_pos : (in_pos + 1), :], atol=0.1 + k.jax(), new_k.jax()[:, :, in_pos : (in_pos + 1), :], atol=0.1 ) ) self.assertTrue( jnp.allclose( - v._elem, new_v._elem[:, :, in_pos : (in_pos + 1), :], atol=0.1 + v.jax(), new_v.jax()[:, :, in_pos : (in_pos + 1), :], atol=0.1 ) ) @@ -159,6 +156,7 @@ def update_finalize_compare(in_k, in_v, in_layer, in_pos): ("ring_buffer", True), ("left_aligned", False), ) + # pylint: disable-next=all def test_kv_kernel(self, ring_buffer): """test kv cache quantization""" @@ -254,7 +252,9 @@ def update_finalize_record( # Running into the issue of multiple env object always share the same quant_config. # Record the results and compare as a workaround. + # pylint: disable-next=all env._data.quant_config.enable_kv_quantization = True + # pylint: disable-next=all env = environment.JetEngineEnvironment(env._data) cache_int = cache_manager.Int8KVCacheGenerate( @@ -284,6 +284,7 @@ def update_finalize_record( self.assertTrue(jnp.allclose(f.jax(), i.jax(), atol=0.01)) def test_quantize_dequantize_tensor(self): + """Test quantize and dequantize tensor.""" def quantize_dequantize_weight(w, n_bit): # print(f"original w {w}") @@ -333,10 +334,9 @@ def quantize_dequantize_weight(w, n_bit): quantize_dequantize_weight(w, bit) def test_weight_only_quant(self): - + """Test weight only quantization.""" out_features = 2048 in_features = 2048 - block_size = 128 arg = torch.randn(2, 16, in_features).to(torch.bfloat16) nn_linear = torch.nn.Linear( @@ -377,24 +377,27 @@ def test_weight_only_quant(self): in_features, out_features, quant_config=quant_config ) # block_q_linear.run_fake_quantize = True - res, torch_res, block_diff2 = self._nn_linear_run_and_compare( + res, torch_res, _ = self._nn_linear_run_and_compare( nn_linear, block_q_linear, arg ) # self._print_diff(res, torch_res) self.assertLess(per_channel_diff2.norm(), per_channel_diff.norm()) + # pylint: disable-next=all # FIXME: Now asymmetric blockwise quant has higher error than asymmetric per-channel. # self.assertLess(block_diff2.norm(), per_channel_diff2.norm()) def test_int4_weight_loading(self): + """Test int4 weight loading.""" layer = WeightOnlyBlockwiseQuantizedLinear(1024, 2048) state_dict_jax = torchjax.from_torch( helpers.to_xla_tensor(layer.state_dict()) ) state_dict_jax["weight"] = state_dict_jax["weight"].astype(jnp.int4) state_dict_torch = torchjax.to_torch(state_dict_jax) - self.assertTrue(state_dict_torch["weight"]._elem.dtype == jnp.int4) + self.assertTrue(state_dict_torch["weight"].jax().dtype == jnp.int4) def test_blockwise_quantized_linear_sharding(self): + """Test blockwise quantized linear sharding.""" @functools.partial( jax.jit, @@ -410,19 +413,20 @@ def f(layer, weights, args): state_dict_jax = torchjax.from_torch( helpers.to_xla_tensor(layer.state_dict()) ) - input = jax.random.normal( + inputs = jax.random.normal( jax.random.key(0), shape=(2, 32, 1024), dtype=jnp.bfloat16 ) - def shard_and_lower(f, layer, state_dict_jax, input, shardings): + def shard_and_lower(f, layer, state_dict_jax, inputs, shardings): for k, v in state_dict_jax.items(): if k == "weight": state_dict_jax[k] = v.astype(jnp.int4) state_dict_jax[k] = jax.device_put(v, sharding[0]) if k == "weight_scaler": state_dict_jax[k] = jax.device_put(v, sharding[1]) - pre_opt = f.lower(layer, state_dict_jax, input).as_text("hlo") - post_opt = f.lower(layer, state_dict_jax, input).compile().as_text() + # pre opt, for debugging + _ = f.lower(layer, state_dict_jax, inputs).as_text("hlo") + post_opt = f.lower(layer, state_dict_jax, inputs).compile().as_text() return post_opt env, _ = helpers.make_env_tiny() @@ -432,15 +436,16 @@ def shard_and_lower(f, layer, state_dict_jax, input, shardings): # (sharding_by_axis(1), sharding_by_axis(0)), # bad sharding ] for sharding in shardings: - opt_hlo = shard_and_lower(f, layer, state_dict_jax, input, sharding) + opt_hlo = shard_and_lower(f, layer, state_dict_jax, inputs, sharding) self.assertFalse("all-to-all" in opt_hlo) self.assertFalse("all-reduce-scatter" in opt_hlo) def test_activation_quant_per_channel(self): - + """Test activation quantization channel mode.""" out_features = 8 in_features = 4 - block_size = 128 + # Block size + _ = 128 arg = torch.randn(2, 1, in_features).to(torch.bfloat16) nn_linear = torch.nn.Linear( @@ -459,10 +464,11 @@ def test_activation_quant_per_channel(self): self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.9999) def test_quant_creator(self): - + """Test quantization creator.""" out_features = 8 in_features = 4 - block_size = 128 + # Block size + _ = 128 arg = torch.randn(2, 1, in_features).to(torch.bfloat16) nn_linear = torch.nn.Linear( @@ -479,8 +485,10 @@ def test_quant_creator(self): self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.9999) def test_3_layers(self): + """Test 3 layers.""" class Model(torch.nn.Module): + """Model.""" def __init__(self): super().__init__() @@ -489,6 +497,7 @@ def __init__(self): self.linear3 = torch.nn.Linear(2048, 1024, bias=False) def forward(self, x): + """Forward function.""" x = self.linear1(x) x = self.linear2(x) x = self.linear3(x) diff --git a/tests/test_run_server.py b/tests/test_run_server.py index 849af329..759a0177 100644 --- a/tests/test_run_server.py +++ b/tests/test_run_server.py @@ -17,32 +17,40 @@ from absl import app from absl.testing import flagsaver from parameterized import parameterized, param +from run_server import flags class MockServer(MagicMock): + """Mock server.""" def run(self, **kwargs): + """Run.""" return self def wait_for_termination(self): + """Wait for termination.""" raise SystemExit("Successfully exited test.") def mock_engine(**kwargs): + """Mock engine.""" return kwargs class ServerRunTest(unittest.TestCase): + """Server run test.""" def reset_flags(self): + """Reset flag.""" flagsaver.restore_flag_values(self.original) def setup(self): - from run_server import flags + """Setup.""" - FLAGS = flags.FLAGS + f = flags.FLAGS + # pylint: disable-next=all self.original = flagsaver.save_flag_values() - return FLAGS + return f @parameterized.expand( [ @@ -61,50 +69,51 @@ def test_no_change_from_defaults(self, args, expected): args (List): List to simulate sys.argv with dummy first entry at index 0. expected (str): model_name flag value to inspect """ + # pylint: disable-next=all from run_server import main - FLAGS = self.setup() + f = self.setup() with self.assertRaisesRegex(SystemExit, "Successfully exited test."): app.run(main, args) # run_server - self.assertEqual(FLAGS.port, 9000) - self.assertEqual(FLAGS.threads, 64) - self.assertEqual(FLAGS.config, "InterleavedCPUTestServer") - self.assertEqual(FLAGS.prometheus_port, 0) - self.assertEqual(FLAGS.enable_jax_profiler, False) - self.assertEqual(FLAGS.jax_profiler_port, 9999) + self.assertEqual(f.port, 9000) + self.assertEqual(f.threads, 64) + self.assertEqual(f.config, "InterleavedCPUTestServer") + self.assertEqual(f.prometheus_port, 0) + self.assertEqual(f.enable_jax_profiler, False) + self.assertEqual(f.jax_profiler_port, 9999) # quantization configs - self.assertEqual(FLAGS.quantize_weights, False) - self.assertEqual(FLAGS.quantize_activation, False) - self.assertEqual(FLAGS.quantize_type, "int8_per_channel") - self.assertEqual(FLAGS.quantize_kv_cache, False) + self.assertEqual(f.quantize_weights, False) + self.assertEqual(f.quantize_activation, False) + self.assertEqual(f.quantize_type, "int8_per_channel") + self.assertEqual(f.quantize_kv_cache, False) # engine configs - self.assertEqual(FLAGS.tokenizer_path, None) - self.assertEqual(FLAGS.checkpoint_path, None) - self.assertEqual(FLAGS.bf16_enable, True) - self.assertEqual(FLAGS.context_length, 1024) - self.assertEqual(FLAGS.batch_size, 32) - self.assertEqual(FLAGS.size, "tiny") - self.assertEqual(FLAGS.max_cache_length, 1024) - self.assertEqual(FLAGS.shard_on_batch, False) - self.assertEqual(FLAGS.sharding_config, "") - self.assertEqual(FLAGS.ragged_mha, False) - self.assertEqual(FLAGS.starting_position, 512) - self.assertEqual(FLAGS.temperature, 1.0) - self.assertEqual(FLAGS.sampling_algorithm, "greedy") - self.assertEqual(FLAGS.nucleus_topp, 0.0) - self.assertEqual(FLAGS.topk, 0) - self.assertEqual(FLAGS.ring_buffer, True) + self.assertEqual(f.tokenizer_path, None) + self.assertEqual(f.checkpoint_path, None) + self.assertEqual(f.bf16_enable, True) + self.assertEqual(f.context_length, 1024) + self.assertEqual(f.batch_size, 32) + self.assertEqual(f.size, "tiny") + self.assertEqual(f.max_cache_length, 1024) + self.assertEqual(f.shard_on_batch, False) + self.assertEqual(f.sharding_config, "") + self.assertEqual(f.ragged_mha, False) + self.assertEqual(f.starting_position, 512) + self.assertEqual(f.temperature, 1.0) + self.assertEqual(f.sampling_algorithm, "greedy") + self.assertEqual(f.nucleus_topp, 0.0) + self.assertEqual(f.topk, 0) + self.assertEqual(f.ring_buffer, True) # profiling configs - self.assertEqual(FLAGS.profiling_prefill, False) - self.assertEqual(FLAGS.profiling_output, "") + self.assertEqual(f.profiling_prefill, False) + self.assertEqual(f.profiling_output, "") # model_name flag updates - self.assertEqual(FLAGS.model_name, expected) + self.assertEqual(f.model_name, expected) # reset back to original flags self.reset_flags() @@ -112,7 +121,8 @@ def test_no_change_from_defaults(self, args, expected): @parameterized.expand([param(["test1", "--model_name", "llama3"])]) @patch("jetstream_pt.engine.create_pytorch_engine", mock_engine) def test_call_server_object(self, args): - """tests whether running the main script from absl.app.run launches a server and waits for termination + """tests whether running the main script from absl.app.run launches a server + and waits for termination Args: args (List): List to simulate sys.argv with dummy first entry at index 0. @@ -120,9 +130,10 @@ def test_call_server_object(self, args): with patch( "jetstream.core.server_lib.run", autospec=MockServer().run ) as mock_server: + # pylint: disable-next=all from run_server import main - FLAGS = self.setup() + _ = self.setup() with self.assertRaises(SystemExit): app.run(main, args) self.assertEqual(mock_server.call_count, 1) From 62f3c518208c9e04d1d8e90b749834cb19a67f73 Mon Sep 17 00:00:00 2001 From: Richard Liu <39319471+richardsliu@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:40:23 +0000 Subject: [PATCH 53/57] Fix Ray engine crash on multihost (#164) --- jetstream_pt/ray_worker.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index b473e05c..1a4f15e7 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -466,6 +466,9 @@ def prefill_ray( logits = logits[0] token = np.argmax(logits[true_length - 1]) + updated_caches = multihost_utils.process_allgather( + updated_caches, tiled=True + ) prefix = Prefix(token, updated_caches, true_length) self.prefix_queue.put(prefix, block=False) From 743c0e505278216bac34966b9439c72f0d0521b3 Mon Sep 17 00:00:00 2001 From: Richard Liu <39319471+richardsliu@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:34:31 +0000 Subject: [PATCH 54/57] Fix TPU head resource name for v4 and v5e (#165) * Fix TPU head resource name for v4 and v5e * fix format --- run_ray_serve_interleave.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/run_ray_serve_interleave.py b/run_ray_serve_interleave.py index a6b3421d..853ce068 100644 --- a/run_ray_serve_interleave.py +++ b/run_ray_serve_interleave.py @@ -40,8 +40,11 @@ def create_head_resource_name(generation, tpu_chips): - """Create head resource name.""" - return f"TPU-{generation}-{tpu_chips}-head" + if generation == "v5litepod": + return f"TPU-{generation}-{tpu_chips}-head" + else: + tpu_cores = tpu_chips * 2 + return f"TPU-{generation}-{tpu_cores}-head" def create_engine(**kwargs): From 784801f981fb380aaa107a916e48ba63489dc03f Mon Sep 17 00:00:00 2001 From: Fanhai Lu <154379058+FanhaiLu1@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:22:04 -0700 Subject: [PATCH 55/57] Fixed exhausted bug between head and workers (#163) * add xla2 fix * update jax version * revert jax TPU version --- README.md | 1 + deps/xla | 2 +- run_interactive_multiple_host.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5292e0c7..ca6ec4ba 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,7 @@ Note: Get address ip and port information from ray head. Here is an example to run the server with ray for llama2 7B model: ```bash +export DISABLE_XLA2_PJRT_TEST="true" python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` diff --git a/deps/xla b/deps/xla index c2753715..fb2d4e14 160000 --- a/deps/xla +++ b/deps/xla @@ -1 +1 @@ -Subproject commit c27537153f3ea983a7ba9b0e1bfdae4b37ca5e9e +Subproject commit fb2d4e1464dfd96f38a343c0e6f512629e28b48c diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 9192076a..f9307126 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -56,7 +56,7 @@ def create_engine(): sharding_config=FLAGS.sharding_config, num_hosts=_NUM_HOSTS.value, worker_chips=_WORKER_CHIPS.value, - tpu_chips=_TPU_CHIPS, + tpu_chips=_TPU_CHIPS.value, ) print("Initialize engine", time.perf_counter() - start) From d318ce47646b2bb635de1780027523045948a568 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 6 Aug 2024 19:16:46 +0000 Subject: [PATCH 56/57] Fix test_run_server issue from fixing the lint; Fix run_interactive from merge; Fix lints; --- run_interactive.py | 15 +++++++++++---- tests/test_run_server.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/run_interactive.py b/run_interactive.py index 86cd6df2..8463658c 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -47,9 +47,12 @@ def main(argv): if profiling_prefill: jax.profiler.start_trace(profiling_output) + decode_state = engine.init_decode_state() + if profiling_prefill: jax.profiler.stop_trace() + prompts: List[str] = [ # pylint: disable-next=all "I believe the meaning of life is", @@ -72,11 +75,13 @@ def main(argv): # pylint: disable-next=all if profiling_prefill: jax.profiler.start_trace(profiling_output) - prefill_result, _ = engine.prefill( - params=params, padded_tokens=tokens, true_length=true_length - ) - # pylint: disable-next=all + + prefill_result, _ = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + # pylint: disable-next=all decode_state = engine.insert(prefill_result, decode_state, slot=slot) + if profiling_prefill: jax.profiler.stop_trace() @@ -86,11 +91,13 @@ def main(argv): while True: if profiling_output: jax.profiler.start_trace(profiling_output) + decode_state, result_tokens = engine.generate(params, decode_state) result_tokens = result_tokens.convert_to_numpy() if profiling_output: jax.profiler.stop_trace() + output, complete = token_utils.process_result_tokens( tokenizer=tokenizer, slot=slot, diff --git a/tests/test_run_server.py b/tests/test_run_server.py index 759a0177..73022a74 100644 --- a/tests/test_run_server.py +++ b/tests/test_run_server.py @@ -17,7 +17,6 @@ from absl import app from absl.testing import flagsaver from parameterized import parameterized, param -from run_server import flags class MockServer(MagicMock): @@ -46,6 +45,7 @@ def reset_flags(self): def setup(self): """Setup.""" + from run_server import flags f = flags.FLAGS # pylint: disable-next=all From 8b26e9f31639a5e8031f8d56db5bba526b233eed Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 6 Aug 2024 19:29:51 +0000 Subject: [PATCH 57/57] Revert xla changes. --- deps/xla | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/xla b/deps/xla index fb2d4e14..c2753715 160000 --- a/deps/xla +++ b/deps/xla @@ -1 +1 @@ -Subproject commit fb2d4e1464dfd96f38a343c0e6f512629e28b48c +Subproject commit c27537153f3ea983a7ba9b0e1bfdae4b37ca5e9e