In [1]:
from pathlib import Path
import sys

# Load the annotated data
project_root = Path.cwd()
while not (project_root / "pyproject.toml").exists() and project_root != project_root.parent:
    project_root = project_root.parent
sys.path.insert(0, str(project_root))

In [2]:
from llm_python.datasets.io import read_soar_parquet

df = read_soar_parquet(project_root / "experimental/transduction/transductive_train_filtered.parquet")

In [3]:
from llm_python.transduction.code_classifier import CodeTransductionClassifier
from llm_python.utils.task_loader import TaskLoader

task_loader = TaskLoader()

df["groundtruth_transductive_prediction"] = (
    df["is_transductive"]
)
# Compute current transductive classification.
classifier = CodeTransductionClassifier()
df["current_transductive_prediction"] = df.apply(
    lambda row: classifier.is_transductive(
        row["code"], task_loader.get_task(row["task_id"])
    )[0],
    axis=1,
)

Loading arc-prize-2024...
  Training: 400 tasks
  Evaluation: 400 tasks
  Test: 100 tasks
Loading arc-prize-2025...
  Training: 1000 tasks
  Evaluation: 120 tasks
  Test: 240 tasks




In [4]:
stats = {
    "groundtruth_transductive_prediction_true": df["groundtruth_transductive_prediction"].sum(),
    "groundtruth_transductive_prediction_false": (~df["groundtruth_transductive_prediction"]).sum(),
    "current_transductive_prediction_true": df["current_transductive_prediction"].sum(),
    "current_transductive_prediction_false": (~df["current_transductive_prediction"]).sum(),
}
for k, v in stats.items():
    print(f"{k}: {v}")


groundtruth_transductive_prediction_true: 123
groundtruth_transductive_prediction_false: 264
current_transductive_prediction_true: 120
current_transductive_prediction_false: 267


In [5]:
import pandas as pd

# Split off a 10% eval set
eval_frac = 0.1
eval_size = int(len(df) * eval_frac)
eval_df = df.iloc[:eval_size].reset_index(drop=True)
train_df = df.iloc[eval_size:].reset_index(drop=True)

# Do NOT reassign df; keep train_df and eval_df separate

In [6]:
import re
import ast

def extract_literals_ast(program_code):
    """Extract numeric literals using AST parsing"""
    try:
        tree = ast.parse(program_code)
        literals = []

        class LiteralVisitor(ast.NodeVisitor):
            def visit_Constant(self, node):
                if isinstance(node.value, (int, float)):
                    literals.append(node.value)
                self.generic_visit(node)

            def visit_Num(self, node):
                literals.append(node.n)
                self.generic_visit(node)

        visitor = LiteralVisitor()
        visitor.visit(tree)
        return literals
    except SyntaxError:
        numeric_literals = re.findall(r'\b\d+(?:\.\d+)?\b', program_code)
        return [float(lit) if '.' in lit else int(lit) for lit in numeric_literals]


def extract_ast_features(program):
    """Extract AST structural features"""
    try:
        tree = ast.parse(program)
        counts = {
            'function_definitions': 0, 'for_loops': 0, 'while_loops': 0,
            'if_statements': 0, 'assignments': 0, 'comparisons': 0,
            'binary_ops': 0, 'subscripts': 0, 'method_calls': 0,
            'list_comprehensions': 0, 'lambda_functions': 0,
            'try_statements': 0, 'return_statements': 0
        }

        class ASTVisitor(ast.NodeVisitor):
            def visit_FunctionDef(self, node):
                counts['function_definitions'] += 1
                self.generic_visit(node)
            
            def visit_For(self, node):
                counts['for_loops'] += 1
                self.generic_visit(node)
            
            def visit_While(self, node):
                counts['while_loops'] += 1
                self.generic_visit(node)
            
            def visit_If(self, node):
                counts['if_statements'] += 1
                self.generic_visit(node)
            
            def visit_Assign(self, node):
                counts['assignments'] += 1
                self.generic_visit(node)
            
            def visit_Compare(self, node):
                counts['comparisons'] += 1
                self.generic_visit(node)
            
            def visit_BinOp(self, node):
                counts['binary_ops'] += 1
                self.generic_visit(node)
            
            def visit_Subscript(self, node):
                counts['subscripts'] += 1
                self.generic_visit(node)
            
            def visit_Call(self, node):
                counts['method_calls'] += 1
                self.generic_visit(node)
            
            def visit_ListComp(self, node):
                counts['list_comprehensions'] += 1
                self.generic_visit(node)
            
            def visit_Lambda(self, node):
                counts['lambda_functions'] += 1
                self.generic_visit(node)
            
            def visit_Try(self, node):
                counts['try_statements'] += 1
                self.generic_visit(node)
            
            def visit_Return(self, node):
                counts['return_statements'] += 1
                self.generic_visit(node)

        visitor = ASTVisitor()
        visitor.visit(tree)
        return counts
    except SyntaxError:
        return {key: 0 for key in ['function_definitions', 'for_loops', 'while_loops', 'if_statements',
                                    'assignments', 'comparisons', 'binary_ops', 'subscripts', 'method_calls',
                                    'list_comprehensions', 'lambda_functions', 'try_statements', 'return_statements']}

