Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 112 additions & 14 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import List
import random
import sys

import time
# import torch_xla2 first!
import torch_xla2 # pylint: disable
import jax
from absl import app, flags
from jetstream.engine import token_utils
from jetstream.core import server_lib
from jetstream.core.config_lib import ServerConfig, MetricsServerConfig
import torch
import numpy as np
from transformers import AutoTokenizer

from jetstream_pt import fetch_models
Expand Down Expand Up @@ -55,9 +59,12 @@ def create_engine(devices):
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
if quant_config.enable_weight_quantization:
quantize_model.quantize_model(model, quant_config)
print("====== model =======")
print(model)

weight_shardings = model.get_sharding_annotations()
sharded_weights = shard_weights(env, model.state_dict(), weight_shardings)
env_data.quant_config = quant_config

return engine.PyTorchEngine(
pt_model=model,
Expand All @@ -74,11 +81,7 @@ def list_model():

def serve():
"""Run gRPC server."""
if FLAGS.model_id == "":
print("Please specify model_id with --model_id")
print("valid model ids are:")
list_model()
sys.exit(1)
_check_model_id()
devices = server_lib.get_devices()
print(f"devices: {devices}")

Expand All @@ -103,9 +106,105 @@ def serve():
jetstream_server.wait_for_termination()


def _check_model_id():
if FLAGS.model_id == "":
print("Please specify model_id with --model_id")
print("valid model ids are:")
list_model()
sys.exit(1)


def interactive():
"""Run interactive"""
raise RuntimeError("Not implemented")
_check_model_id()
devices = server_lib.get_devices()
print(f"devices: {devices}")
pt_engine = create_engine(devices)

start = time.perf_counter()
params = pt_engine.load_params()
print("Load params ", time.perf_counter() - start)

metadata = pt_engine.get_tokenizer()
tokenizer = pt_engine.build_tokenizer(metadata)
max_output_length = 1024

profiling_output = FLAGS.profiling_output
profiling_prefill = (
FLAGS.profiling_prefill
and profiling_output is not None
and profiling_output != ""
)

if profiling_prefill:
jax.profiler.start_trace(profiling_output)

decode_state = pt_engine.init_decode_state()

if profiling_prefill:
jax.profiler.stop_trace()

prompts: List[str] = [
# pylint: disable-next=all
"I believe the meaning of life is",
# pylint: disable-next=all
"To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList<Person>`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList<Person> peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.",
# pylint: disable-next=all
"<s>[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<</SYS>>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]",
# pylint: disable-next=all
"<s>[INST] <<SYS>>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<</SYS>>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST",
# pylint: disable-next=all
"<s>[INST] <<SYS>>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<</SYS>>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]",
]
for prompt in prompts:
slot = random.randint(0, FLAGS.batch_size - 1)
tokens, true_length = tokenizer.encode(prompt)

print(f"---- Input prompts are: {prompt}")
print(f"---- Encoded tokens are: {tokens}")

# pylint: disable-next=all
if profiling_prefill:
jax.profiler.start_trace(profiling_output)

prefill_result, _ = pt_engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
# pylint: disable-next=all
decode_state = pt_engine.insert(prefill_result, decode_state, slot=slot)

if profiling_prefill:
jax.profiler.stop_trace()

sampled_tokens_list = []
print(f"---- Streaming decode started on #slot{slot}.")
complete = np.zeros((1,), dtype=np.bool_)
while True:
if profiling_output:
jax.profiler.start_trace(profiling_output)

decode_state, result_tokens = pt_engine.generate(params, decode_state)
result_tokens = result_tokens.convert_to_numpy()

if profiling_output:
jax.profiler.stop_trace()

output, complete = token_utils.process_result_tokens(
tokenizer=tokenizer,
slot=slot,
slot_max_length=max_output_length,
result_tokens=result_tokens,
complete=complete,
)
if complete[0]:
break
token_ids = output[0].token_ids
sampled_tokens_list.extend(token_ids)

print("---- All output tokens.")
print(sampled_tokens_list)
print("---- All output text.")
print(tokenizer.decode(sampled_tokens_list))


def main(argv):
Expand All @@ -115,15 +214,14 @@ def main(argv):

if argv[1] == "list":
list_model()
return

if argv[1] == "serve":
elif argv[1] == "serve":
serve()
return

if argv[1] == "interative":
elif argv[1] == "interactive":
interactive()
return
else:
print(
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
)


if __name__ == "__main__":
Expand Down
18 changes: 9 additions & 9 deletions jetstream_pt/fetch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from jetstream_pt.third_party.llama import model_exportable as llama_model
from jetstream_pt.third_party.mixtral import model as mixtral_model
from jetstream_pt.third_party.gemma import model as gemma_model

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -49,6 +50,9 @@ class ModelInfo:

_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)

_gemma_2b = ModelInfo(gemma_model.GemmaModel, 18, 1, 256, 8)
_gemma_7b = ModelInfo(gemma_model.GemmaModel, 28, 16, 256, 1)


model_id_to_class = {
"meta-llama/Llama-2-7b-chat-hf": _llama2_7,
Expand All @@ -57,10 +61,10 @@ class ModelInfo:
"meta-llama/Llama-2-13b-hf": _llama2_13,
"meta-llama/Meta-Llama-3-8B": _llama3_8,
"meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8,
"google/gemma-2b": None,
"google/gemma-2b-it": None,
"google/gemma-7b": None,
"google/gemma-7b-it": None,
"google/gemma-2b": _gemma_2b,
"google/gemma-2b-it": _gemma_2b,
"google/gemma-7b": _gemma_7b,
"google/gemma-7b-it": _gemma_7b,
"mistralai/Mixtral-8x7B-v0.1": _mixtral_87,
"mistralai/Mixtral-8x7B-Instruct-v0.1": _mixtral_87,
}
Expand Down Expand Up @@ -162,11 +166,7 @@ def instantiate_model_from_repo_id(
env.device = "meta"
model = model_info.model_class.from_hf_model_id(repo_id, env)
weights = _load_weights(model_dir)
updated_keys = model.get_hf_names_to_real_name()
for name, updated in updated_keys.items():
if name in weights:
val = weights.pop(name)
weights[updated] = val
weights = model.convert_hf_weights(weights)

model.load_state_dict(weights, assign=True, strict=False)

Expand Down
11 changes: 7 additions & 4 deletions jetstream_pt/hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jetstream.engine import tokenizer_api
from jetstream.engine import tokenizer_api, token_utils


class HFTokenizerAdapter(tokenizer_api.Tokenizer):
Expand All @@ -17,7 +17,10 @@ def encode(self, s: str, **kwargs):
true_length: Actual length of the non-padded sequence
if padding is used.
"""
return self(s)
res = self.tokenizer.encode(s, add_special_tokens=False)
return token_utils.pad_tokens(
res, self.bos_id, self.pad_id, jax_padding=True
)

def decode(self, token_ids: list[int], **kwargs) -> str:
"""Processess input token ids to generate a string.
Expand All @@ -27,7 +30,7 @@ def decode(self, token_ids: list[int], **kwargs) -> str:
Returns:
str: String generated from the token ids.
"""
return self.decode(token_ids)
return self.tokenizer.decode(token_ids)

@property
def pad_id(self) -> int:
Expand All @@ -47,4 +50,4 @@ def bos_id(self) -> int:
@property
def stop_tokens(self) -> set[int]:
"""ID of the stop token."""
return {self.eos_id, self.pad_id}
return {self.eos_id}
4 changes: 2 additions & 2 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def create_quantized_from_nn_embedding(
float_embedding.num_embeddings,
float_embedding.embedding_dim,
)
weights, scaler, _ = quantize_tensor(float_embedding.weight, 1)
weights, scaler, _ = quantize_tensor(float_embedding.weight, 0)
obj.weight = weights
obj.scaler = scaler
obj.weight_scaler = scaler
return obj


Expand Down
32 changes: 28 additions & 4 deletions jetstream_pt/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,18 @@ class AttrProperty:


class ModuleBase(torch.nn.Module, metaclass=abc.ABCMeta):
"""nn Module that allows attaching properties"""
"""nn Module that allows attaching properties.

This class currently serves 2 goals:
1. Allow model to specify alternative names for submodules / weights
this is needed so that it can *also* load HuggingFace checkpoints
without need to do massive rewrites.

2. Allow model to attach information to weights, such as sharding config.

Quantization config could be another thing to attach, but right now it's not used
this way.
"""

attr_to_property: Dict[str, Any]

Expand Down Expand Up @@ -74,6 +85,19 @@ def annotate_sharding(self, name, axis):
"""Set sharding name for a attribute or submodule."""
self.attr_to_property[name].sharding_axis = axis

def drop_weight(self, key):
"""list out names to discard."""
return False
def convert_hf_weights(
self, hf_weights: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Load state_dict with hg weights."""
weights = {}
updated_keys = self.get_hf_names_to_real_name()
for name, updated in updated_keys.items():
if name in hf_weights:
weights[updated] = hf_weights[name]

for name in list(weights.keys()):
if "inv_freq" in name:
weights.pop(name)
if hasattr(self, "freqs_cis"):
weights["freqs_cis"] = self.freqs_cis
return weights
Loading