In [None]:
from pathlib import Path
import sys
project_root = next((parent for parent in [Path.cwd()] + list(Path.cwd().parents) if (parent / "pyproject.toml").exists()), Path.cwd())
sys.path.append(str(project_root))

In [None]:
output_path = "/tmp/superking.parquet"

In [None]:
from llm_python.datasets.superking import download_superking

download_superking(output_path)

In [None]:
import pandas as pd
from llm_python.utils.numpy import convert_numpy_types

df = pd.read_parquet(output_path)
df["predicted_train_output"] = df["predicted_train_output"].apply(convert_numpy_types)
df["correct_train_input"] = df["correct_train_input"].apply(convert_numpy_types)
df["predicted_test_output"] = df["predicted_test_output"].apply(convert_numpy_types)
df["correct_test_input"] = df["correct_test_input"].apply(convert_numpy_types)

In [None]:
from llm_python.utils.arc_tester import ArcTester
from llm_python.utils.task_loader import get_task_loader
from llm_python.datasets.validation import validate_soar_dataframe
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import time

print(f"Original dataset shape: {df.shape}")
print(f"Original columns: {df.columns.tolist()}")

task_loader = get_task_loader()

def process_single_row(row_data):
    """Process a single row - this function will be called in parallel"""
    idx, row = row_data
    try:
        # Create instances inside the worker process
        arc_tester = ArcTester()
        
        result = arc_tester.test_program(
            row["code"], task_loader.get_task(row["task_id"])
        )
        
        # Create corrected row with actual values from arc_tester
        corrected_row = row.copy()
        corrected_row["predicted_train_output"] = result.train_outputs
        corrected_row["predicted_test_output"] = result.test_outputs
        corrected_row["correct_train_input"] = result.correct_train_input
        corrected_row["correct_test_input"] = result.correct_test_input
        
        return ('success', idx, corrected_row)
        
    except Exception as e:
        return ('failed', idx, str(e))

# Determine optimal number of workers
num_workers = min(mp.cpu_count(), 8)  # Don't use too many to avoid memory issues
print(f"Using {num_workers} parallel workers")

# Process in batches for better memory management
batch_size = 5000  # Larger batches since we're using parallel processing
total_rows = len(df)
all_corrected_rows = []
all_failed_indices = []

print(f"Processing {total_rows} rows in batches of {batch_size}")

start_time = time.time()

for batch_start in range(0, total_rows, batch_size):
    batch_end = min(batch_start + batch_size, total_rows)
    batch_df = df.iloc[batch_start:batch_end].copy()
    
    print(f"\nProcessing batch {batch_start//batch_size + 1}: rows {batch_start} to {batch_end-1}")
    
    # Prepare data for parallel processing
    row_data = [(idx + batch_start, row) for idx, (_, row) in enumerate(batch_df.iterrows())]
    
    batch_corrected_rows = []
    batch_failed_indices = []
    
    # Process batch in parallel with progress bar
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Submit all tasks
        future_to_idx = {executor.submit(process_single_row, data): data[0] 
                        for data in row_data}
        
        # Process results with progress bar
        with tqdm(total=len(row_data), desc=f"Batch {batch_start//batch_size + 1}") as pbar:
            for future in as_completed(future_to_idx):
                result_type, idx, result_data = future.result()
                
                if result_type == 'success':
                    batch_corrected_rows.append(result_data)
                else:
                    batch_failed_indices.append(idx)
                    if len(batch_failed_indices) <= 5:  # Only print first few errors per batch
                        print(f"    Failed row {idx}: {result_data}")
                
                pbar.update(1)
    
    # Add batch results to overall results
    all_corrected_rows.extend(batch_corrected_rows)
    all_failed_indices.extend(batch_failed_indices)
    
    elapsed = time.time() - start_time
    processed_so_far = batch_end
    rate = processed_so_far / elapsed
    remaining = total_rows - processed_so_far
    eta = remaining / rate if rate > 0 else 0
    
    print(f"  Batch completed: {len(batch_corrected_rows)} successful, {len(batch_failed_indices)} failed")
    print(f"  Total so far: {len(all_corrected_rows)} successful, {len(all_failed_indices)} failed")
    print(f"  Rate: {rate:.1f} rows/sec, ETA: {eta/60:.1f} minutes")

print(f"\nProcessing complete!")
print(f"Total successful rows: {len(all_corrected_rows)}")
print(f"Total failed rows: {len(all_failed_indices)}")
print(f"Total time: {(time.time() - start_time)/60:.1f} minutes")


In [None]:
# Create final cleaned dataframe and filter out invalid rows

# Force reload the validation module
import sys
if 'llm_python.datasets.validation' in sys.modules:
    del sys.modules['llm_python.datasets.validation']

from llm_python.datasets.validation import validate_soar_sample, validate_soar_dataframe
from llm_python.programsdb.schema import PARQUET_SCHEMA

if all_corrected_rows:
    df_cleaned = pd.DataFrame(all_corrected_rows)
    print(f"Original cleaned dataset shape: {df_cleaned.shape}")
    
    # Validate each row and keep only valid ones
    valid_rows = []
    invalid_count = 0
    
    for i, row in df_cleaned.iterrows():
        is_valid, error_msg = validate_soar_sample(row.to_dict())
        if is_valid:
            valid_rows.append(row.to_dict())
        else:
            invalid_count += 1
            if invalid_count <= 5:  # Print first few errors
                print(f"Row {i} invalid: {error_msg}")
    
    print(f"\nValidation results:")
    print(f"Valid rows: {len(valid_rows)}")
    print(f"Invalid rows: {invalid_count}")
    
    if valid_rows:
        # Create new dataframe with only valid rows
        df_final = pd.DataFrame(valid_rows)
        print(f"Final dataset shape after filtering: {df_final.shape}")
        
        # Save the final cleaned dataset
        output_file = "/tmp/superking_cleaned_v2.parquet"
        df_final.to_parquet(output_file, schema=PARQUET_SCHEMA)
        print(f"Saved final cleaned dataset to {output_file}")
        
        # Double-check validation on the final dataset
        is_valid, validation_message = validate_soar_dataframe(df_final)
        print(f"Final validation result: {validation_message}")
    else:
        print("No valid rows found!")
else:
    print("No rows were successfully processed!")