In [None]:
from pathlib import Path
import sys
project_root = next((parent for parent in [Path.cwd()] + list(Path.cwd().parents) if (parent / "pyproject.toml").exists()), Path.cwd())
sys.path.append(str(project_root))

In [None]:
superking_path = "/tmp/superking.parquet"
output_path = project_root / "refinement_finetuning.parquet"

In [None]:
from llm_python.datasets.superking import download_superking

download_superking(superking_path)

In [None]:
from llm_python.datasets.io import read_soar_parquet


superking_df = read_soar_parquet(superking_path)

print(len(superking_df[superking_df["refined_from_id"].notna()]))

In [None]:
from llm_python.utils.task_loader import get_task_loader

def filter_soar_df(
    df,
    include_subset=None,
    exclude_subset=None,
    all_train_correct=None,
    all_test_correct=None,
    any_train_correct=None,
    exclude_transductive=True,
    max_rows=None
):
    """
    Load a SOAR-format parquet file and filter rows based on subset membership and correctness.

    Args:
        parquet_path (str): Path to the parquet file.
        include_subset (str, optional): Subset name to include (from task_loader).
        exclude_subset (str, optional): Subset name to exclude (from task_loader).
        all_train_correct (bool, optional): If True, only rows where all train inputs are correct.
        all_test_correct (bool, optional): If True, only rows where all test inputs are correct.
        any_train_correct (bool, optional): If True, only rows where any train input is correct.
        exclude_transductive (bool, optional): If True, exclude transductive programs.
        max_rows (int, optional): Limit number of rows returned.

    Returns:
        pd.DataFrame: Filtered DataFrame.
    """
    task_loader = get_task_loader()

    # Subset filtering
    if include_subset:
        allowed_ids = set([id for id, _ in task_loader.get_subset_tasks(include_subset)])
        df = df[df['task_id'].isin(allowed_ids)]
    if exclude_subset:
        excluded_ids = set([id for id, _ in task_loader.get_subset_tasks(exclude_subset)])
        df = df[~df['task_id'].isin(excluded_ids)]

    # Correctness filters
    if all_train_correct is not None:
        df = df[df['correct_train_input'].apply(lambda x: all(x) == all_train_correct)]
    if all_test_correct is not None:
        df = df[df['correct_test_input'].apply(lambda x: all(x) == all_test_correct)]
    if any_train_correct is not None:
        df = df[df['correct_train_input'].apply(lambda x: any(x) == any_train_correct)]

    # Exclude transductive
    if exclude_transductive:
        df = df[~df['is_transductive']]

    # Limit rows
    if max_rows is not None:
        df = df.head(max_rows)

    return df

In [None]:
# Keep only refined examples.
df = superking_df.copy()
df = df[df["refined_from_id"].notna()]
df = filter_soar_df(df, include_subset="arc-prize-2025/training", all_train_correct=True, all_test_correct=True)


In [None]:
# Merge df with superking_df to pull in original columns based on refined_from_id
df = df.merge(
    superking_df[["row_id", "code", "predicted_train_output", "predicted_test_output"]],
    left_on="refined_from_id",
    right_on="row_id",
    how="left",
    suffixes=("", "_original")
)

In [None]:
print(len(df))
print(df.head())

In [None]:
import pyarrow as pa
import pyarrow.parquet as pq

from llm_python.datasets.schema import PARQUET_SCHEMA

refinement_schema = pa.schema(
    (
        [field for field in PARQUET_SCHEMA]
        + [
            pa.field("code_original", pa.string()),
            pa.field(
                "predicted_train_output_original",
                pa.list_(pa.list_(pa.list_(pa.int64()))),
                nullable=False,
            ),  # Required
            pa.field(
                "predicted_test_output_original",
                pa.list_(pa.list_(pa.list_(pa.int64()))),
                nullable=False,
            ),  # Required
        ]
    )
)

print(f"Saving final dataset to: {output_path}")
table = pa.Table.from_pandas(df, schema=refinement_schema)
pq.write_table(table, output_path)

In [None]:
from llm_python.datasets.statistics import analyze_dataset_statistics

analyze_dataset_statistics(df, "refinement")