Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[MESSAGES CONTROL]
disable=C0114,R0801,E1102,W0613,R1711
disable=C0114,R0801,E1102,W0613,R1711,too-many-locals
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --m


# Run the server
Here is an example to run the server with llama2 7B config. Note that the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`).
Here is an example to run the server with llama2 7B config.

```bash
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name --sharding_config="default_shardings/llama.yaml"
python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
```

Now you can fire gRPC to it.
Expand Down
85 changes: 28 additions & 57 deletions benchmarks/prefill_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
import time
import functools
import humanize

# pylint: disable-next=all
from absl import app
from absl import flags
import numpy as np
import humanize
import jax

import numpy as np
# pylint: disable-next=all
from absl import app, flags
from jetstream_pt import engine as je

FLAGS = flags.FLAGS

_TOKENIZER_PATH = flags.DEFINE_string(
"tokenizer_path",
"tokenizer.model",
"The tokenizer model path",
required=False,
)
_CKPT_PATH = flags.DEFINE_string(
"checkpoint_path", None, "Directory for .pth checkpoints", required=False
)
_BF16_ENABLE = flags.DEFINE_bool(
"bf16_enable", False, "Whether to enable bf16", required=False
)
_CONTEXT_LENGTH = flags.DEFINE_integer(
"context_length", 1024, "The context length", required=False
)
_BATCH_SIZE = flags.DEFINE_integer(
"batch_size", 32, "The batch size", required=False
)
_PROFILING_OUTPUT = flags.DEFINE_string(
"profiling_output",
"",
"The profiling output",
required=False,
from jetstream_pt.config import (
FLAGS,
create_engine_from_config_flags,
define_profiling_flags,
)

_SIZE = flags.DEFINE_string("size", "tiny", "size of model")

_QUANTIZE_WEIGHTS = flags.DEFINE_bool(
"quantize_weights", False, "weight quantization"
)
_QUANTIZE_KV_CACHE = flags.DEFINE_bool(
"quantize_kv_cache", False, "kv_cache_quantize"
)
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
"max_cache_length", 1024, "kv_cache_quantize"
)
define_profiling_flags()


def create_engine():
Expand All @@ -73,16 +39,19 @@ def create_engine():
devices = jax.devices()
start = time.perf_counter()
engine = je.create_pytorch_engine(
model_name=FLAGS.model_name,
devices=devices,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
tokenizer_path=FLAGS.tokenizer_path,
ckpt_path=FLAGS.checkpoint_path,
bf16_enable=FLAGS.bf16_enable,
param_size=FLAGS.size,
context_length=FLAGS.context_length,
batch_size=FLAGS.batch_size,
quantize_weights=FLAGS.quantize_weights,
quantize_kv=FLAGS.quantize_kv_cache,
max_cache_length=FLAGS.max_cache_length,
sharding_config=FLAGS.sharding_config,
shard_on_batch=FLAGS.shard_on_batch,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down Expand Up @@ -163,14 +132,16 @@ def prefill_benchmark(tokens_list, engine, params, warmup):
# pylint: disable-next=all
def main(argv):

engine = create_engine()
engine = create_engine_from_config_flags()

start = time.perf_counter()
params = engine.load_params()
print("Load params ", time.perf_counter() - start)

if _PROFILING_OUTPUT.value:
jax.profiler.start_trace(_PROFILING_OUTPUT.value)
profiling_output = FLAGS.profiling_output
if profiling_output:
jax.profiler.start_trace(profiling_output)

print_mem_usage()
tokens_list = create_prefill_tokens()
for _ in range(3):
Expand All @@ -189,7 +160,7 @@ def main(argv):
tokens_list=tokens_list, engine=engine, params=params, warmup=False
)

if _PROFILING_OUTPUT.value:
if profiling_output:
jax.profiler.stop_trace()


Expand Down
101 changes: 31 additions & 70 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,66 +15,24 @@
import logging
import os
import time
# pylint: disable-next=all
from absl import app
from absl import flags

import jax
import jax.numpy as jnp

from jetstream_pt import engine as je
# pylint: disable-next=all
from absl import app, flags
# pylint: disable-next=all
from benchmarks import analyze_sharegpt


logging.getLogger().setLevel(logging.ERROR)


FLAGS = flags.FLAGS

_TOKENIZER_PATH = flags.DEFINE_string(
"tokenizer_path",
"tokenizer.model",
"The tokenizer model path",
required=False,
)
_CKPT_PATH = flags.DEFINE_string(
"checkpoint_path", None, "Directory for .pth checkpoints", required=False
)
_BF16_ENABLE = flags.DEFINE_bool(
"bf16_enable", False, "Whether to enable bf16", required=False
)
_CONTEXT_LENGTH = flags.DEFINE_integer(
"context_length", 1024, "The context length", required=False
)
_BATCH_SIZE = flags.DEFINE_integer(
"batch_size", 32, "The batch size", required=False
)
_PROFILING_OUTPUT = flags.DEFINE_string(
"profiling_output",
"",
"The profiling output",
required=False,
from jetstream_pt import engine as je
from jetstream_pt.config import (
FLAGS,
create_engine_from_config_flags,
define_profiling_flags,
)

_SIZE = flags.DEFINE_string("size", "tiny", "size of model")
logging.getLogger().setLevel(logging.ERROR)

_QUANTIZE_WEIGHTS = flags.DEFINE_bool(
"quantize_weights", False, "weight quantization"
)
_QUANTIZE_KV_CACHE = flags.DEFINE_bool(
"quantize_kv_cache", False, "kv_cache_quantize"
)
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
"max_cache_length", 1024, "kv_cache_quantize"
)
_MODEL_NAME = flags.DEFINE_string("model_name", "", "model_name")
_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "path to sharding config"
)
_SHAREGPT_PATH = flags.DEFINE_string(
"sharegpt_path", "", "path to sharegpt json file"
)
define_profiling_flags()
flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file")


def create_engine():
Expand All @@ -85,18 +43,19 @@ def create_engine():
devices = jax.devices()
start = time.perf_counter()
engine = je.create_pytorch_engine(
model_name=FLAGS.model_name,
devices=devices,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
model_name=_MODEL_NAME.value,
sharding_config=_SHARDING_CONFIG.value,
tokenizer_path=FLAGS.tokenizer_path,
ckpt_path=FLAGS.checkpoint_path,
bf16_enable=FLAGS.bf16_enable,
param_size=FLAGS.size,
context_length=FLAGS.context_length,
batch_size=FLAGS.batch_size,
quantize_weights=FLAGS.quantize_weights,
quantize_kv=FLAGS.quantize_kv_cache,
max_cache_length=FLAGS.max_cache_length,
sharding_config=FLAGS.sharding_config,
shard_on_batch=FLAGS.shard_on_batch,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down Expand Up @@ -148,16 +107,17 @@ def run_prefill_time(engine, params, decode_state, seqlen):

def main(argv):
"""Main function to run engine offline."""
engine = create_engine()
engine = create_engine_from_config_flags()

start = time.perf_counter()
params = engine.load_params()
print("Load params ", time.perf_counter() - start)

prefill_times = {}

if _PROFILING_OUTPUT.value:
jax.profiler.start_trace(_PROFILING_OUTPUT.value)
profiling_output = FLAGS.profiling_output
if profiling_output:
jax.profiler.start_trace(profiling_output)
decode_state = engine.init_decode_state()
for batch, _ in MAXTEXT_PREFILL.items():
runtime, decode_state = run_prefill_time(
Expand Down Expand Up @@ -186,18 +146,19 @@ def main(argv):
dec_times.append(end - start)
print(i, "decode time", (end - start))

if _PROFILING_OUTPUT.value:
if profiling_output:
jax.profiler.stop_trace()

print("prefill ", prefill_times)
print("decode", sum(dec_times) / 10)

prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
decode_time_ms = sum(dec_times) * 1000 / 10 / _BATCH_SIZE.value
decode_time_ms = sum(dec_times) * 1000 / 10 / FLAGS.batch_size

if _SHAREGPT_PATH.value:
sharegpt_path = FLAGS.sharegpt_path
if sharegpt_path:
analyze_sharegpt.do_simulation(
_SHAREGPT_PATH.value, prefill_times_ms, decode_time_ms
sharegpt_path, prefill_times_ms, decode_time_ms
)


Expand Down
Loading