In [None]:
from pathlib import Path
import sys
project_root = next((parent for parent in [Path.cwd()] + list(Path.cwd().parents) if (parent / "pyproject.toml").exists()), Path.cwd())
sys.path.append(str(project_root))

In [None]:
output_path = project_root / "refinement_finetuning.parquet"

In [None]:
from llm_python.datasets.superking import load_superking

superking_df = load_superking()

In [None]:
print(len(superking_df[superking_df["refined_from_id"].notna()]))

In [None]:
# Keep only refined examples.
from llm_python.datasets.query import filter_soar_df


df = superking_df.copy()
df = df[df["refined_from_id"].notna()]
df = filter_soar_df(
    df,
    include_subset="arc-prize-2024/training",
    all_train_correct=True,
    all_test_correct=True,
)


In [None]:
# Merge df with superking_df to pull in original columns based on refined_from_id
df = df.merge(
    superking_df[["row_id", "code", "predicted_train_output", "predicted_test_output"]],
    left_on="refined_from_id",
    right_on="row_id",
    how="left",
    suffixes=("", "_original")
)

In [None]:
import numpy as np

from llm_python.datasets.query import sample_by_task

df["correct_train_input_count"] = df["correct_train_input"].apply(lambda x: np.sum(x))
df["correct_test_input_count"] = df["correct_test_input"].apply(lambda x: np.sum(x))
df["code_length"] = df["code"].str.len()

df = sample_by_task(
    df,
    sort_keys=["correct_test_input_count", "correct_train_input_count", "code_length"],
    sort_ascending=[False, False, True],
    task_limit=10,
)

In [None]:
print(len(df))
print(df.head())

In [None]:
from llm_python.datasets.io import write_soar_parquet
from llm_python.datasets.schema import REFINEMENT_PARQUET_SCHEMA

print(f"Saving final dataset to: {output_path}")
write_soar_parquet(df, output_path, schema=REFINEMENT_PARQUET_SCHEMA)

In [None]:
from llm_python.datasets.statistics import analyze_dataset_statistics

analyze_dataset_statistics(df, "refinement")

In [None]:
sample = df[["code_original", "code"]].sample(n=10, random_state=42)
for idx, row in sample.iterrows():
    print(f"Row {idx}:")
    print("Original code:\n", row["code_original"])
    print("Refined code:\n", row["code"])
    print("-" * 80)