In [None]:
import art
import asyncio
from dotenv import load_dotenv
import re
from tqdm.asyncio import tqdm

load_dotenv()


api = art.LocalAPI()
model = await api.get_or_create_model(
    name="001", project="prisoners-dilemma", base_model="Qwen/Qwen2.5-7B-Instruct"
)
client = await model.openai_client()

num_rounds = 10
prompt = f"""
You are playing a game of prisoners' dilemma with another player.

You are given a choice between two actions:

1. Cooperate
2. Defect

The payoffs are as follows:

- If both players cooperate, you get 3 points and the other player gets 3 points.
- If one player cooperates and the other defects, the defector gets 5 points and the cooperator gets 0 points.
- If both players defect, you both get 1 point.

You will play this game {num_rounds} times with the same player.

For your first turn, would you like to cooperate or defect?
"""


async def rollout_game() -> tuple[art.Trajectory, art.Trajectory]:
    messages: tuple[art.Messages, art.Messages] = (
        [{"role": "user", "content": prompt}],
        [{"role": "user", "content": prompt}],
    )
    trajectories = (
        art.Trajectory(messages_and_choices=[*messages[0]], reward=0),
        art.Trajectory(messages_and_choices=[*messages[1]], reward=0),
    )
    for _ in range(num_rounds):
        chat_completions = await asyncio.gather(
            client.chat.completions.create(messages=messages[0], model=model.name),
            client.chat.completions.create(messages=messages[1], model=model.name),
        )
        choices = [chat_completion.choices[0] for chat_completion in chat_completions]
        messages[0].append({"role": "assistant", "content": choices[0].message.content})
        messages[1].append({"role": "assistant", "content": choices[1].message.content})
        trajectories[0].messages_and_choices.append(choices[0])
        trajectories[1].messages_and_choices.append(choices[1])
        actions = [
            (
                matches[-1]
                if (
                    matches := re.findall(
                        pattern=r"cooperate|defect",
                        string=choice.message.content or "",
                        flags=re.IGNORECASE,
                    )
                )
                else "none"
            )
            for choice in choices
        ]
        if actions[0] == "cooperate" and actions[1] == "cooperate":
            trajectories[0].reward += 3
            trajectories[1].reward += 3
        elif actions[0] == "cooperate" and actions[1] == "defect":
            trajectories[0].reward += 0
            trajectories[1].reward += 5
        elif actions[0] == "defect" and actions[1] == "cooperate":
            trajectories[0].reward += 5
            trajectories[1].reward += 0
        elif actions[0] == "defect" and actions[1] == "defect":
            trajectories[0].reward += 1
            trajectories[1].reward += 1
        else:
            # One or both players did not choose an action.
            default_rewards = {"cooperate": 3, "defect": 5, "none": 0}
            trajectories[0].reward += default_rewards[actions[0]]
            trajectories[1].reward += default_rewards[actions[1]]
        for i in range(2):
            messages[i].append(
                {
                    "role": "user",
                    "content": f"The other player responded as follows: \n\n{choices[1 - i].message.content}\n\n"
                    f"Your score is {trajectories[i].reward}. The other player's score is {trajectories[1 - i].reward}.\n\n"
                    "For the next round, would you like to cooperate or defect?",
                }
            )
            trajectories[i].messages_and_choices.append(messages[i][-1])
    return trajectories


for i in range(await model.get_step(), 1_000):
    game_trajectories = await tqdm.gather(
        *(rollout_game() for _ in range(32)), desc="rollout"
    )
    trajectory_groups = [
        art.TrajectoryGroup(
            game_trajectories[game_index][player_index]
            for game_index in range(len(game_trajectories))
        )
        for player_index in range(2)
    ]
    await model.train(
        trajectory_groups,
        config=art.TrainConfig(learning_rate=5e-5),
    )