diff --git a/agent/__init__.py b/agent/__init__.py index 33d8954..e11832f 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -5,3 +5,4 @@ from .robot import Robot from .observer import KEN_GREEN, KEN_RED +from .llm import get_client diff --git a/agent/config.py b/agent/config.py index 25c26ca..64e1ffa 100644 --- a/agent/config.py +++ b/agent/config.py @@ -15,6 +15,8 @@ "mistral:mistral-large-latest", # "groq:mistral-8x6b-32768", }, + "GROQ": {"groq:gemma-7b-it"}, + "ANTHROPIC": {"anthropic:claude-3-haiku-20240307"}, } diff --git a/agent/llm.py b/agent/llm.py new file mode 100644 index 0000000..738eca3 --- /dev/null +++ b/agent/llm.py @@ -0,0 +1,25 @@ +def get_client(model_str): + split_result = model_str.split(":") + if len(split_result) == 1: + # Assume default provider to be openai + provider = "openai" + model_name = split_result[0] + elif len(split_result) > 2: + # Some model names have :, so we need to join the rest of the string + provider = split_result[0] + model_name = ":".join(split_result[1:]) + else: + provider = split_result[0] + model_name = split_result[1] + if provider == "openai": + from llama_index.llms.openai import OpenAI + + return OpenAI(model=model_name) + elif provider == "anthropic": + from llama_index.llms.anthropic import Anthropic + + return Anthropic(model=model_name) + elif provider == "mixtral" or provider == "groq": + from llama_index.llms.groq import Groq + + return Groq(model=model_name) diff --git a/agent/robot.py b/agent/robot.py index 80a04a9..ee36869 100644 --- a/agent/robot.py +++ b/agent/robot.py @@ -8,7 +8,7 @@ import numpy as np from gymnasium import spaces from loguru import logger -from phospho.lab import get_provider_and_model, get_sync_client +from llama_index.core.llms import ChatMessage from rich import print from .config import ( @@ -21,6 +21,7 @@ Y_SIZE, ) from .observer import detect_position_from_color +from .llm import get_client class Robot: @@ -289,35 +290,44 @@ def get_moves_from_llm( return [random.choice(list(MOVES.values()))] while len(valid_moves) == 0: - llm_response = self.call_llm() - - # The response is a bullet point list of moves. Use regex - matches = re.findall(r"- ([\w ]+)", llm_response) - moves = ["".join(match) for match in matches] - invalid_moves = [] - valid_moves = [] - for move in moves: - cleaned_move_name = move.strip().lower() - if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys(): - if self.player_nb == 1: - print( - f"[red] Player {self.player_nb} move: {cleaned_move_name}" - ) - elif self.player_nb == 2: - print( - f"[green] Player {self.player_nb} move: {cleaned_move_name}" - ) - valid_moves.append(cleaned_move_name) - else: - logger.debug(f"Invalid completion: {move}") - logger.debug(f"Cleaned move name: {cleaned_move_name}") - invalid_moves.append(move) - - if len(invalid_moves) > 1: - logger.warning(f"Many invalid moves: {invalid_moves}") - - logger.debug(f"Next moves: {valid_moves}") - return valid_moves + llm_stream = self.call_llm() + + # adding support for streaming the response + # this should make the players faster! + + llm_response = "" + + for r in llm_stream: + print(r.delta, end="") + llm_response += r.delta + + # The response is a bullet point list of moves. Use regex + matches = re.findall(r"- ([\w ]+)", llm_response) + moves = ["".join(match) for match in matches] + invalid_moves = [] + valid_moves = [] + for move in moves: + cleaned_move_name = move.strip().lower() + if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys(): + if self.player_nb == 1: + print( + f"[red] Player {self.player_nb} move: {cleaned_move_name}" + ) + elif self.player_nb == 2: + print( + f"[green] Player {self.player_nb} move: {cleaned_move_name}" + ) + valid_moves.append(cleaned_move_name) + else: + logger.debug(f"Invalid completion: {move}") + logger.debug(f"Cleaned move name: {cleaned_move_name}") + invalid_moves.append(move) + + if len(invalid_moves) > 1: + logger.warning(f"Many invalid moves: {invalid_moves}") + + logger.debug(f"Next moves: {valid_moves}") + return valid_moves def call_llm( self, @@ -330,8 +340,6 @@ def call_llm( Edit this method to change the behavior of the robot! """ - provider_name, model_name = get_provider_and_model(self.model) - client = get_sync_client(provider_name) # Generate the prompts move_list = "- " + "\n - ".join([move for move in META_INSTRUCTIONS]) @@ -351,17 +359,16 @@ def call_llm( - Move closer""" start_time = time.time() - completion = client.chat.completions.create( - model=model_name, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": "Your next moves are:"}, - ], - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - ) + + client = get_client(self.model) + + messages = [ + ChatMessage(role="system", content=system_prompt), + ChatMessage(role="user", content="Your next moves are:"), + ] + resp = client.stream_chat(messages) + logger.debug(f"LLM call to {self.model}: {system_prompt}") logger.debug(f"LLM call to {self.model}: {time.time() - start_time}s") - llm_response = completion.choices[0].message.content.strip() - return llm_response + + return resp diff --git a/requirements.txt b/requirements.txt index cf401ff..b591484 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,3 +47,7 @@ tqdm==4.66.2; python_version >= '3.7' typing-extensions==4.10.0; python_version >= '3.8' tzdata==2024.1; python_version >= '2' rich==13.7.1 +llama-index +llama-index-llms-openai +llama-index-llms-anthropic +llama-index-llms-groq \ No newline at end of file diff --git a/script.py b/script.py index 0525a3a..27fa62e 100644 --- a/script.py +++ b/script.py @@ -1,7 +1,7 @@ import sys from dotenv import load_dotenv -from eval.game import Game, Player1, Player2, generate_random_model +from eval.game import Game, Player1, Player2 from loguru import logger logger.remove() @@ -16,11 +16,11 @@ def main(): render=True, player_1=Player1( nickname="Daddy", - model=generate_random_model(mistral=True), + model="groq:gemma-7b-it", ), player_2=Player2( nickname="Baby", - model=generate_random_model(openai=True), + model="anthropic:claude-3-haiku-20240307", ), ) return game.run()