From 8ed79dff16bd8c2d3769fe4155056fe914a4540a Mon Sep 17 00:00:00 2001 From: Sam Pink Date: Mon, 15 Apr 2024 10:08:55 +0100 Subject: [PATCH] Working on adding streaming --- agent/robot.py | 72 ++++++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/agent/robot.py b/agent/robot.py index 278e34b..ee36869 100644 --- a/agent/robot.py +++ b/agent/robot.py @@ -290,35 +290,44 @@ def get_moves_from_llm( return [random.choice(list(MOVES.values()))] while len(valid_moves) == 0: - llm_response = self.call_llm() - - # The response is a bullet point list of moves. Use regex - matches = re.findall(r"- ([\w ]+)", llm_response) - moves = ["".join(match) for match in matches] - invalid_moves = [] - valid_moves = [] - for move in moves: - cleaned_move_name = move.strip().lower() - if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys(): - if self.player_nb == 1: - print( - f"[red] Player {self.player_nb} move: {cleaned_move_name}" - ) - elif self.player_nb == 2: - print( - f"[green] Player {self.player_nb} move: {cleaned_move_name}" - ) - valid_moves.append(cleaned_move_name) - else: - logger.debug(f"Invalid completion: {move}") - logger.debug(f"Cleaned move name: {cleaned_move_name}") - invalid_moves.append(move) - - if len(invalid_moves) > 1: - logger.warning(f"Many invalid moves: {invalid_moves}") - - logger.debug(f"Next moves: {valid_moves}") - return valid_moves + llm_stream = self.call_llm() + + # adding support for streaming the response + # this should make the players faster! + + llm_response = "" + + for r in llm_stream: + print(r.delta, end="") + llm_response += r.delta + + # The response is a bullet point list of moves. Use regex + matches = re.findall(r"- ([\w ]+)", llm_response) + moves = ["".join(match) for match in matches] + invalid_moves = [] + valid_moves = [] + for move in moves: + cleaned_move_name = move.strip().lower() + if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys(): + if self.player_nb == 1: + print( + f"[red] Player {self.player_nb} move: {cleaned_move_name}" + ) + elif self.player_nb == 2: + print( + f"[green] Player {self.player_nb} move: {cleaned_move_name}" + ) + valid_moves.append(cleaned_move_name) + else: + logger.debug(f"Invalid completion: {move}") + logger.debug(f"Cleaned move name: {cleaned_move_name}") + invalid_moves.append(move) + + if len(invalid_moves) > 1: + logger.warning(f"Many invalid moves: {invalid_moves}") + + logger.debug(f"Next moves: {valid_moves}") + return valid_moves def call_llm( self, @@ -357,8 +366,9 @@ def call_llm( ChatMessage(role="system", content=system_prompt), ChatMessage(role="user", content="Your next moves are:"), ] - llm_response = client.chat(messages).message.content + resp = client.stream_chat(messages) logger.debug(f"LLM call to {self.model}: {system_prompt}") logger.debug(f"LLM call to {self.model}: {time.time() - start_time}s") - return llm_response + + return resp