In [1]:
# imports
import os
import asyncio
from diskcache import Cache
from groq import AsyncGroq
from omegaconf import OmegaConf
from cachesaver.pipelines import OnlineAPI
from cachesaver.typedefs import Response, Batch
from typing import Any, List

import sys
sys.path.append(os.getcwd())

from src.utils import tokens2cost
from src.algorithms import *
from src.models import OnlineLLM, API
from src.typedefs import DecodingParameters, Model
from src.tasks.game24 import EnvironmentGame24, AgentBfsGame24, AgentAggregateGame24, AgentEvaluateGame24, StateGame24

In [2]:
from groq import RateLimitError

class MockLLM(Model):
    def __init__(self, client: Any, model: str)-> None:
        self.client = client
        self.model = model

    async def request(self, prompt: str, n: int, request_id: int, namespace: str, params: DecodingParameters) -> Response:
        sleep = 1
        while True:
            try:
                completion = await self.client.chat.completions.create(
                    messages = [
                        {
                            "role" : "user",
                            "content" : prompt
                        }
                    ],
                    model = self.model,
                    n = n,
                    max_tokens= params.max_completion_tokens or None, # or None not needed but just to be explicit
                    temperature = params.temperature or 1,
                    stop = params.stop or None,
                    top_p = params.top_p or 1,
                    seed = 1234,
                    logprobs = params.logprobs or False,
                    top_logprobs = None,
                )
                break
            except RateLimitError as e:
                await asyncio.sleep(max(sleep, 90))
                sleep *= 2
            except Exception as e:
                print(f"Error {e}")
                raise e
        input_tokens = completion.usage.prompt_tokens
        completion_tokens = completion.usage.completion_tokens
        response = [choice.message.content for choice in completion.choices]
        return response
    
    async def batch_request(self, batch: Batch) -> List[Response]:
        requests = [self.request(request) for request in batch.requests]
        completions = await asyncio.gather(*requests)
        return completions

In [3]:
llm = "llama-3.3-70b-versatile"
game_simple = "10 10 1 4"

cache = Cache(f"caches\\game24")

client = AsyncGroq()
model = MockLLM(client=client, model=llm)
# pipeline = OnlineAPI(
#     model=model,
#     cache=cache,
#     batch_size=2,
#     timeout=0.1,
# )
# api = API(
#     pipeline=pipeline,
#     model=llm
# )

state = StateGame24(
    puzzle=game_simple,
    current_state=game_simple,
    steps=[],
    randomness=None,
)

In [4]:
params = DecodingParameters(
    temperature=0.7,
    max_completion_tokens=100,
    top_p=1.0,
    stop=None,
    logprobs=False,
)

config = OmegaConf.load("scripts\\game24.yaml")

agents = AgentDictGOT(
    step=AgentBfsGame24,
    aggregate=AgentAggregateGame24,
    evaluate=AgentEvaluateGame24,
    step_params=params,
    aggregate_params=params,
    eval_params=params,
)
method = AlgorithmGOT(
    model=model,
    agents=agents,
    env=EnvironmentGame24,
    num_selections=config.got.num_selections,
    num_steps=config.got.num_steps,
    num_best=config.got.num_best,
    num_evaluations=config.got.num_evaluations,
)

results = await method.solve(idx=0, state=state, namespace="small", value_cache=None)

Step 0, current state: [StateGame24(puzzle='10 10 1 4', current_state='10 10 1 4', steps=[], randomness=6311)]


IndexError: list index out of range

# Snippet of each method, to test them each and individual

In [8]:
generate = AgentBfsGame24()
aggregate = AgentAggregateGame24()
evaluate = AgentEvaluateGame24()
env = EnvironmentGame24()

generate_results = await generate.act(model=model, state=state, namespace="small", request_id=0, params=params)
print(generate_results)

aggregate_results = await aggregate.act(model=model, state=state, actions=generate_results, k=3, n=1, namespace="small", request_id=0, params=params)
print(aggregate_results)

proposals = []
for action in aggregate_results:
    proposals.append(env.step(state, action))

print(proposals)

evaluate_coroutines = [evaluate.act(model=model, state=state, n=1, namespace="small", request_id=0, params=params, cache=None) for state in proposals]
evaluate_results = await asyncio.gather(*evaluate_coroutines)
print(evaluate_results)

['10 + 10 = 20 (left: 20 1 4)', '10 + 1 = 11 (left: 10 11 4)', '10 + 4 = 14 (left: 10 14 1)', '10 * 10 = 100 (left: 100 1 4)', '10 * 1 = 10 (left: 10 10 4)']
['10 + 10 = 20 (left: 20 1 4)', '10 + 4 = 14 (left: 10 14 1)', '10 * 1 = 10 (left: 10 10 4)']
[StateGame24(puzzle='10 10 1 4', current_state='20 1 4', steps=['10 + 10 = 20 (left: 20 1 4)'], randomness=5723), StateGame24(puzzle='10 10 1 4', current_state='10 14 1', steps=['10 + 4 = 14 (left: 10 14 1)'], randomness=6963), StateGame24(puzzle='10 10 1 4', current_state='10 10 4', steps=['10 * 1 = 10 (left: 10 10 4)'], randomness=5602)]
[20.0, 20.0, 20.0]
