In [2]:
from pathlib import Path

import orjson
import polars as pl
from rich.console import Console

from cogelot.structures.vima import Partition, Task

console = Console()


def process_old_style_csv(path: str) -> pl.DataFrame:
    """Process the old style CSV files and return a DataFrame with the relevant columns."""
    episodes = (
        pl.read_csv(path)
        .with_columns(
            pl.col("max_swept_obj").replace("null", None),
            pl.col("num_dragged_obj").replace("null", None),
            pl.col("num_base_obj").replace("null", None),
            pl.col("num_target_base_obj").replace("null", None),
        )
        .select(
            pl.col(
                [
                    "partition",
                    "task",
                    "is_successful_at_end",
                ]
            ),
            pl.col(
                [
                    "steps_taken",
                    "max_swept_obj",
                    "num_dragged_obj",
                    "num_base_obj",
                    "num_target_base_obj",
                ]
            ).cast(pl.Int32),
        )
        .with_columns(
            pl.when(pl.col("task").is_in(["novel_adj", "novel_adj_and_noun"]))
            .then(1)
            .otherwise(pl.col("num_dragged_obj"))
            .alias("num_dragged_obj"),
        )
        .with_columns(
            pl.when(pl.col("task").is_in(["sweep_without_touching", "sweep_without_exceeding"]))
            .then(pl.col("max_swept_obj"))
            .when(pl.col("task") == "pick_in_order_then_restore")
            .then(pl.col("num_target_base_obj"))
            .otherwise(pl.col("num_dragged_obj"))
            .alias("time_limit"),
        )
        .with_columns(
            pl.when(pl.col("is_successful_at_end"))
            .then(pl.col("time_limit").sub(pl.col("steps_taken")).ge(0))
            .otherwise(pl.lit(False))  # noqa: FBT003
            .alias("is_successful_without_mistakes"),
            pl.col("task")
            .replace(old=[x.name for x in Task], new=[x.value + 1 for x in Task])
            .cast(pl.Int16),
            pl.col("partition")
            .replace(old=[x.name for x in Partition], new=[x.value for x in Partition])
            .cast(pl.Int16),
        )
        .drop("max_swept_obj", "num_target_base_obj", "num_dragged_obj", "num_base_obj")
    )

    # Verify that there are no null values in the time_limit column
    assert episodes.select(pl.col("time_limit").is_not_null().all()).to_dict(as_series=False)[
        "time_limit"
    ][0]

    return episodes


def compute_per_task_performance(episodes_df: pl.DataFrame) -> pl.DataFrame:
    """Compute the performance of each task in the dataset."""
    return (
        episodes_df.group_by(["partition", "task"])
        .agg(
            pl.col("is_successful_at_end").sum().alias("num_successful"),
            pl.col("is_successful_without_mistakes").sum().alias("num_successful_strict"),
        )
        .join(
            episodes_df.group_by(["partition", "task"]).agg(
                pl.col("time_limit").count().alias("total_episodes")
            ),
            on=["partition", "task"],
        )
        .with_columns(
            pl.col("num_successful_strict")
            .truediv(pl.col("total_episodes"))
            .alias("percentage_successful_strict"),
            pl.col("num_successful")
            .truediv(pl.col("total_episodes"))
            .alias("percentage_successful"),
        )
        .sort(["partition", "task"])
    )


def compute_per_partition_performance(episodes_per_task_df: pl.DataFrame) -> pl.DataFrame:
    """Compute the performance of each partition in the dataset."""
    return (
        episodes_per_task_df.group_by("partition")
        .agg(
            pl.col("percentage_successful_strict")
            .mean()
            .mul(100)
            .round(1)
            .alias("percentage_successful_strict"),
            pl.col("percentage_successful")
            .mean()
            .mul(100)
            .round(1)
            .alias("percentage_successful"),
        )
        .sort("partition")
    )


def process_new_style_json(path: str) -> pl.DataFrame:
    """Process the new style JSON files and return a DataFrame with the relevant columns."""
    raw_episodes = orjson.loads(Path(path).read_text())

    episodes = pl.DataFrame(raw_episodes["data"])
    episodes.columns = raw_episodes["columns"]

    episodes = episodes.select(
        pl.col("partition", "is_successful_at_end"),
        pl.col("minimum_steps").alias("time_limit"),
        pl.col("total_steps").alias("steps_taken"),
        pl.col("task").add(1).alias("task"),
    ).with_columns(
        pl.when(pl.col("is_successful_at_end"))
        .then(pl.col("time_limit").sub(pl.col("steps_taken")).ge(0))
        .otherwise(pl.lit(False))  # noqa: FBT003
        .alias("is_successful_without_mistakes")
    )
    return episodes

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
import statistics
from decimal import Decimal