def extract_code_features(program):
    # Extract all features
    literals = extract_literals_ast(program)
    ast_features = extract_ast_features(program)

    # Calculate features
    features = {
        'total_chars': len(program),
        'total_lines': len(program.split('\n')),
        'avg_line_length': len(program) / max(1, len(program.split('\n'))),
        'indentation_variance': np.var([len(line) - len(line.lstrip()) for line in program.split('\n')]),
        'max_indentation': max([len(line) - len(line.lstrip()) for line in program.split('\n')]),
        'comments': len(re.findall(r'#.*', program)),
        'docstrings': len(re.findall(r'""".*?"""', program, re.DOTALL)),
        'total_literals': len(literals),
        'literals_over_9': len([lit for lit in literals if lit > 9]),
        'literals_over_99': len([lit for lit in literals if lit > 99]),
        'max_literal': max(literals) if literals else 0,
        'unique_literals': len(set(literals)),
        'zero_literals': len([lit for lit in literals if lit == 0]),
        'single_digit_literals': len([lit for lit in literals if 0 <= lit <= 9]),
        'elif_chains': len(re.findall(r'elif\b', program)),
        'nested_loops': len(re.findall(r'for.*?:\s*.*?for.*?:', program, re.DOTALL)),
        'range_calls': len(re.findall(r'\brange\s*\(', program)),
        'len_calls': len(re.findall(r'\blen\s*\(', program)),
        'enumerate_calls': len(re.findall(r'\benumerate\s*\(', program)),
        'zip_calls': len(re.findall(r'\bzip\s*\(', program)),
        'numpy_usage': len(re.findall(r'\bnp\.|numpy\.', program)),
        'grid_shape_access': len(re.findall(r'\.shape\b|len\s*\(\s*grid\s*\)|len\s*\(\s*\w+\[0\]\s*\)', program)),
        'coordinate_patterns': len(re.findall(r'\[\s*\d+\s*\]\s*\[\s*\d+\s*\]', program)),
        'hardcoded_coordinates': len(re.findall(r'\[\s*\d+\s*\]', program)),
        'specific_conditionals': len(re.findall(r'==\s*\d+|!=\s*\d+|>\s*\d+|<\s*\d+|>=\s*\d+|<=\s*\d+', program)),
        'brackets_count': program.count('[') + program.count(']'),
        'general_loops': ast_features['for_loops'] + ast_features['while_loops'],
        'generic_variables': len(re.findall(r'\b(i|j|k|x|y|row|col|idx)\b', program)),
        'shape_operations': len(re.findall(r'\.shape|len\(.*\)', program)),
        'mathematical_ops': len(re.findall(r'\+|\-|\*|\/|\%|\*\*', program)),
        'array_creation': len(re.findall(r'\[\s*\[|np\.array|np\.zeros|np\.ones', program)),
        'imports': len(re.findall(r'^\s*import|^\s*from', program, re.MULTILINE)),
    }

    # Add AST features
    features.update(ast_features)

    # return features
    # return features
    # Extract and scale only the active features
    feature_names = ['max_indentation', 'max_literal', 'unique_literals', 'function_definitions', 'if_statements', 'method_calls', 'elif_chains', 'enumerate_calls', 'coordinate_patterns', 'specific_conditionals', 'brackets_count', 'generic_variables', 'array_creation']
    return {k: features[k] for k in feature_names if k in features}

