In [None]:
# Imports

from dataclasses import dataclass
from typing import List

import random

import asyncio
import re
import math
import random
import numpy as np
from sympy import simplify

random.seed(0)

from async_engine.batched_api import BatchingAPI
from async_engine.api import API

from src.prompts.adapt import gameof24 as llama_prompts
from utils import parse_suggestions, create_box

In [None]:
# State class

@dataclass(frozen=True)
class GameOf24State:
    # game of 24 puzzle, for example 1 1 4 6
    puzzle: str

    # initialized to the same value as puzzle, but is updated as the game progresses
    current_state: str

    steps: List[str]

    #Randomness used for resampling (random seed)
    randomness: int

    def __hash__(self):
        return hash((self.puzzle, self.current_state, " -> ".join(self.steps)))
    
    def items(self):
        return self.puzzle, self.current_state, self.steps, self.randomness
    
    def duplicate(self, randomness=None):
        return GameOf24State(
            puzzle=self.puzzle,
            current_state=self.current_state,
            steps=self.steps,
            randomness=randomness if randomness is not None else self.randomness)

In [None]:
#Testing the game of 24

states = []
puzzle = "1 1 4 6"
example = GameOf24State(puzzle=puzzle, current_state=puzzle, steps=[], randomness=random.randint(0, 1000))

# for step in range(2):
#     print(f"Step {step} : Stepping")

print("Initial State")
print(example.items(), "\n")

print("Step: 0")
example = GameOf24State(puzzle="1 1 4 6", current_state="1 4 6", steps=example.steps + ["1 * 1 = 1"], randomness=random.randint(0, 1000))
print(example.items(), "\n")


print("Step: 1")
example = GameOf24State(puzzle="1 1 4 6", current_state="1 24", steps=example.steps + ["4 * 6 = 24"], randomness=random.randint(0, 1000))
print(example.items(), "\n")

print("Step: 2")
example = GameOf24State(puzzle="1 1 4 6", current_state="24", steps=example.steps + ["1 * 24 = 25"], randomness=random.randint(0, 1000))
print(example.items(), "\n")

In [None]:
#Reflexion agent :O

class GameOf24Agent:

    @staticmethod
    async def step(state: GameOf24State, api, namespace, reflection: list)-> GameOf24State:
        """
        Given a state, returns the next state one.
        """

        # set up the prompt, based on the current state

        # ToT uses bfs_prompt to generate next steps but then uses
        # the cot_prompt to get the final expression. 
        # For example, input : 1 1 4 6
        # Step 0 : '1 - 1 = 0 (left: 0 4 6)'          BFS prompt
        # Step 1 : '0 + 4 = 4 (left: 4 6)'            BFS prompt
        # Step 2 : '4 * 6 = 24 (left: 24)'            BFS prompt
        # Step 3 : Answer : ((1 - 1) + 4) * 6 = 24    CoT prompt


        # set up the prompt, based on the current state

        current_state = state.current_state
        
        if current_state.strip() == "24":
            # CoT prompt
            steps = "\n".join(state.steps) + "\n"
            
            if len(reflection) == 0:
                prompt = llama_prompts.cot_prompt.format(input=state.puzzle) + "Steps:\n" + steps + "Answer: "
            else:
                prompt = llama_prompts.bfs_reflexion_prompt.format(input=current_state, puzzle = "1 1 4 6", reflection=reflection[0], steps=reflection[1]) 
            

            # Set up CoT prompt
            # if any(author in api.model for author in ["meta", "google", "mistral", "gpt-4o"]):
            #     prompt = llama_prompts.cot_prompt.format(input=state.puzzle) + "Steps:\n" + steps + "Answer: "
            # else:
            #     prompt = totor_prompts.cot_prompt.format(input=state.puzzle) + "Steps:\n" + steps

            # Get the final expression
            suggestions = await api.buffered_request(prompt, key=hash(state), namespace=namespace)

            # State does not change, only the steps
            selected_suggestion = suggestions
            selected_state = state.current_state
            


        else:
            if len(reflection) == 0:
                prompt = llama_prompts.bfs_prompt.format(input=current_state) 
            else:
                prompt = llama_prompts.bfs_reflexion_prompt.format(input=current_state, puzzle = "1 1 4 6", reflection=reflection[0], steps=reflection[1]) 
                
            # Set up BFS prompt
            # if any(author in api.model for author in ["meta", "google", "mistral", "gpt-4o"]):
            #     prompt = llama_prompts.bfs_prompt.format(input=current_state) + "Keep in mind the following critique from the last step: \n" + reflexion_suggestions
            # else:
            #     prompt = totor_prompts.bfs_prompt.format(input=current_state) + "Keep in mind the following critique from the last step: \n" + reflexion_suggestions

            # Get the next state
            # suggestions = await api.buffered_request(prompt, key=hash(state), namespace=namespace)

            suggestions = await api.buffered_request(prompt, key=hash(state), namespace=namespace)
            #print("suggestions: ", suggestions)

            # parse suggestions, based on the current state
            parsed_suggestions = parse_suggestions(suggestions)
            if parsed_suggestions == []:
                print(f"No suggestions were paresed from state: {state}")
                print(f"\nPrompt: {prompt}\nSuggestions: {suggestions}\nParsed suggestions: {' | '.join(parsed_suggestions)}\n")
                assert False, "No suggestions found."
            
            suggestions = parsed_suggestions
            
            random.seed(state.randomness)
            selected_suggestion = random.choice(suggestions)
            selected_state = GameOf24Agent.parse_next_state(selected_suggestion)

        # set up new state object
        next_state = GameOf24State(
            puzzle=state.puzzle,
            current_state=selected_state,
            steps=state.steps + [selected_suggestion],
            randomness=random.randint(0, 1000)
        )
        return next_state
    
    @staticmethod
    def parse_next_state(suggestion: str) -> str:
        return suggestion.split('left: ')[-1].split(')')[0]
    
    @staticmethod
    def verify(state: GameOf24State)-> dict:
            """
            Verifies the output of a given task
                1. Checks if the numbers used are the same as the ones provided.
                2. Checks if the operations performed result to 24.

            States 
                {"r": 0} : Not finished.
                {"r": 1} : Finished and correct.
                {"r": -1} : Finished and incorrect.
            """
            current_states = state.current_state.split(" ")
            if len(current_states) !=1 or len(state.steps)<=3:
                # More than one number left
                return {'r':0}
            elif current_states[0] != "24":
                # One number left and it is not 24
                return {'r':-1}
            else:
                # One number left and it is 24
                expression = state.steps[-1].lower().replace('answer: ', '').split('=')[0]
                numbers = re.findall(r'\d+', expression)
                problem_numbers = re.findall(r'\d+', state.puzzle)
                if sorted(numbers) != sorted(problem_numbers):
                    # Numbers used are not the same as the ones provided
                    return {'r': -1}
                try:
                    if simplify(expression) == 24:
                        return {'r': 1}
                    else:
                        # Operations performed do not result to 24
                        return {'r': -1}
                except Exception as e:
                    print(e)
                    return {'r': -1}

    @staticmethod
    def generate_reflection(puzzle: str, steps, state: GameOf24State, api, namespace) -> str:
        prompt = llama_prompts.reflexion_prompt.format(puzzle=puzzle, steps=steps)
        reflection = api.buffered_request(prompt, key=hash(state), namespace=namespace)
        return reflection


