From 4e0b15165ca0297b40c143510acde4afd30e72cd Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Thu, 4 Apr 2024 14:23:40 +0000 Subject: [PATCH] Refactor interactive --- run_interactive.py | 63 +++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/run_interactive.py b/run_interactive.py index 56c797f1..f459d41e 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -1,6 +1,8 @@ from absl import app from absl import flags from absl import logging +import random +from typing import List import sys import jax import jax.numpy as jnp @@ -8,6 +10,8 @@ from jetstream.engine import token_utils from absl.testing import absltest +from colorama import Fore, Back, Style + import os import sys @@ -16,6 +20,7 @@ import time import logging + logging.getLogger().setLevel(logging.ERROR) @@ -86,21 +91,28 @@ def main(argv): params = engine.load_params() print('Load params ', time.perf_counter() - start) - prefill_times = {} - slot = jnp.int32(0) metadata = engine.get_tokenizer() vocab = token_utils.load_vocab( metadata.path, metadata.extra_ids) - tokenizer = vocab.tokenizer + stop_tokens = [vocab.eos_id, vocab.pad_id] + max_output_length = 1024 - while True: - # text = input('Text >>>> ') - text = 'I believe the meaning of life is' - decode_state = engine.init_decode_state() - tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True) - # tokens = tokenizer.encode(text) - # tokens = [tokenizer.bos_id()] + tokens - print('Encoded tokens are: ', tokens) + if _PROFILING_OUTPUT.value: + jax.profiler.start_trace(_PROFILING_OUTPUT.value) + + decode_state = engine.init_decode_state() + prompts: List[str] = [ + "I believe the meaning of life is", + "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", + "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", + "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", + "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", + ] + for prompt in prompts: + slot = random.randint(0, _BATCH_SIZE.value) + tokens, true_length = token_utils.tokenize_and_pad(prompt, vocab, is_bos=True) + print(f"---- Input prompts are: {prompt}") + print(f"---- Encoded tokens are: {tokens}") prefill_result = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length @@ -109,25 +121,36 @@ def main(argv): prefill_result, decode_state, slot=slot ) sampled_tokens_list = [] - for i in range(100): - decode_state, sampled_tokens = engine.generate( + print(f"---- Streaming decode started on #slot{slot}.") + while True: + decode_state, result_tokens = engine.generate( params, decode_state ) - tstart, end = sampled_tokens.tokens_idx - sampled_tokens_list.append(sampled_tokens.data[0, 0].item()) - print('---- ans ----') - print(sampled_tokens_list) - print(tokenizer.decode(sampled_tokens_list)) - break + slot_data = result_tokens.get_result_at_slot(slot) + slot_tokens = slot_data.tokens + slot_lengths = slot_data.lengths + + token_id = slot_tokens[slot, 0].item() + if slot_lengths > max_output_length or token_id in stop_tokens: + break + + sampled_tokens_list.append(token_id) + output = token_utils.mix_decode(vocab, token_id) + print(Fore.GREEN + output, end="", flush=True) + print(Style.RESET_ALL + "\n") + print("---- Streaming decode finished.") + + + print("---- All output tokens.") + print(sampled_tokens_list) if _PROFILING_OUTPUT.value: jax.profiler.stop_trace() - if __name__ == "__main__": import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"