Skip to content

Conversation

@bvrockwell
Copy link
Contributor

@bvrockwell bvrockwell commented Jun 12, 2024

  • Added sampling methods to engine.py decoder (greedy, weighted, nucleus, topk).
  • Added configurations to different launch methods.
  • rolled up Jetstream submodule to main

@bvrockwell bvrockwell force-pushed the add-decoder-temperature branch 3 times, most recently from 134f77a to b692f77 Compare June 12, 2024 05:40
@bvrockwell bvrockwell force-pushed the add-decoder-temperature branch from b692f77 to 8d672f4 Compare June 12, 2024 05:45
@bvrockwell bvrockwell force-pushed the add-decoder-temperature branch from 32609da to 409d118 Compare June 12, 2024 18:49
@bvrockwell bvrockwell marked this pull request as ready for review June 12, 2024 19:00
Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding different token sampling! Can you start the pytorch engine server and share the inference result with different sampling algorithms?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need create a duplicated class if sampling_utils py is same as jetstream's one. Jetstream is one of dependencies of pytorch engine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, wasn't sure if I should just keep it to the previous pinned JetStream commit or roll up to most recent main. 13 files changed since (including sampling_utils.py addition, so I changed the import to point to jetstream.engine instead). I'll run the tests with the different sampling algorithms too, thanks!

logits = jnp.expand_dims(logits, 0)
return (
jnp.argmax(logits[:, -1], axis=-1)
sampling_utils.sampling(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit for the sampling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests to test_engine.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding it, looks good to me!

logits = jnp.expand_dims(logits, 0)
return (
jnp.argmax(logits[:, -1], axis=-1)
sampling_utils.sampling(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding it, looks good to me!

@bvrockwell
Copy link
Contributor Author

bvrockwell commented Jun 13, 2024

Thanks for adding different token sampling! Can you start the pytorch engine server and share the inference result with different sampling algorithms?

Here are the inference results between main and proposed changes:

python run_server.py --size=7b --batch_size=1 --max_cache_length=16 \ --quantize_weights=true --quantize_kv_cache=true --checkpoint_path=".../model/llama2" \ --tokenizer_path="deps/JetStream/jetstream/tests/engine/third_party/llama2/tokenizer.model" \ --model_name=llama-2 --sharding_config="default_shardings/llama.yaml" &> server.log &

call to benchmark

python deps/JetStream/benchmarks/benchmark_serving.py \ --tokenizer="deps/JetStream/jetstream/tests/engine/third_party/llama2/tokenizer.model" \ --num-prompts 100 \ --warmup-mode sampled

main: v5e-1 llama2-7b "greedy" quantized "int8_per_channel" batch=1 max_cache_length=16

Successful requests: 130
Benchmark duration: 23.104631 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 5.63 requests/s
Input token throughput: 382.26 tokens/s
Output token throughput: 28.57 tokens/s
Mean TTFT: 11684.67 ms
Median TTFT: 11878.46 ms
P99 TTFT: 22641.39 ms
Mean TPOT: 3531.82 ms
Median TPOT: 2214.95 ms
P99 TPOT: 18868.32 ms

add-decoder-temperature: "greedy" same settings

Successful requests: 130
Benchmark duration: 23.105529 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 5.63 requests/s
Input token throughput: 382.25 tokens/s
Output token throughput: 28.56 tokens/s
Mean TTFT: 11684.65 ms
Median TTFT: 11878.57 ms
P99 TTFT: 22642.09 ms
Mean TPOT: 3531.79 ms
Median TPOT: 2214.98 ms
P99 TPOT: 18869.06 ms

add-decoder-temperature: "weighted" temperature=10 same settings

Successful requests: 130
Benchmark duration: 23.113795 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 5.62 requests/s
Input token throughput: 382.11 tokens/s
Output token throughput: 28.55 tokens/s
Mean TTFT: 11689.17 ms
Median TTFT: 11883.04 ms
P99 TTFT: 22650.20 ms
Mean TPOT: 3533.19 ms
Median TPOT: 2215.81 ms
P99 TPOT: 18875.58 ms

add-decoder-temperature "topk" topk=10 temperature=0.5 same settings

Successful requests: 130
Benchmark duration: 26.119014 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 4.98 requests/s
Input token throughput: 338.14 tokens/s
Output token throughput: 25.27 tokens/s
Mean TTFT: 11691.81 ms
Median TTFT: 11885.63 ms
P99 TTFT: 22655.60 ms
Mean TPOT: 3533.97 ms
Median TPOT: 2216.27 ms
P99 TPOT: 18880.23 ms

add-decoder-temperature "nucleus" nucleus_topp=0.8 temperature=0.5 same settings

Successful requests: 130
Benchmark duration: 23.645459 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 5.50 requests/s
Input token throughput: 373.52 tokens/s
Output token throughput: 27.91 tokens/s
Mean TTFT: 11957.90 ms
Median TTFT: 12156.01 ms
P99 TTFT: 23170.30 ms
Mean TPOT: 3614.40 ms
Median TPOT: 2266.83 ms
P99 TPOT: 19309.17 ms

v5e-4

main "greedy"

Successful requests: 130
Benchmark duration: 7.190380 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 18.08 requests/s
Input token throughput: 1228.31 tokens/s
Output token throughput: 91.79 tokens/s
Mean TTFT: 3626.83 ms
Median TTFT: 3692.30 ms
P99 TTFT: 7034.99 ms
Mean TPOT: 1085.40 ms
Median TPOT: 680.39 ms
P99 TPOT: 5863.28 ms

add-decoder-temperature "greedy" same settings

Successful requests: 130
Benchmark duration: 9.167303 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 14.18 requests/s
Input token throughput: 963.42 tokens/s
Output token throughput: 72.00 tokens/s
Mean TTFT: 3612.90 ms
Median TTFT: 3649.38 ms
P99 TTFT: 7026.81 ms
Mean TPOT: 1092.39 ms
Median TPOT: 686.06 ms
P99 TPOT: 5855.20 ms

add-decoder-temperature "weighted" temperature=10 same settings

Successful requests: 130
Benchmark duration: 7.201671 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 18.05 requests/s
Input token throughput: 1226.38 tokens/s
Output token throughput: 91.65 tokens/s
Mean TTFT: 3633.50 ms
Median TTFT: 3697.00 ms
P99 TTFT: 7045.76 ms
Mean TPOT: 1093.84 ms
Median TPOT: 689.30 ms
P99 TPOT: 5809.58 ms

@qihqi qihqi merged commit 97aaeae into main Jun 14, 2024
@bvrockwell bvrockwell deleted the add-decoder-temperature branch June 14, 2024 00:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants