Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions run_interactive.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from absl import app
from absl import flags
from absl import logging
import random
from typing import List
import sys
import jax
import jax.numpy as jnp
import numpy as np

from jetstream.engine import token_utils
from absl.testing import absltest
from colorama import Fore, Back, Style


import os
import sys
Expand All @@ -16,6 +20,7 @@
import time
import logging


logging.getLogger().setLevel(logging.ERROR)


Expand Down Expand Up @@ -86,21 +91,28 @@ def main(argv):
params = engine.load_params()
print('Load params ', time.perf_counter() - start)

prefill_times = {}
slot = jnp.int32(0)
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(
metadata.path, metadata.extra_ids)
tokenizer = vocab.tokenizer
stop_tokens = [vocab.eos_id, vocab.pad_id]
max_output_length = 1024

while True:
# text = input('Text >>>> ')
text = 'I believe the meaning of life is'
decode_state = engine.init_decode_state()
tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True)
# tokens = tokenizer.encode(text)
# tokens = [tokenizer.bos_id()] + tokens
print('Encoded tokens are: ', tokens)
if _PROFILING_OUTPUT.value:
jax.profiler.start_trace(_PROFILING_OUTPUT.value)

decode_state = engine.init_decode_state()
prompts: List[str] = [
"I believe the meaning of life is",
"To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList<Person>`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList<Person> peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.",
"<s>[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<</SYS>>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]",
"<s>[INST] <<SYS>>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<</SYS>>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST",
"<s>[INST] <<SYS>>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<</SYS>>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]",
]
for prompt in prompts:
slot = random.randint(0, _BATCH_SIZE.value)
tokens, true_length = token_utils.tokenize_and_pad(prompt, vocab, is_bos=True)
print(f"---- Input prompts are: {prompt}")
print(f"---- Encoded tokens are: {tokens}")

prefill_result = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
Expand All @@ -109,25 +121,36 @@ def main(argv):
prefill_result, decode_state, slot=slot
)
sampled_tokens_list = []
for i in range(100):
decode_state, sampled_tokens = engine.generate(
print(f"---- Streaming decode started on #slot{slot}.")
while True:
decode_state, result_tokens = engine.generate(
params, decode_state
)
tstart, end = sampled_tokens.tokens_idx
sampled_tokens_list.append(sampled_tokens.data[0, 0].item())

print('---- ans ----')
print(sampled_tokens_list)
print(tokenizer.decode(sampled_tokens_list))
break
slot_data = result_tokens.get_result_at_slot(slot)
slot_tokens = slot_data.tokens
slot_lengths = slot_data.lengths

token_id = slot_tokens[slot, 0].item()
if slot_lengths > max_output_length or token_id in stop_tokens:
break

sampled_tokens_list.append(token_id)
output = token_utils.mix_decode(vocab, token_id)
print(Fore.GREEN + output, end="", flush=True)

print(Style.RESET_ALL + "\n")
print("---- Streaming decode finished.")


print("---- All output tokens.")
print(sampled_tokens_list)


if _PROFILING_OUTPUT.value:
jax.profiler.stop_trace()



if __name__ == "__main__":
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
Expand Down