In [8]:
# Extract code features for each row and analyze correlation with groundtruth label (top 15 features only)
import pandas as pd
from scipy.stats import pointbiserialr
import numpy as np

feature_dicts = train_df["code"].apply(extract_code_features)
features_df = pd.DataFrame(list(feature_dicts))

# Concatenate features with groundtruth label
analysis_df = pd.concat([features_df, train_df[["groundtruth_transductive_prediction"]].reset_index(drop=True)], axis=1)

# Compute correlation for each feature with groundtruth label
correlations = {}
for col in features_df.columns:
    if analysis_df[col].dtype in [int, float]:
        try:
            corr, pval = pointbiserialr(analysis_df["groundtruth_transductive_prediction"], analysis_df[col])
            correlations[col] = {"correlation": corr, "p_value": pval}
        except Exception as e:
            correlations[col] = {"correlation": None, "p_value": None, "error": str(e)}
    else:
        correlations[col] = {"correlation": None, "p_value": None, "error": "Non-numeric feature"}

# Print sorted correlations
sorted_corr = sorted(correlations.items(), key=lambda x: abs(x[1]["correlation"] or 0), reverse=True)
# Print sorted correlations
def safe_fmt(val, fmt):
    return fmt % val if val is not None else "None"

for feat, vals in sorted_corr:
    corr_str = f"{vals['correlation']:.3f}" if vals['correlation'] is not None else "None"
    pval_str = f"{vals['p_value']:.3g}" if vals['p_value'] is not None else "None"
    print(f"{feat}: correlation={corr_str}, p-value={pval_str}")


unique_literals: correlation=0.815, p-value=4.4e-84
elif_chains: correlation=0.616, p-value=7e-38
if_statements: correlation=0.524, p-value=5.77e-26
specific_conditionals: correlation=0.483, p-value=7.92e-22
brackets_count: correlation=0.325, p-value=4.8e-10
array_creation: correlation=0.317, p-value=1.3e-09
generic_variables: correlation=-0.268, p-value=3.76e-07
max_indentation: correlation=-0.243, p-value=4.29e-06
coordinate_patterns: correlation=0.146, p-value=0.00629
function_definitions: correlation=-0.079, p-value=0.142
enumerate_calls: correlation=-0.074, p-value=0.169
method_calls: correlation=0.050, p-value=0.354
max_literal: correlation=None, p-value=None


In [9]:
# Train a logistic regression classifier and compare to current transductive classifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report

X = features_df.drop(columns=[col for col in features_df.columns if col not in analysis_df.columns or col == "groundtruth_transductive_prediction"]).values
y = analysis_df["groundtruth_transductive_prediction"].values

clf = LogisticRegression(max_iter=10000)

clf.fit(X, y)

y_pred = clf.predict(X)
logreg_acc = accuracy_score(y, y_pred)
print(f"Logistic Regression accuracy (train): {logreg_acc:.3f}")
print(classification_report(y, y_pred))

# Logistic Regression on eval split
eval_feature_dicts = eval_df["code"].apply(extract_code_features)
eval_features_df = pd.DataFrame(list(eval_feature_dicts))
eval_X = eval_features_df.drop(columns=[col for col in features_df.columns if col not in analysis_df.columns or col == "groundtruth_transductive_prediction"]).values
eval_y = eval_df["groundtruth_transductive_prediction"].values
eval_y_pred = clf.predict(eval_X)
eval_logreg_acc = accuracy_score(eval_y, eval_y_pred)
print(f"Logistic Regression accuracy (eval): {eval_logreg_acc:.3f}")
print(classification_report(eval_y, eval_y_pred))

