In [None]:
from collections import defaultdict
from difflib import Differ
import json

import altair as alt

alt.data_transformers.enable("vegafusion")

In [None]:
with open("../results/vibevolve/rewards.jsonl", "r") as f:
    lines = f.readlines()
records = [json.loads(line) for line in lines]


In [None]:
blacklist = [
    "0c00e39a",
    "89bbc310",
    # "299e3db9",
    "641bf644",
    "58771c1b",
    "bcb3952c",
    # "f0a48f56",
]
min_timestamp = {}
cummax_reward = defaultdict(dict)
data = []
for record in sorted(records, key=lambda x: x["timestamp"]):
    run_id = record["run_id"]
    if run_id in blacklist:
        continue

    timestamp = record["timestamp"]

    if run_id not in min_timestamp:
        min_timestamp[run_id] = timestamp

    for result in record["results"]:
        opponent = "round_robin" if result["name"] == "tmp_agent" else result["name"]
        reward = result["reward"] if result["name"] == "tmp_agent" else -result["reward"]
        error = reward == -1.0

        cummax_reward[run_id][opponent] = max(cummax_reward[run_id].get(opponent, -1.0), reward)

        data.append({
            "opponent": opponent,
            "reward": reward,
            "cummax_reward": cummax_reward[run_id][opponent],
            "run_id": run_id,
            "timestamp": timestamp,
            "minutes": (timestamp - min_timestamp[run_id]) / 60,
            "error": error,
            "completion_id": record["completion_id"],
            "parent_completion_id": record["parent_completion_id"],
        })

In [None]:
alt_data = alt.Data(values=data)

x = alt.X("minutes:Q", axis=alt.Axis(title="minutes"))
color = alt.Color("run_id:N")
tooltip_cols = [
    "run_id:N",
    "completion_id:N",
    "reward:Q",
    "minutes:Q",
    "completion_id:N",
    "parent_completion_id:N",
]

y = alt.Y(
    "reward:Q",
    axis=alt.Axis(title="reward"),
    scale=alt.Scale(domain=(-1, 1)),
)
points = alt.Chart(alt_data).encode(
    x,
    y,
    color,
    tooltip=tooltip_cols,
).mark_circle(opacity=0.2)

y = alt.Y(
    "cummax_reward:Q",
    axis=alt.Axis(title="reward"),
    scale=alt.Scale(domain=(-1, 1)),
)
cummax = alt.Chart(alt_data).encode(
    x,
    y,
    color,
    tooltip=tooltip_cols,
).mark_line(strokeWidth=2)

figure = (
    (cummax + points)
    .properties(width=300, height=120)
    .facet(
        alt.Facet(
            'opponent:N',
            sort=alt.EncodingSortField("reward", op='mean', order='descending'),
            header=alt.Header(
                labelFontSize=14,
                titleFontSize=16,
            ),
        ),
        columns=3,
    )
    .configure_axis(
        labelFontSize=14,
        titleFontSize=14
    )
)
figure.interactive()

In [None]:
def src_from_history(run_id, completion_id):
    path = f"../results/vibevolve/{run_id}/{completion_id}/completion.txt"
    with open(path, "r") as f:
        completion = f.read()
    return src_from_completion(completion)

def src_from_completion(completion):
    if "```python" not in completion:
        return None
    _, src, *__ = completion.split("```python")
    if "```" not in src:
        return None
    src, *_ = src.split("```")
    return src

def show_diff(string1, string2):
    # https://www.timsanteford.com/posts/creating-a-git-like-diff-viewer-in-python-using-difflib/
    differ = Differ()
    diff = differ.compare(string1.splitlines(), string2.splitlines())
    for line in diff:
        if line.startswith("- "):
            print(f"\033[31m{line}\033[0m")  # Red for removals
        elif line.startswith("+ "):
            print(f"\033[32m{line}\033[0m")  # Green for additions
        elif line.startswith("? "):
            print(f"\033[33m{line}\033[0m")  # Yellow for hints
        else:
            print(line)

lookup = {
    (record["run_id"], record["completion_id"]): record
    for record in records
}

def get_parent(run_id, completion_id):
    return lookup[(run_id, completion_id)]["parent_completion_id"]

def get_reward(run_id, completion_id):
    return lookup[(run_id, completion_id)]["reward"]

run_id, completion_id = max(
    (key for key in lookup.keys()),
    key=lambda key: lookup[key]["reward"],
)

parent_completion_id = get_parent(run_id, completion_id)
history = [completion_id]
while parent_completion_id is not None:
    parent_completion_id = get_parent(run_id, parent_completion_id)
    history.append(parent_completion_id)
history.reverse()

for parent_id, child_id in zip(history[:-1], history[1:]):
    if parent_id is not None:
        parent_src = src_from_history(run_id, parent_id)
        parent_reward = get_reward(run_id, parent_id)
    else:
        parent_src = ""
        parent_reward = -1.0
    
    child_src = src_from_history(run_id, child_id)
    child_reward = get_reward(run_id, child_id)

    print("-" * 120)
    print(parent_id, "->", child_id)
    print(parent_reward, "->", child_reward, "+", child_reward - parent_reward)
    show_diff(parent_src, child_src)