Skip to content

Commit

Permalink
Merge pull request #44 from SamPink/llama
Browse files Browse the repository at this point in the history
Added support for anthropic and others
  • Loading branch information
oulianov committed Apr 15, 2024
2 parents 2206c4e + 8ed79df commit 5701efb
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 47 deletions.
1 change: 1 addition & 0 deletions agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .robot import Robot
from .observer import KEN_GREEN, KEN_RED
from .llm import get_client
2 changes: 2 additions & 0 deletions agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"mistral:mistral-large-latest",
# "groq:mistral-8x6b-32768",
},
"GROQ": {"groq:gemma-7b-it"},
"ANTHROPIC": {"anthropic:claude-3-haiku-20240307"},
}


Expand Down
25 changes: 25 additions & 0 deletions agent/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
def get_client(model_str):
split_result = model_str.split(":")
if len(split_result) == 1:
# Assume default provider to be openai
provider = "openai"
model_name = split_result[0]
elif len(split_result) > 2:
# Some model names have :, so we need to join the rest of the string
provider = split_result[0]
model_name = ":".join(split_result[1:])
else:
provider = split_result[0]
model_name = split_result[1]
if provider == "openai":
from llama_index.llms.openai import OpenAI

return OpenAI(model=model_name)
elif provider == "anthropic":
from llama_index.llms.anthropic import Anthropic

return Anthropic(model=model_name)
elif provider == "mixtral" or provider == "groq":
from llama_index.llms.groq import Groq

return Groq(model=model_name)
95 changes: 51 additions & 44 deletions agent/robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from gymnasium import spaces
from loguru import logger
from phospho.lab import get_provider_and_model, get_sync_client
from llama_index.core.llms import ChatMessage
from rich import print

from .config import (
Expand All @@ -21,6 +21,7 @@
Y_SIZE,
)
from .observer import detect_position_from_color
from .llm import get_client


class Robot:
Expand Down Expand Up @@ -289,35 +290,44 @@ def get_moves_from_llm(
return [random.choice(list(MOVES.values()))]

while len(valid_moves) == 0:
llm_response = self.call_llm()

# The response is a bullet point list of moves. Use regex
matches = re.findall(r"- ([\w ]+)", llm_response)
moves = ["".join(match) for match in matches]
invalid_moves = []
valid_moves = []
for move in moves:
cleaned_move_name = move.strip().lower()
if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys():
if self.player_nb == 1:
print(
f"[red] Player {self.player_nb} move: {cleaned_move_name}"
)
elif self.player_nb == 2:
print(
f"[green] Player {self.player_nb} move: {cleaned_move_name}"
)
valid_moves.append(cleaned_move_name)
else:
logger.debug(f"Invalid completion: {move}")
logger.debug(f"Cleaned move name: {cleaned_move_name}")
invalid_moves.append(move)

if len(invalid_moves) > 1:
logger.warning(f"Many invalid moves: {invalid_moves}")

logger.debug(f"Next moves: {valid_moves}")
return valid_moves
llm_stream = self.call_llm()

# adding support for streaming the response
# this should make the players faster!

llm_response = ""

for r in llm_stream:
print(r.delta, end="")
llm_response += r.delta

# The response is a bullet point list of moves. Use regex
matches = re.findall(r"- ([\w ]+)", llm_response)
moves = ["".join(match) for match in matches]
invalid_moves = []
valid_moves = []
for move in moves:
cleaned_move_name = move.strip().lower()
if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys():
if self.player_nb == 1:
print(
f"[red] Player {self.player_nb} move: {cleaned_move_name}"
)
elif self.player_nb == 2:
print(
f"[green] Player {self.player_nb} move: {cleaned_move_name}"
)
valid_moves.append(cleaned_move_name)
else:
logger.debug(f"Invalid completion: {move}")
logger.debug(f"Cleaned move name: {cleaned_move_name}")
invalid_moves.append(move)

if len(invalid_moves) > 1:
logger.warning(f"Many invalid moves: {invalid_moves}")

logger.debug(f"Next moves: {valid_moves}")
return valid_moves

def call_llm(
self,
Expand All @@ -330,8 +340,6 @@ def call_llm(
Edit this method to change the behavior of the robot!
"""
provider_name, model_name = get_provider_and_model(self.model)
client = get_sync_client(provider_name)

# Generate the prompts
move_list = "- " + "\n - ".join([move for move in META_INSTRUCTIONS])
Expand All @@ -351,17 +359,16 @@ def call_llm(
- Move closer"""

start_time = time.time()
completion = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Your next moves are:"},
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

client = get_client(self.model)

messages = [
ChatMessage(role="system", content=system_prompt),
ChatMessage(role="user", content="Your next moves are:"),
]
resp = client.stream_chat(messages)

logger.debug(f"LLM call to {self.model}: {system_prompt}")
logger.debug(f"LLM call to {self.model}: {time.time() - start_time}s")
llm_response = completion.choices[0].message.content.strip()
return llm_response

return resp
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ tqdm==4.66.2; python_version >= '3.7'
typing-extensions==4.10.0; python_version >= '3.8'
tzdata==2024.1; python_version >= '2'
rich==13.7.1
llama-index
llama-index-llms-openai
llama-index-llms-anthropic
llama-index-llms-groq
6 changes: 3 additions & 3 deletions script.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys

from dotenv import load_dotenv
from eval.game import Game, Player1, Player2, generate_random_model
from eval.game import Game, Player1, Player2
from loguru import logger

logger.remove()
Expand All @@ -16,11 +16,11 @@ def main():
render=True,
player_1=Player1(
nickname="Daddy",
model=generate_random_model(mistral=True),
model="groq:gemma-7b-it",
),
player_2=Player2(
nickname="Baby",
model=generate_random_model(openai=True),
model="anthropic:claude-3-haiku-20240307",
),
)
return game.run()
Expand Down

0 comments on commit 5701efb

Please sign in to comment.