# Logistic Regression on full split
eval_feature_dicts = df["code"].apply(extract_code_features)
eval_features_df = pd.DataFrame(list(eval_feature_dicts))
eval_X = eval_features_df.drop(columns=[col for col in features_df.columns if col not in analysis_df.columns or col == "groundtruth_transductive_prediction"]).values
eval_y = df["groundtruth_transductive_prediction"].values
eval_y_pred = clf.predict(eval_X)
eval_logreg_acc = accuracy_score(eval_y, eval_y_pred)
print(f"Logistic Regression accuracy (original): {eval_logreg_acc:.3f}")
print(classification_report(eval_y, eval_y_pred))


Logistic Regression accuracy (train): 0.943
              precision    recall  f1-score   support

         0.0       0.95      0.96      0.96       237
         1.0       0.92      0.90      0.91       112

    accuracy                           0.94       349
   macro avg       0.94      0.93      0.93       349
weighted avg       0.94      0.94      0.94       349

Logistic Regression accuracy (eval): 0.921
              precision    recall  f1-score   support

         0.0       0.93      0.96      0.95        27
         1.0       0.90      0.82      0.86        11

    accuracy                           0.92        38
   macro avg       0.91      0.89      0.90        38
weighted avg       0.92      0.92      0.92        38

Logistic Regression accuracy (original): 0.941
              precision    recall  f1-score   support

         0.0       0.95      0.96      0.96       264
         1.0       0.92      0.89      0.91       123

    accuracy                           0.94     

In [10]:
# Legacy model on eval dataset
legacy_orig_y = eval_df["groundtruth_transductive_prediction"].values
legacy_orig_logreg_pred = eval_df["current_transductive_prediction"].values
legacy_orig_logreg_acc = accuracy_score(legacy_orig_y, legacy_orig_logreg_pred)
print(f"Logistic Regression accuracy (eval): {legacy_orig_logreg_acc:.3f}")
print(classification_report(legacy_orig_y, legacy_orig_logreg_pred))

# Legacy model on original dataset
legacy_orig_y = df["groundtruth_transductive_prediction"].values
legacy_orig_logreg_pred = df["current_transductive_prediction"].values
legacy_orig_logreg_acc = accuracy_score(legacy_orig_y, legacy_orig_logreg_pred)
print(f"Logistic Regression accuracy (original): {legacy_orig_logreg_acc:.3f}")
print(classification_report(legacy_orig_y, legacy_orig_logreg_pred))

Logistic Regression accuracy (eval): 0.921
              precision    recall  f1-score   support

         0.0       0.93      0.96      0.95        27
         1.0       0.90      0.82      0.86        11

    accuracy                           0.92        38
   macro avg       0.91      0.89      0.90        38
weighted avg       0.92      0.92      0.92        38

Logistic Regression accuracy (original): 0.941
              precision    recall  f1-score   support

         0.0       0.95      0.96      0.96       264
         1.0       0.92      0.89      0.91       123

    accuracy                           0.94       387
   macro avg       0.93      0.93      0.93       387
weighted avg       0.94      0.94      0.94       387



In [None]:
# Save the trained logistic regression model weights to disk (code_classifier_v2)
import joblib
model_path = project_root / "llm_python/transduction/code_classifier_v2.joblib"
joblib.dump(clf, model_path)
print(f"Saved code_classifier_v2 model to {model_path}")

In [None]:
# Print highest and lowest confidence examples from eval_df for spot checking
eval_feature_dicts = eval_df["code"].apply(extract_code_features)
eval_features_df = pd.DataFrame(list(eval_feature_dicts))
eval_X = eval_features_df.values
eval_probs = clf.predict_proba(eval_X)[:, 1]

# Get indices of highest and lowest confidence predictions
n = 5
high_conf_idx = eval_probs.argsort()[-n:][::-1]
low_conf_idx = eval_probs.argsort()[:n]

print("Highest confidence examples:")
for idx in high_conf_idx:
    print(f"Index: {idx}, Prob: {eval_probs[idx]:.4f}, True label: {eval_df.iloc[idx]['groundtruth_transductive_prediction']}")
    print(eval_df.iloc[idx]['code'])
    print("-" * 40)

print("Lowest confidence examples:")
for idx in low_conf_idx:
    print(f"Index: {idx}, Prob: {eval_probs[idx]:.4f}, True label: {eval_df.iloc[idx]['groundtruth_transductive_prediction']}")
    print(eval_df.iloc[idx]['code'])
    print("-" * 40)