# Solve 1 1 4 6 puzzle:

In [None]:
# Initialization

step_api_config = eval_api_config = {
    "max_tokens": 1000,
    "temperature": 0.7,
    "top_p": 1,
    "request_timeout": 120,
    "top_k": 50
}

# eligible providers ["TogehterAI", "OpenAI", "Groq"]
model = "llama-3.3-70b-versatile"
provider = "Groq"
models = {
    "step": {"model_name":model, "provider":provider},
    "eval": {"model_name":model, "provider":provider},
}

api = API(eval_api_config, models=models.values(), resources=2, verbose=False)

states = []
puzzle = "1 1 4 6"
num_steps = 4

#Create initial state/environment
game_env = GameOf24State(puzzle=puzzle, current_state=puzzle, steps=[], randomness=random.randint(0, 1000))
step_batcher = BatchingAPI(api, batch_size=1, timeout=2, model=models["step"]["model_name"], tab="step")

states.append(game_env)


In [None]:
# Attempting to solve the puzzle (without reflexion)

states = [game_env]
finished_states = []

#Stepping
for step in range(num_steps):
    
    print(f"Step {step} : Stepping")
    agent_tasks = [
        asyncio.create_task(
        GameOf24Agent.step(state, step_batcher, namespace=(0, f"Agent: {agent_id}", f"Step : {step}"), reflection="")
        )
        for agent_id, state in enumerate(states)
    ]
    states = await asyncio.gather(*agent_tasks)
    print(f"Current step: {states[-1].steps[-1]} \n")

    # Evaluate whether a puzzle has been solved
    i = 0
    while i < len(states):
        if GameOf24Agent.verify(states[i]) == {"r": 1}:
            print(f"Puzzle finished: {states[i].puzzle}")
            finished_states.append(states.pop(i))
        else:
            i += 1

    # If all puzzles have been solved, break
    if len(states) == 0:
        break

In [None]:
# Generate reflexions if the puzzle is not solved

agent_reflections = [
    asyncio.create_task(
    GameOf24Agent.generate_reflection(puzzle=puzzle, steps=state.steps, state=state, api=step_batcher, namespace=(0, f"Agent: {agent_id}", f"Step : {step}"))
    )
    for agent_id, state in enumerate(states)
]
reflection = await asyncio.gather(*agent_reflections)
reflection.append(states[0].steps)

print(f"Reflection: {reflection[0]}")

In [None]:
# Reattempting to solve the puzzle (with reflexion)

#Resetting
states = [game_env]
finished_states = []

#Stepping
for step in range(num_steps):
    
    print(f"Step {step} : Stepping")
    agent_tasks = [
        asyncio.create_task(
        GameOf24Agent.step(state, step_batcher, namespace=(0, f"Agent: {agent_id}", f"Step : {step}"), reflection=reflection)
        )
        for agent_id, state in enumerate(states)
    ]
    states = await asyncio.gather(*agent_tasks)
    print(f"Current step: {states[-1].steps[-1]} \n")

    # Evaluate whether a puzzle has been solved
    i = 0
    while i < len(states):
        if GameOf24Agent.verify(states[i]) == {"r": 1}:
            print(f"Puzzle finished: {states[i].puzzle}")
            finished_states.append(states.pop(i))
        else:
            i += 1

    # If all puzzles have been solved, break
    if len(states) == 0:
        break