x_obj_episodes = process_old_style_csv("storage/data/x-obj-para-episodes-old.csv")
x_obj_episode_per_task = compute_per_task_performance(x_obj_episodes)

performances = x_obj_episode_per_task.select(
    pl.col("partition").cast(str).add("L").str.reverse().alias("partition"),
    pl.col("task"),
    pl.col("percentage_successful_strict").mul(100).round(1).alias("success"),
).to_dicts()

# Get all the success metrics per partition
performance_per_level = {}
for success in performances:
    if success["partition"] not in performance_per_level:
        performance_per_level[success["partition"]] = {}

    performance_per_level[success["partition"]][success["task"]] = Decimal(
        success["success"]
    ).quantize(Decimal("1.0"))


for level, task_success in performance_per_level.items():
    task_success = {  # noqa: PLW2901
        task.value + 1: task_success.get(task.value + 1, "{---}") for task in Task
    }
    average = statistics.mean(
        task_value for task_value in task_success.values() if isinstance(task_value, Decimal)
    )
    print_line = " & ".join(map(str, task_success.values()))
    print_line += " & " + str(average.quantize(Decimal("1.0")))
    print_line = r"\bfseries " + f"{level} & {print_line} " + r"\\"
    # printing_lines.append(print_line)

\bfseries L1 & 99.0 & 99.0 & 99.0 & 83.5 & 0.0 & 97.5 & 98.0 & {---} & 11.5 & {---} & 92.0 & 97.5 & {---} & {---} & 96.0 & 41.5 & 0.0 & 70.3 \\
\bfseries L2 & 97.5 & 98.0 & 99.5 & 78.0 & 0.0 & 98.0 & 99.0 & {---} & 13.5 & {---} & 91.0 & 91.5 & {---} & {---} & 94.5 & 46.0 & 0.0 & 69.7 \\
\bfseries L3 & 98.0 & 97.0 & 99.5 & 77.5 & 0.0 & 97.5 & 95.5 & {---} & 15.5 & {---} & 92.5 & {---} & {---} & {---} & 94.5 & 47.5 & 0.0 & 67.9 \\
\bfseries L4 & {---} & {---} & {---} & {---} & {---} & {---} & {---} & 92.0 & {---} & 0.0 & {---} & {---} & 0.0 & 95.0 & {---} & {---} & {---} & 46.8 \\


In [4]:
d_obj_episodes = process_new_style_json("storage/data/episodes-d-obj.json")
d_obj_episode_per_task = compute_per_task_performance(d_obj_episodes)
d_obj_episode_per_partition = compute_per_partition_performance(d_obj_episode_per_task)
d_obj_episode_per_partition

partition,percentage_successful_strict,percentage_successful
i64,f64,f64
1,72.2,80.4
2,71.4,78.2
3,65.7,74.8
4,45.5,49.0


In [5]:
d_ptch_episodes = process_new_style_json("storage/data/episodes-d-ptch.json")
d_ptch_episode_per_task = compute_per_task_performance(d_ptch_episodes)
d_ptch_episode_per_partition = compute_per_partition_performance(d_ptch_episode_per_task)
d_ptch_episode_per_partition

partition,percentage_successful_strict,percentage_successful
i64,f64,f64
1,61.2,67.1
2,57.6,62.8
3,46.0,52.0
4,12.9,19.8


In [6]:
x_ptch_episodes = process_new_style_json("storage/data/episodes-x-ptch.json")
x_ptch_episode_per_task = compute_per_task_performance(x_ptch_episodes)
x_ptch_episode_per_partition = compute_per_partition_performance(x_ptch_episode_per_task)
x_ptch_episode_per_partition

partition,percentage_successful_strict,percentage_successful
i64,f64,f64
1,58.2,63.9
2,57.3,63.0
3,44.3,49.5
4,15.9,20.4


In [7]:
x_obj_episodes_no_prompt = process_old_style_csv("storage/data/episodes-x-obj-no-prompt.csv")
x_obj_episode_per_task_no_prompt = compute_per_task_performance(x_obj_episodes_no_prompt)
x_obj_episode_per_partition_no_prompt = compute_per_partition_performance(
    x_obj_episode_per_task_no_prompt
)
x_obj_episode_per_partition_no_prompt

