Skip to content
Merged
7 changes: 3 additions & 4 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import jax
import jax.numpy as jnp

from jetstream.engine import token_utils
from jetstream_pt import engine as je
# pylint: disable-next=all
from benchmarks import analyze_sharegpt
Expand Down Expand Up @@ -97,11 +96,11 @@ def create_engine():
def run_prefill_time(engine, params, decode_state, seqlen):
"""Run prefill and measure time."""
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
tokenizer = engine.build_tokenizer(metadata)

text = "This is a beautiful day"
tokens, true_length = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[seqlen]
tokens, true_length = tokenizer.encode(
text, is_bos=True, prefill_lengths=[seqlen]
)

for _ in range(3):
Expand Down
2 changes: 1 addition & 1 deletion install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

TORCHXLA_TAG=jetstream-pytorch
JETSTREAM_TAG=v0.2.0
JETSTREAM_TAG=v0.2.1

# Uninstall existing jax
pip3 show jax && pip3 uninstall -y jax
Expand Down
25 changes: 17 additions & 8 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import torch
import numpy as np

from jetstream.engine import engine_api, tokenizer_pb2, token_utils
from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils
import torch_xla2
from torch.utils import _pytree as pytree

from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.third_party.llama2 import model_exportable, model_args
from jetstream_pt.third_party.llama import model_exportable, model_args


Mesh = jax.sharding.Mesh
Expand Down Expand Up @@ -526,6 +526,14 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
# pylint: disable-next=all
return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path)

def build_tokenizer(
self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all
) -> tokenizer_api.Tokenizer:
if "llama-3" in self.env.model_type:
return token_utils.TikToken(metadata)

return token_utils.SentencePieceTokenizer(metadata)

def join_prefixes(
self,
prefix1: engine_api.Prefix,
Expand Down Expand Up @@ -652,13 +660,18 @@ def create_pytorch_engine(
context_length: int = 1024,
batch_size: int = 1,
max_decode_length: int = 4096,
model_name="llama",
model_name="llama-2",
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
) -> PyTorchEngine:
"""Returns: The pytorch engine."""

supported_models = ["llama-2", "llama-3"]
if model_name not in supported_models:
raise NotImplementedError(
f"Model name should be one of{','.join(supported_models)}"
)
# See issue b/309529778 if it's turned on.
jax.config.update("jax_dynamic_shapes", False)
# Pytorch exports has int64 constants.
Expand Down Expand Up @@ -696,11 +709,7 @@ def create_pytorch_engine(
if model_name.startswith("llama"):

args = model_args.get_model_args(
param_size,
context_length,
batch_size,
tokenizer.vocab_size,
bf16_enable,
model_name + "-" + param_size, context_length, batch_size, bf16_enable
)
args.device = "meta"
args.quantize = quantize_weights
Expand Down
12 changes: 5 additions & 7 deletions jetstream_pt/ray_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from torch.utils import _pytree as pytree
import torch_xla2

from jetstream.engine import engine_api, tokenizer_pb2, token_utils
from jetstream.engine import engine_api, tokenizer_pb2

from jetstream_pt.third_party.llama2 import model_exportable, model_args
from jetstream_pt.third_party.llama import model_exportable, model_args

from jetstream_pt import cache_manager
from jetstream_pt import quantize
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
context_length: int = 1024,
batch_size: int = 1,
max_decode_length: int = 4096,
model_name="llama",
model_name="llama-2",
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
Expand Down Expand Up @@ -159,14 +159,12 @@ def __init__(
)
env = JetEngineEnvironment(env_data)

tokenizer = token_utils.load_vocab(tokenizer_path)
pt_model = None
if model_name == "llama":
if "llama" in model_name:
args = model_args.get_model_args(
param_size,
model_name + "-" + param_size,
context_length,
batch_size,
tokenizer.vocab_size,
bf16_enable,
)
args.device = "meta"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing import List, Literal, Optional, Tuple, TypedDict

import torch
from jetstream_pt.third_party.llama2 import model_original
from jetstream_pt.third_party.llama import model_original
from flax import struct
from jetstream_pt.third_party.llama2.tokenizer import Tokenizer
from jetstream_pt.third_party.llama.tokenizer import Tokenizer

Role = Literal["system", "user", "assistant"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,68 +34,81 @@ class ModelArgs:
device = "cpu"
quantize = False

rope_theta: float = 10000.0


def get_arg(
param_size: str,
model_name: str,
seqlen,
batch_size,
vocab_size: int,
bf16_enable: bool = False,
) -> ModelArgs:
"""Gets model args."""

data = {}
if param_size == "tiny":
if model_name == "llama-2-tiny":
data = {
"dim": 128,
"vocab_size": 32000,
"multiple_of": 32,
"n_heads": 8,
"n_layers": 3,
"norm_eps": 1e-05,
}
elif param_size == "7b":
elif model_name == "llama-2-7b":
data = {
"dim": 4096,
"vocab_size": 32000,
"multiple_of": 256,
"n_heads": 32,
"n_layers": 32,
"norm_eps": 1e-05,
}
elif param_size == "13b":
elif model_name == "llama-2-13b":
data = {
"dim": 5120,
"vocab_size": 32000,
"multiple_of": 256,
"n_heads": 40,
"n_layers": 40,
"norm_eps": 1e-05,
}
elif param_size == "70b":
elif model_name == "llama-2-70b":
data = {
"dim": 8192,
"vocab_size": 32000,
"multiple_of": 4096,
"ffn_dim_multiplier": 1.3,
"n_heads": 64,
"n_kv_heads": 8,
"n_layers": 80,
"norm_eps": 1e-05,
}
elif model_name == "llama-3-8b":
data = {
"dim": 4096,
"vocab_size": 128256,
"multiple_of": 1024,
"ffn_dim_multiplier": 1.3,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
}
return ModelArgs(
max_seq_len=seqlen,
max_batch_size=batch_size,
vocab_size=vocab_size,
bf16_enable=bf16_enable,
**data,
)


def get_model_args(
param_size, context_length, batch_size, vocab_size, bf16_enable
):
def get_model_args(model_name, context_length, batch_size, bf16_enable):
model_args = get_arg(
param_size=param_size,
model_name=model_name,
seqlen=context_length,
batch_size=batch_size,
vocab_size=vocab_size,
bf16_enable=bf16_enable,
)
model_args.n_kv_heads = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def __init__(
)
# TODO what to do with this
freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
theta=self.params.rope_theta,
)

self.register_buffer("freqs_cis", freqs_cis)
Expand Down
41 changes: 24 additions & 17 deletions run_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
import jax

from jetstream.engine import token_utils
from colorama import Fore, Style
import numpy as np

import os

from jetstream_pt import engine as je

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -64,6 +69,11 @@
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
"max_cache_length", 1024, "kv_cache_quantize"
)
_MODEL_NAME = flags.DEFINE_string(
"model",
"llama-2",
"name of the model. Supported options are llama-2 and llama-3",
)


def create_engine():
Expand All @@ -81,6 +91,7 @@ def create_engine():
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
model_name=_MODEL_NAME.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
Expand All @@ -100,8 +111,7 @@ def main(argv):
print("Load params ", time.perf_counter() - start)

metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
stop_tokens = [vocab.eos_id, vocab.pad_id]
tokenizer = engine.build_tokenizer(metadata)
max_output_length = 1024

if _PROFILING_OUTPUT.value:
Expand All @@ -121,9 +131,8 @@ def main(argv):
]
for prompt in prompts:
slot = random.randint(0, _BATCH_SIZE.value - 1)
tokens, true_length = token_utils.tokenize_and_pad(
prompt, vocab, is_bos=True
)
tokens, true_length = tokenizer.encode(prompt, is_bos=True)

