In [1]:
import json
import random
import time
import tqdm
from collections import deque

from application.aws import AwsAPI
from application.game import Game
from application.utils import get_kill_events

# Realistic kill events test

For this test, kill events were generated as close to real ones as possible: a tree structure was generated to map events to users, each user got 0 to 6 kills randomly assigned, tree was traversed from leaves to root, so that in the end only 1 user is alive. Context window size was set to 20 messages for performance reasons, so the model generates responses based on 9 previous events and the current event. Claude v2 was used, this notebook can be utilized to test other models' generations quality on that task

In [2]:
kills = get_kill_events("data/kills.csv")

In [3]:
init_prompts = [
    (
        "Imagine you're commenting the battle royale game match. You'll be getting kill events from the game like this one: "
        '[{"victim": {"username": str},"KillInstigator": {"username": str,"Distance": float,"first_kill": bool,"used_weapon": {"type": str,"name": str},"Headshot": bool,"OneShot": bool,"num_kills": int,"previous_victims": [str]},"location": str,"num_players_alive": int}]. '
        "Sometimes you will be getting more than one event in this list. "
        "You'll need to comment on the kill events, using 3 senteces max. "
        "If there are multiple events, you can decide either to comment on all of them or comment the last of them while also keeping in mind other events. "
        "Even if there are multiple events, you should still be able to say everything in 3 sentences. "
        "You're not supposed to always use each field for the comment, but you can use them if you think they're relevant. "
        "Try to also remember previous events and if you see some patterns feel free to voice them. Understood?"
    ),
    "Understood. I'm ready to commentate on the battle royale game events.",
]

In [4]:
from application.game import Player


def _dump_messages_to_file(q, filename):
    with open(filename, "a") as f:
        f.write(f"User: {q[-2]}\n")
        f.write(f"Assistant: {q[-1]}\n\n")


def _create_event_lists(game: Game, root: Player):
    kill_events: list[list[dict]] = []
    kill_events_flat = [kill_event for kill_event in game.get_kill_event(root)]
    i = 0
    while i < len(kill_events_flat):
        kill_events.append([])
        events_num = min(random.randint(1, 4), len(kill_events_flat) - i)
        for _ in range(events_num):
            kill_events[-1].append(kill_events_flat[i])
            i += 1
    return kill_events


def _trim_queue(q: deque, context_window_size: int):
    if len(q) + 2 > context_window_size:
        for _ in range(len(init_prompts)):
            q.popleft()
        q.popleft()
        q.popleft()
        for prompt in reversed(init_prompts):
            q.appendleft(prompt)


def _get_sentences(q, aws: AwsAPI, model_id, temperature):
    roles = ("user", "assistant")
    model = aws.models_mapping[model_id]
    
    messages = [
        {
            "role": roles[i % 2],
            "content": [model.get_content(prompt)],
        }
        for i, prompt in enumerate(q)
    ]
    yield from aws.get_streamed_response(model_id, messages, time.time(), temperature)

In [5]:
def run_dialogue(model_id, region, context_window_size=20, temperature=0.9):
    aws = AwsAPI(region)
    game = Game(kills)
    root = game.create_players_tree()
    q = deque(init_prompts)
    filename = f"output/text/{int(time.time())}-{model_id}.txt"

    _dump_messages_to_file(q, filename)
    kill_events = _create_event_lists(game, root)

    elapsed_total = 0
    for kill_events_list in tqdm.tqdm(kill_events):
        _trim_queue(q, context_window_size)
        q.append(json.dumps(kill_events_list))

        start_time = time.time()
        response = []
        try:
            for sentence in _get_sentences(q, aws, model_id, temperature):
                response.append(sentence)
        except KeyboardInterrupt:
            break
        elapsed_total += time.time() - start_time

        q.append("".join(response))
        _dump_messages_to_file(q, filename)
    
    print(f"Avg. time per event: {elapsed_total / len(kill_events):.2f}s")
    # print(aws.get_bedrock_stats(model_id))

print("Testing Nova Pro...")
run_dialogue("amazon.nova-pro-v1:0", "us-east-1")
print("Testing Claude V2...")
run_dialogue("anthropic.claude-v2", "eu-central-1")


Testing Nova Pro...


100%|██████████| 52/52 [01:35<00:00,  1.84s/it]


Avg. time per event: 1.84s
Testing Claude V2...


100%|██████████| 54/54 [04:39<00:00,  5.17s/it]

Avg. time per event: 5.17s