partition,percentage_successful_strict,percentage_successful
i16,f64,f64
1,45.1,52.3
2,45.4,52.5
3,34.5,41.7
4,26.1,35.8


In [55]:
import wandb


def download_episodes_from_wandb(run_id: str) -> str:
    """Download the episodes from the given run id."""
    run = wandb.Api().run(f"pyop/cogelot-evaluation/{run_id}")
    console.print("Run:", run.id)
    console.print("Name:", run.name)

    table = next(run.logged_artifacts())
    assert "episodes" in table.name
    table_path = table.download(root=f"./storage/artifacts/{run_id}") + "/episodes.table.json"
    return table_path


def print_for_paper(episodes_per_partition: pl.DataFrame) -> None:
    """Print the strict numbers for the paper."""
    averages = (
        episodes_per_partition.drop("percentage_successful")
        .select(
            pl.col("partition").cast(str).add("L").str.reverse(),
            pl.col("percentage_successful_strict"),
        )
        .drop("partition")
        .to_numpy()
        .flatten()
        .tolist()
    )
    console.print(" & ".join(map(str, averages)), r"\\")

In [75]:
run_id = "sm0t3gea"
episodes_path = download_episodes_from_wandb(run_id)
episodes = process_new_style_json(episodes_path)
episode_per_task = compute_per_task_performance(episodes)
episode_per_task.filter(pl.col("partition") == 1)
# episode_per_partition = compute_per_partition_performance(episode_per_task)
# episode_per_partition

[34m[1mwandb[0m:   1 of 1 files downloaded.  


partition,task,num_successful,num_successful_strict,total_episodes,percentage_successful_strict,percentage_successful
i64,i64,u32,u32,u32,f64,f64
1,1,200,198,200,0.99,1.0
1,2,199,198,200,0.99,0.995
1,3,199,199,200,0.995,0.995
1,4,194,192,200,0.96,0.97
1,5,17,6,200,0.03,0.085
…,…,…,…,…,…,…
1,11,183,182,200,0.91,0.915
1,12,192,178,200,0.89,0.96
1,15,193,193,200,0.965,0.965
1,16,99,92,200,0.46,0.495


In [70]:
from decimal import Decimal

performance_per_level = {}

performances = episode_per_task.select(
    pl.col("partition").cast(str).add("L").str.reverse().alias("partition"),
    pl.col("task"),
    pl.col("percentage_successful_strict").mul(100).round(1).alias("success"),
).to_dicts()

for success in performances:
    if success["partition"] not in performance_per_level:
        performance_per_level[success["partition"]] = {}

    performance_per_level[success["partition"]][success["task"]] = Decimal(
        success["success"]
    ).quantize(Decimal("1.0"))
performance_per_level

{'L1': {1: Decimal('90.0'),
  2: Decimal('38.5'),
  3: Decimal('8.5'),
  4: Decimal('9.0'),
  5: Decimal('1.0'),
  6: Decimal('43.5'),
  7: Decimal('74.0'),
  9: Decimal('3.5'),
  11: Decimal('67.5'),
  12: Decimal('82.5'),
  15: Decimal('49.0'),
  16: Decimal('19.0'),
  17: Decimal('0.0')},
 'L2': {1: Decimal('80.5'),
  2: Decimal('27.0'),
  3: Decimal('4.5'),
  4: Decimal('8.0'),
  5: Decimal('1.0'),
  6: Decimal('49.5'),
  7: Decimal('74.5'),
  9: Decimal('3.0'),
  11: Decimal('75.5'),
  12: Decimal('77.0'),
  15: Decimal('50.0'),
  16: Decimal('23.5'),
  17: Decimal('0.0')},
 'L3': {1: Decimal('62.5'),
  2: Decimal('35.0'),
  3: Decimal('4.0'),
  4: Decimal('11.0'),
  5: Decimal('0.5'),
  6: Decimal('42.0'),
  7: Decimal('43.5'),
  9: Decimal('3.5'),
  11: Decimal('71.5'),
  15: Decimal('25.5'),
  16: Decimal('14.5'),
  17: Decimal('0.0')},
 'L4': {8: Decimal('17.5'),
  10: Decimal('1.5'),
  13: Decimal('0.0'),
  14: Decimal('26.0')}}