print(f"---- Input prompts are: {prompt}")
print(f"---- Encoded tokens are: {tokens}")

Expand All @@ -135,29 +144,27 @@ def main(argv):
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
sampled_tokens_list = []
print(f"---- Streaming decode started on #slot{slot}.")
complete = np.zeros((1,), dtype=np.bool_)
while True:
# pylint: disable-next=all
decode_state, result_tokens = engine.generate(params, decode_state)

slot_data = result_tokens.get_result_at_slot(slot)
slot_tokens = slot_data.tokens
slot_lengths = slot_data.lengths

token_id = slot_tokens[slot, 0].item()
if slot_lengths > max_output_length or token_id in stop_tokens:
result_tokens = result_tokens.convert_to_numpy()
output, complete = tokenizer.decode(
slot, max_output_length, result_tokens, complete
)
if complete[0]:
break

token_id = output[0][0]
sampled_tokens_list.append(token_id)
# output = token_utils.mix_decode(vocab, token_id)
# print(Fore.GREEN + output, end="", flush=True)
# output_str = tokenizer.decode_str([token_id])
# print(Fore.GREEN + output_str, end="", flush=True)

# print(Style.RESET_ALL + "\n")
# print("---- Streaming decode finished.")

print("---- All output tokens.")
print(sampled_tokens_list)
print("---- All output text.")
print(vocab.tokenizer.decode(sampled_tokens_list))
print(tokenizer.decode_str(sampled_tokens_list))

if _PROFILING_OUTPUT.value:
jax.profiler.stop_trace()
Expand Down
6 changes: 6 additions & 0 deletions run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@
"The model size the server runs on.",
required=False,
)
_MODEL_NAME = flags.DEFINE_string(
"model",
"llama-2",
"name of the model. Supported options are llama-2 and llama-3",
)

_QUANTIZE_WEIGHTS = flags.DEFINE_bool(
"quantize_weights", False, "weight quantization"
Expand Down Expand Up @@ -98,6 +103,7 @@ def main(argv: Sequence[str]):
param_size=_PARAM_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
model_name=_MODEL_NAME.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import jax
from jetstream_pt.third_party.llama2 import model_args
from jetstream_pt.third_party.llama import model_args
from jetstream_pt import environment


Expand All @@ -9,7 +9,7 @@ def make_env_tiny(bf16_enable=True):
torch.set_default_dtype(torch_dtype)
jax.config.update("jax_dynamic_shapes", False)
jax.config.update("jax_traceback_filtering", "off")
config = model_args.get_model_args("tiny", 128, 1, 32000, True)
config = model_args.get_model_args("llama-2-tiny", 128, 1, True)
environment_data = environment.JetEngineEnvironmentData()
environment_data.max_input_sequence_length = 128
environment_data.max_input_sequence_length = 128
Expand Down
2 changes: 1 addition & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.engine import PyTorchEngine, Prefix, DecodeState
from jetstream_pt.third_party.llama2 import model_exportable, model_original
from jetstream_pt.third_party.llama import model_exportable, model_original

# This model will output tokens with value of 2
# and will update caches with value of 1.0
Expand Down
Loading