In [1]:
import json
import copy
import time
import tqdm
from collections import deque

from application.models import ModelID
from application.aws import AwsAPI
from application.game import Game
from application.utils import (
    get_game_events,
    dump_message_to_file,
    generate_batch_sizes,
    trim_queue,
)

# Realistic game events test

For this test, game events were generated as close to real ones as possible: a tree structure was generated to map events to users, each user got 0 to 6 kills randomly assigned, tree was traversed from leaves to root, so that in the end only 1 user is alive. Context window size was set to 20 messages for performance reasons, so the model generates responses based on 9 previous events and the current event. Claude v2 was used, this notebook can be utilized to test other models' generations quality on that task

In [2]:
game_events = get_game_events("data/events.csv")

In [3]:
init_prompts = [
    (
        "Imagine you're commenting the battle royale game match. You'll be getting kill events from the game like this one: "
        '[{"victim": {"username": str},"KillInstigator": {"username": str,"Distance": float,"first_kill": bool,"used_weapon": {"type": str,"name": str},"Headshot": bool,"OneShot": bool,"num_kills": int,"previous_victims": [str]},"location": str,"num_players_alive": int}]. '
        "Sometimes you will be getting more than one event in this list. "
        "You'll need to comment on the kill events, using 3 senteces max. "
        "If there are multiple events, you can decide either to comment on all of them or comment the last of them while also keeping in mind other events. "
        "Even if there are multiple events, you should still be able to say everything in 3 sentences. "
        "You're not supposed to always use each field for the comment, but you can use them if you think they're relevant. "
        "Try to also remember previous events and if you see some patterns feel free to voice them. Understood?"
    ),
    "Understood. I'm ready to commentate on the battle royale game events.",
]

In [4]:
def _get_sentences(q, aws: AwsAPI, model_id: ModelID, temperature: float):
    roles = ("user", "assistant")
    model = aws.models_mapping[model_id]
    
    messages = [
        {
            "role": roles[i % 2],
            "content": [model.get_content(prompt)],
        }
        for i, prompt in enumerate(q)
    ]
    yield from aws.get_streamed_response(model_id, messages, time.time(), temperature)

In [5]:
game = Game(game_events)
winner = game.create_players_tree()
num_events_total = game.num_players_total - 1
batch_sizes = generate_batch_sizes(num_events_total)


def run_dialogue(model_id, region, context_window_size=20, temperature=0.9):
    aws = AwsAPI(region)
    q = deque(init_prompts)
    events_nums = copy.copy(batch_sizes)
    filename = f"output/text/{int(time.time())}-{model_id.value}.txt"

    dump_message_to_file("User", q[-2], filename)
    dump_message_to_file("Assistant", q[-1], filename)
    game.start()

    elapsed_total = 0
    events_num = events_nums.pop()
    game_events_list = []
    for event in tqdm.tqdm(game.get_events(winner), total=num_events_total):
        if events_num > 0:
            game_events_list.append(event)
            events_num -= 1
            continue
        else:
            events_num = events_nums.pop()
            game_events_list = [event]
            events_num -= 1

        trim_queue(q, context_window_size, init_prompts)

        q.append(json.dumps(game_events_list))
        dump_message_to_file("User", q[-1], filename)

        start_time = time.time()
        response = []
        try:
            for sentence in _get_sentences(q, aws, model_id, temperature):
                response.append(sentence)
        except KeyboardInterrupt:
            break
        elapsed_total += time.time() - start_time

        q.append("".join(response))
        dump_message_to_file("Assistant", q[-1], filename)

    print(f"Avg. time per event: {elapsed_total / len(batch_sizes):.2f}s")
    # print(aws.get_bedrock_stats(model_id))


print("Testing Claude V3.5 Sonnet...")
run_dialogue(ModelID.CLAUDE_SONNET, "eu-central-1")
print("Testing Nova Pro...")
run_dialogue(ModelID.NOVA_PRO, "us-east-1")
print("Testing Claude V2...")
run_dialogue(ModelID.CLAUDE_V2, "eu-central-1")

Testing Claude V3.5 Sonnet...


100%|██████████| 130/130 [03:35<00:00,  1.65s/it]


Avg. time per event: 3.71s
Testing Nova Pro...


100%|██████████| 130/130 [01:36<00:00,  1.35it/s]


Avg. time per event: 1.66s
Testing Claude V2...


100%|██████████| 130/130 [03:42<00:00,  1.71s/it]

Avg. time per event: 3.83s



