diff --git a/benchmarks/analyze_sharegpt.py b/benchmarks/analyze_sharegpt.py index 9dd7df97..12042de9 100644 --- a/benchmarks/analyze_sharegpt.py +++ b/benchmarks/analyze_sharegpt.py @@ -17,61 +17,11 @@ CUTOFF_INPUT = 1024 CUTOFF_OUTPUT = 1024 -# batch size 60, ful cache, bfloat -prefill_bucket_size_to_s = { - 64: 0.007696230011060834, - 128: 0.011508351005613805, - 256: 0.01721684739459306, - 512: 0.03257157760672271, - 1024: 0.08185497261583805, -} - -# batch size 96, ful cache, quantized -prefill_bucket_size_to_s = { - 64: 0.006911616190336645, - 128: 0.011646182998083532, - 256: 0.01875854718964547, - 512: 0.0334438294172287, - 1024: 0.0643601292045787, -} - -# batch size 96, rolling, bfloat -prefill_bucket_size_to_s = { - 64: 0.007730783987790346, - 128: 0.011515899002552033, - 256: 0.01780580161139369, - 512: 0.03115477201063186, - 1024: 0.07443338260054588, -} - -# batch size 160, rolling, quantized -prefill_bucket_size_to_s = { - 64: 0.006821704190224409, - 128: 0.01175499300006777, - 256: 0.018776051187887787, - 512: 0.03392685519065708, - 1024: 0.06476318498607725, -} - -prefill_bucket_size_to_ms = { - k: p * 1000 for k, p in prefill_bucket_size_to_s.items() -} - -# batch size 60, ful cache, bfloat -SYSTEM_TIME_PER_DECODE_TOKEN_MS = 26.55 / 60 - -# batch size 96, ful cache, quantized -SYSTEM_TIME_PER_DECODE_TOKEN_MS = 26.0 / 96 - -# batch size 96, rolling, bfloat -SYSTEM_TIME_PER_DECODE_TOKEN_MS = 28.18 / 96 - -# batch size 160, rolling, quantized -SYSTEM_TIME_PER_DECODE_TOKEN_MS = 30 / 160 - # pylint: disable-next=all -def do_simulation(prefill_bucket_size_to_ms, system_time_per_decode_token_ms): +def do_simulation( + sharegpt_path, prefill_bucket_size_to_ms, system_time_per_decode_token_ms +): def next_power_of_2(x): return 1 if x == 0 else 2 ** (x - 1).bit_length() @@ -82,10 +32,9 @@ def tokens_in_input_str(s): convo_numbers = [] # Please update with your own data file path - loaded_share_gpt = json.load( - # pylint: disable-next=all - open("~/data/ShareGPT_V3_unfiltered_cleaned_split.json", "r") - ) + + with open(sharegpt_path, "r", encoding="utf-8") as f: + loaded_share_gpt = json.load(f) for example in loaded_share_gpt: if len(example["conversations"]) < 2: continue diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index a41ca986..2f480930 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -67,6 +67,13 @@ _MAX_CACHE_LENGTH = flags.DEFINE_integer( "max_cache_length", 1024, "kv_cache_quantize" ) +_MODEL_NAME = flags.DEFINE_string("model_name", "", "model_name") +_SHARDING_CONFIG = flags.DEFINE_string( + "sharding_config", "", "path to sharding config" +) +_SHAREGPT_PATH = flags.DEFINE_string( + "sharegpt_path", "", "path to sharegpt json file" +) def create_engine(): @@ -87,6 +94,8 @@ def create_engine(): quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, + model_name=_MODEL_NAME.value, + sharding_config=_SHARDING_CONFIG.value, ) print("Initialize engine", time.perf_counter() - start) @@ -185,7 +194,10 @@ def main(argv): prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()} decode_time_ms = sum(dec_times) * 1000 / 10 / _BATCH_SIZE.value - analyze_sharegpt.do_simulation(prefill_times_ms, decode_time_ms) + if _SHAREGPT_PATH.value: + analyze_sharegpt.do_simulation( + _SHAREGPT_PATH.value, prefill_times_ms, decode_time_ms + ) if __name__ == "__main__": diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 7b234782..6ec196ac 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -86,7 +86,7 @@ def __init__( self.y_sharding = env.sharding_by_axis(1) self.x_sharding = env.sharding_by_axis(0) self.replicated = env.sharding_by_axis(-1) # replicated - self.cache_sharding = self.y_sharding + self.cache_sharding = self.env.cache_sharding self.prefill = jax.jit( self.prefill, out_shardings=self.get_prefix_destination_sharding() diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 453172ef..7458636e 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -97,13 +97,16 @@ def __init__(self, data: JetEngineEnvironmentData): self.x_sharding = jsharding.NamedSharding(self._mesh, P("x")) self.replicated = jsharding.NamedSharding(self._mesh, P()) - cache_sharding = ( - "x" if axis == self._data.kv_cache_shard_axis else None - for axis in self._data.attention_kv_axis_names - ) - self.cache_sharding = jsharding.NamedSharding( - self._mesh, P(*cache_sharding) + cache_sharding_axis = self.attention_kv_axis_names.index( + self.kv_cache_shard_axis ) + + if self.cache_shape[cache_sharding_axis] == 1: + # cannot shard on an axis that is 1 + # default to last + cache_sharding_axis = len(self.cache_shape) - 1 + + self.cache_sharding = self.sharding_by_axis(cache_sharding_axis) self._load_sharding_config() def _load_sharding_config(self): diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index d85c98e2..f0b1fe50 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -151,8 +151,6 @@ def __call__(self, xq, xk, xv, mask, cache): with jax.named_scope("attn_insert_cache"): keys, values = cache.update(xk, xv) - self.env.apply_sharding(keys, axis=1) - self.env.apply_sharding(values, axis=1) keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) with jax.named_scope("attn_mat1"): @@ -206,8 +204,6 @@ def __call__(self, xq, xk, xv, mask, cache): with jax.named_scope("attn_insert_cache"): keys, values, k_scaler, v_scaler = cache.update(xk, xv) - self.env.apply_sharding(keys, axis=1) - self.env.apply_sharding(values, axis=1) keys = repeat_kv(keys, n_rep) values = repeat_kv(values, n_rep) with jax.named_scope("attn_mat1"): diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index a8573519..fad0dda6 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -148,9 +148,15 @@ def forward( xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) - self.env.apply_sharding(xq, axis=2) - self.env.apply_sharding(xk, axis=2) - self.env.apply_sharding(xv, axis=2) + if self.num_kv_heads > 1: + self.env.apply_sharding(xq, axis=2) + self.env.apply_sharding(xk, axis=2) + self.env.apply_sharding(xv, axis=2) + else: + # Gemma 2B + self.env.apply_sharding(xq, axis=3) + self.env.apply_sharding(xk, axis=3) + self.env.apply_sharding(xv, axis=3) # Positional embedding. xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)