Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 6 additions & 57 deletions benchmarks/analyze_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,11 @@
CUTOFF_INPUT = 1024
CUTOFF_OUTPUT = 1024

# batch size 60, ful cache, bfloat
prefill_bucket_size_to_s = {
64: 0.007696230011060834,
128: 0.011508351005613805,
256: 0.01721684739459306,
512: 0.03257157760672271,
1024: 0.08185497261583805,
}

# batch size 96, ful cache, quantized
prefill_bucket_size_to_s = {
64: 0.006911616190336645,
128: 0.011646182998083532,
256: 0.01875854718964547,
512: 0.0334438294172287,
1024: 0.0643601292045787,
}

# batch size 96, rolling, bfloat
prefill_bucket_size_to_s = {
64: 0.007730783987790346,
128: 0.011515899002552033,
256: 0.01780580161139369,
512: 0.03115477201063186,
1024: 0.07443338260054588,
}

# batch size 160, rolling, quantized
prefill_bucket_size_to_s = {
64: 0.006821704190224409,
128: 0.01175499300006777,
256: 0.018776051187887787,
512: 0.03392685519065708,
1024: 0.06476318498607725,
}

prefill_bucket_size_to_ms = {
k: p * 1000 for k, p in prefill_bucket_size_to_s.items()
}

# batch size 60, ful cache, bfloat
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 26.55 / 60

# batch size 96, ful cache, quantized
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 26.0 / 96

# batch size 96, rolling, bfloat
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 28.18 / 96

# batch size 160, rolling, quantized
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 30 / 160


# pylint: disable-next=all
def do_simulation(prefill_bucket_size_to_ms, system_time_per_decode_token_ms):
def do_simulation(
sharegpt_path, prefill_bucket_size_to_ms, system_time_per_decode_token_ms
):
def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()

Expand All @@ -82,10 +32,9 @@ def tokens_in_input_str(s):

convo_numbers = []
# Please update with your own data file path
loaded_share_gpt = json.load(
# pylint: disable-next=all
open("~/data/ShareGPT_V3_unfiltered_cleaned_split.json", "r")
)

with open(sharegpt_path, "r", encoding="utf-8") as f:
loaded_share_gpt = json.load(f)
for example in loaded_share_gpt:
if len(example["conversations"]) < 2:
continue
Expand Down
14 changes: 13 additions & 1 deletion benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
"max_cache_length", 1024, "kv_cache_quantize"
)
_MODEL_NAME = flags.DEFINE_string("model_name", "", "model_name")
_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "path to sharding config"
)
_SHAREGPT_PATH = flags.DEFINE_string(
"sharegpt_path", "", "path to sharegpt json file"
)


def create_engine():
Expand All @@ -87,6 +94,8 @@ def create_engine():
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
model_name=_MODEL_NAME.value,
sharding_config=_SHARDING_CONFIG.value,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down Expand Up @@ -185,7 +194,10 @@ def main(argv):
prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
decode_time_ms = sum(dec_times) * 1000 / 10 / _BATCH_SIZE.value

analyze_sharegpt.do_simulation(prefill_times_ms, decode_time_ms)
if _SHAREGPT_PATH.value:
analyze_sharegpt.do_simulation(
_SHAREGPT_PATH.value, prefill_times_ms, decode_time_ms
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.y_sharding = env.sharding_by_axis(1)
self.x_sharding = env.sharding_by_axis(0)
self.replicated = env.sharding_by_axis(-1) # replicated
self.cache_sharding = self.y_sharding
self.cache_sharding = self.env.cache_sharding

self.prefill = jax.jit(
self.prefill, out_shardings=self.get_prefix_destination_sharding()
Expand Down
15 changes: 9 additions & 6 deletions jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,16 @@ def __init__(self, data: JetEngineEnvironmentData):
self.x_sharding = jsharding.NamedSharding(self._mesh, P("x"))
self.replicated = jsharding.NamedSharding(self._mesh, P())

cache_sharding = (
"x" if axis == self._data.kv_cache_shard_axis else None
for axis in self._data.attention_kv_axis_names
)
self.cache_sharding = jsharding.NamedSharding(
self._mesh, P(*cache_sharding)
cache_sharding_axis = self.attention_kv_axis_names.index(
self.kv_cache_shard_axis
)

if self.cache_shape[cache_sharding_axis] == 1:
# cannot shard on an axis that is 1
# default to last
cache_sharding_axis = len(self.cache_shape) - 1

self.cache_sharding = self.sharding_by_axis(cache_sharding_axis)
self._load_sharding_config()

def _load_sharding_config(self):
Expand Down
4 changes: 0 additions & 4 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ def __call__(self, xq, xk, xv, mask, cache):

with jax.named_scope("attn_insert_cache"):
keys, values = cache.update(xk, xv)
self.env.apply_sharding(keys, axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to remove the keys, values sharding?

self.env.apply_sharding(values, axis=1)
keys = repeat_kv(keys, n_rep)
values = repeat_kv(values, n_rep)
with jax.named_scope("attn_mat1"):
Expand Down Expand Up @@ -206,8 +204,6 @@ def __call__(self, xq, xk, xv, mask, cache):

with jax.named_scope("attn_insert_cache"):
keys, values, k_scaler, v_scaler = cache.update(xk, xv)
self.env.apply_sharding(keys, axis=1)
self.env.apply_sharding(values, axis=1)
keys = repeat_kv(keys, n_rep)
values = repeat_kv(values, n_rep)
with jax.named_scope("attn_mat1"):
Expand Down
12 changes: 9 additions & 3 deletions jetstream_pt/third_party/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,15 @@ def forward(
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

self.env.apply_sharding(xq, axis=2)
self.env.apply_sharding(xk, axis=2)
self.env.apply_sharding(xv, axis=2)
if self.num_kv_heads > 1:
self.env.apply_sharding(xq, axis=2)
self.env.apply_sharding(xk, axis=2)
self.env.apply_sharding(xv, axis=2)
else:
# Gemma 2B
self.env.apply_sharding(xq, axis=3)
self.env.apply_sharding(xk, axis=3)
self.env.apply_sharding(xv, axis=3)

# Positional embedding.
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
Expand Down