In [None]:
# """
# gsm8k_validator_v2.py
# ======================

# Performs a flexible, score-based validation of multiple LLM-generated `solve()`
# functions for a given GSM8K problem. Instead of a rigid filter, this script
# analyzes all pairs of models to find the most robust and comprehensive consensus.

# The final output is a confidence score (0.0-1.0+) that reflects:
# 1.  The completeness of parameter alignment between models.
# 2.  The semantic clarity of the aligned parameters.
# 3.  The level of consensus (number of models in agreement).

# Core Logic:
# -----------
# 1.  **Parse & Pre-filter (UT-0):** All generated Python files for a problem are
#     parsed and filtered, keeping only those that produce the correct answer
#     for the default values.

# 2.  **Pairwise Alignment (UT-2):** For every pair of surviving models, find the
#     best possible alignment between their function arguments based on semantic
#     similarity (SBERT) and matching default values. Order is ignored.

# 3.  **Pairwise Fuzzing (UT-3):** For each aligned pair, fuzz-test for functional
#     equivalence using the "Fuzz Aligned, Freeze Unaligned" strategy. This
#     ensures logical soundness even with partially matching signatures.

# 4.  **Scoring & Consensus:**
#     - Each successfully validated pair receives a `PairwiseQualityScore` based
#       on its Alignment Ratio and Semantic Strength.
#     - The script identifies the largest "clique" of models that are all
#       mutually validated.
#     - A final `ConfidenceScore` is computed from the clique's average quality,
#       boosted by a bonus for the size of the consensus.

# Dependencies
# ------------
# black                – whitespace-stable formatting
# libcst               – reliable CST traversal with comment access
# hypothesis           – property-based fuzzing
# sentence-transformers (mpnet-base) – SBERT cosine for comment semantics
# numpy                - for fast vector operations
# """

# from __future__ import annotations

# # --- (inside gsm8k_validator.py) ---

# import importlib.util
# import inspect
# import itertools
# import json
# import re
# import sys
# from dataclasses import dataclass, field
# from pathlib import Path
# from typing import List, Any, Tuple, Dict

# import black
# import hypothesis.strategies as st
# import libcst as cst  # <--- ADD THIS LINE
# import numpy as np
# from hypothesis import given, settings, HealthCheck
# from sentence_transformers import SentenceTransformer

# # ---------------------------------------------------------------------- #
# #  Global constants & Configuration
# # ---------------------------------------------------------------------- #

# _MODEL = SentenceTransformer("all-mpnet-base-v2")
# _COS_THRESHOLD = 0.90  # SBERT cosine ≥ 0.9 ⇒ semantic match
# _FUZZ_EXAMPLES = 50  # Hypothesis draws
# _MIN_ALIGNMENT_FOR_FUZZ = 1 # A pair must align on at least one arg to be fuzzed

# # --- Scoring Weights --- #
# W_ALIGNMENT = 0.7
# W_SEMANTIC = 0.3

# # --- Consensus Bonus Multipliers --- #
# CONSENSUS_BONUS = {
#     2: 1.0,  # Baseline for a pair
#     3: 1.1,  # 10% bonus for a 3-way consensus
#     4: 1.2,  # 20% bonus for a 4-way consensus
#     5: 1.3,
# }

# _TRACE_RE = re.compile(r"^#: L(\d+)\b")
# _DOC_INDEX_RE = re.compile(r"^Index:\s*(\d+)")


# # ---------------------------------------------------------------------- #
# #  Dataclasses for Structured Data
# # ---------------------------------------------------------------------- #

# @dataclass(frozen=True)
# class Argument:
#     """Represents a single argument from a function signature."""
#     name: str
#     default: Any
#     comment: str


# @dataclass(frozen=True)
# class ParsedFile:
#     """Normalised representation of a generated code file."""
#     path: Path
#     module_code: str
#     func: Any
#     args: List[Argument] = field(default_factory=list)


# @dataclass(frozen=True)
# class AlignmentResult:
#     """Stores the result of aligning two ParsedFiles."""
#     aligned_pairs: List[Tuple[Argument, Argument]]
#     unaligned_A: List[Argument]
#     unaligned_B: List[Argument]
#     semantic_scores: List[float]


# @dataclass
# class PairwiseValidation:
#     """Stores the full result of a successful pairwise validation."""
#     file_A: ParsedFile
#     file_B: ParsedFile
#     alignment: AlignmentResult
#     quality_score: float


# # ---------------------------------------------------------------------- #
# #  Core Logic Implementation
# # ---------------------------------------------------------------------- #

# def parse_file(path: Path) -> ParsedFile | None:
#     """Parse one generated .py file into a structured ParsedFile object."""
#     try:
#         src_raw = path.read_text(encoding="utf-8")
#         src_fmt = black.format_str(src_raw, mode=black.FileMode())

#         mod = cst.parse_module(src_fmt)
#         func_nd = next(
#             n for n in mod.body if isinstance(n, cst.FunctionDef) and n.name.value == "solve"
#         )

#         args = []
#         for param in func_nd.params.params:
#             if not param.default:
#                 continue

#             # --- Argument Name and Default Value ---
#             name = param.name.value
#             default = eval(mod.code_for_node(param.default))

#             # --- Robust Comment Extraction ---
#             comment = ""
#             comment_node = None

#             # Check for comment on the parameter's trailing whitespace
#             # This handles both simple cases and the 'ParenthesizedWhitespace' edge case
#             if hasattr(param, "trailing_whitespace") and param.trailing_whitespace:
#                 ws_node = param.trailing_whitespace
#                 if hasattr(ws_node, "comment") and ws_node.comment:
#                     comment_node = ws_node.comment
#                 elif hasattr(ws_node, "last_line") and hasattr(ws_node.last_line, "comment") and ws_node.last_line.comment:
#                     comment_node = ws_node.last_line.comment

#             # If not found, check for comment on the trailing comma (for non-last params)
#             if not comment_node and hasattr(param, "comma") and param.comma:
#                 if hasattr(param.comma, "whitespace_after") and param.comma.whitespace_after:
#                     ws_node_after_comma = param.comma.whitespace_after
#                     if hasattr(ws_node_after_comma, "comment") and ws_node_after_comma.comment:
#                          comment_node = ws_node_after_comma.comment

#             if comment_node:
#                 comment = comment_node.value.lstrip("#").strip()


#             # --- Append the fully parsed argument to the list ---
#             args.append(Argument(
#                 name=name,
#                 default=default,
#                 comment=comment
#             ))

#         spec = importlib.util.spec_from_loader(f"gsm8k_{path.stem}_{hash(path)}", loader=None)
#         mod_dyn = importlib.util.module_from_spec(spec)
#         exec(src_fmt, mod_dyn.__dict__)

#         return ParsedFile(
#             path=path,
#             module_code=src_fmt,
#             func=mod_dyn.solve,
#             args=args
#         )
#     except (FileNotFoundError, StopIteration, SyntaxError, Exception) as e:
#         print(f"[Parser Error] Skipping {path.name}: {e!r}", file=sys.stderr)
#         return None


# def ut0_answer_match(files: List[ParsedFile], gold: float) -> List[ParsedFile]:
#     """Keep only files whose solve() returns the official answer with default args."""
#     ok_files = []
#     for pf in files:
#         try:
#             if np.isclose(pf.func(), gold):
#                 ok_files.append(pf)
#         except Exception as e:
#             print(f"[UT-0 Fail] {pf.path.name} raised {e!r}", file=sys.stderr)
#     return ok_files


# def find_best_alignment(file_A: ParsedFile, file_B: ParsedFile) -> AlignmentResult:
#     """Find the best argument alignment between two functions, ignoring order."""
#     args_A, args_B = file_A.args, file_B.args
#     if not args_A or not args_B:
#         return AlignmentResult([], args_A, args_B, [])

#     comments_B = [arg.comment for arg in args_B]
#     embeddings_B = _MODEL.encode(comments_B, normalize_embeddings=True)

#     aligned_pairs = []
#     semantic_scores = []
#     used_b_indices = set()

#     for arg_A in args_A:
#         embedding_A = _MODEL.encode([arg_A.comment], normalize_embeddings=True)
#         similarities = (embedding_A @ embeddings_B.T).flatten()

#         best_b_idx = -1
#         # Find best match among available B args
#         for b_idx in np.argsort(similarities)[::-1]:
#             if b_idx not in used_b_indices:
#                 best_b_idx = b_idx
#                 break
        
#         if best_b_idx == -1: continue

#         best_match_arg_B = args_B[best_b_idx]
#         similarity_score = similarities[best_b_idx]

#         if similarity_score >= _COS_THRESHOLD and arg_A.default == best_match_arg_B.default:
#             aligned_pairs.append((arg_A, best_match_arg_B))
#             semantic_scores.append(similarity_score)
#             used_b_indices.add(best_b_idx)
    
#     unaligned_A = [arg for arg in args_A if arg not in [p[0] for p in aligned_pairs]]
#     unaligned_B = [args_B[i] for i in range(len(args_B)) if i not in used_b_indices]

#     return AlignmentResult(aligned_pairs, unaligned_A, unaligned_B, semantic_scores)


# def fuzz_aligned_pair(alignment: AlignmentResult, func_A: callable, func_B: callable) -> bool:
#     """Fuzz-test an aligned pair using the 'Fuzz Aligned, Freeze Unaligned' strategy."""
#     if len(alignment.aligned_pairs) < _MIN_ALIGNMENT_FOR_FUZZ:
#         return False

#     strat_map = {}
#     for i, (arg_A, _) in enumerate(alignment.aligned_pairs):
#         literal = arg_A.default
#         strat = st.floats if isinstance(literal, float) else st.integers
#         strat_map[f"pair_{i}"] = strat(min_value=1, max_value=50)

#     # Freeze unaligned args to their defaults
#     frozen_kwargs_A = {arg.name: arg.default for arg in alignment.unaligned_A}
#     frozen_kwargs_B = {arg.name: arg.default for arg in alignment.unaligned_B}

#     @settings(max_examples=_FUZZ_EXAMPLES, deadline=None, suppress_health_check=[HealthCheck.too_slow])
#     @given(st.fixed_dictionaries(strat_map))
#     def _check(fuzzed_values):
#         kwargs_A = frozen_kwargs_A.copy()
#         kwargs_B = frozen_kwargs_B.copy()

#         for i, (arg_A, arg_B) in enumerate(alignment.aligned_pairs):
#             fuzzed_val = fuzzed_values[f"pair_{i}"]
#             kwargs_A[arg_A.name] = fuzzed_val
#             kwargs_B[arg_B.name] = fuzzed_val

#         assert np.isclose(func_A(**kwargs_A), func_B(**kwargs_B))

#     try:
#         _check()
#         return True
#     except Exception:
#         return False


# def calculate_pairwise_score(alignment: AlignmentResult, file_A: ParsedFile, file_B: ParsedFile) -> float:
#     """Calculate the quality score for a single validated pair."""
#     num_aligned = len(alignment.aligned_pairs)
    
#     # Alignment Ratio
#     total_args_A = len(file_A.args)
#     total_args_B = len(file_B.args)
#     max_possible_args = max(total_args_A, total_args_B)
#     alignment_ratio = num_aligned / max_possible_args if max_possible_args > 0 else 1.0

#     # Semantic Strength
#     semantic_strength = np.mean(alignment.semantic_scores) if alignment.semantic_scores else 1.0

#     return (W_ALIGNMENT * alignment_ratio) + (W_SEMANTIC * semantic_strength)


# # ---------------------------------------------------------------------- #
# #  Orchestration and Reporting
# # ---------------------------------------------------------------------- #

# def analyze_problem_outputs(problem_dir: Path, gold_answer: float):
#     """Main orchestrator to analyze all model outputs for a single problem."""
#     print(f"\n{'='*20} Analyzing Problem: {problem_dir.name} {'='*20}")
    
#     all_files = list(problem_dir.glob("*.py"))
#     if not all_files:
#         print("No Python files found in this directory.")
#         return

#     parsed_files = [pf for pf in [parse_file(p) for p in all_files] if pf is not None]
#     print(f"Found and parsed {len(parsed_files)} files.")

#     survivors_ut0 = ut0_answer_match(parsed_files, gold_answer)
#     print(f"{len(survivors_ut0)} files passed UT-0 (correct default answer).")
#     if len(survivors_ut0) < 2:
#         print("Not enough models passed UT-0 to find a pair. Aborting.")
#         return

#     # --- Pairwise Validation ---
#     validated_pairs: List[PairwiseValidation] = []
#     for file_A, file_B in itertools.combinations(survivors_ut0, 2):
#         alignment = find_best_alignment(file_A, file_B)
        
#         if fuzz_aligned_pair(alignment, file_A.func, file_B.func):
#             score = calculate_pairwise_score(alignment, file_A, file_B)
#             validated_pairs.append(PairwiseValidation(file_A, file_B, alignment, score))
#             print(f"  ✓ Validated Pair: ({file_A.path.name}, {file_B.path.name}), Score: {score:.3f}")

#     if not validated_pairs:
#         print("\nNo functionally equivalent pairs found after fuzzing.")
#         return

#     # --- Find Best Consensus Clique ---
#     nodes = survivors_ut0
#     adj = {pf.path.name: set() for pf in nodes}
#     for vp in validated_pairs:
#         adj[vp.file_A.path.name].add(vp.file_B.path.name)
#         adj[vp.file_B.path.name].add(vp.file_A.path.name)

#     best_clique = []
#     # Check for cliques of decreasing size
#     for size in range(len(nodes), 1, -1):
#         for combo in itertools.combinations(nodes, size):
#             names = [pf.path.name for pf in combo]
#             is_clique = all(
#                 names[j] in adj[names[i]] for i in range(size) for j in range(i + 1, size)
#             )
#             if is_clique:
#                 best_clique = list(combo)
#                 break
#         if best_clique:
#             break
    
#     # --- Calculate Final Score and Report ---
#     if not best_clique:
#         # Should not happen if validated_pairs is not empty
#         best_pair = max(validated_pairs, key=lambda vp: vp.quality_score)
#         final_score = best_pair.quality_score
#         clique_size = 2
#         best_clique_names = [best_pair.file_A.path.name, best_pair.file_B.path.name]
#         avg_quality = final_score
#     else:
#         clique_size = len(best_clique)
#         clique_names = [pf.path.name for pf in best_clique]
        
#         clique_pairs_scores = [
#             vp.quality_score for vp in validated_pairs 
#             if vp.file_A.path.name in clique_names and vp.file_B.path.name in clique_names
#         ]
#         avg_quality = np.mean(clique_pairs_scores) if clique_pairs_scores else 0
#         bonus = CONSENSUS_BONUS.get(clique_size, max(CONSENSUS_BONUS.values()))
#         final_score = avg_quality * bonus

#     print("\n" + "-"*50)
#     print("                 VALIDATION SUMMARY")
#     print("-"*50)
#     print(f"Best Consensus Found: {clique_size}-way agreement")
#     print(f"Models in Consensus: {best_clique_names}")
#     print(f"Average Pairwise Quality in Clique: {avg_quality:.4f}")
#     print(f"Consensus Bonus Multiplier: x{CONSENSUS_BONUS.get(clique_size, 'N/A')}")
#     print(f"FINAL CONFIDENCE SCORE: {final_score:.4f}")
#     print("-"*50)

In [58]:
"""
gsm8k_validator_v2.py
======================

Performs a flexible, score-based validation of multiple LLM-generated `solve()`
functions for a given GSM8K problem. Instead of a rigid filter, this script
analyzes all pairs of models to find the most robust and comprehensive consensus.

The final output is a confidence score (0.0-1.0+) that reflects:
1.  The completeness of parameter alignment between models.
2.  The semantic clarity of the aligned parameters.
3.  The level of consensus (number of models in agreement).

Core Logic:
-----------
1.  **Parse & Pre-filter (UT-0):** All generated Python files for a problem are
    parsed and filtered, keeping only those that produce the correct answer
    for the default values.

2.  **Pairwise Alignment (UT-2):** For every pair of surviving models, find the
    best possible alignment between their function arguments based on semantic
    similarity (SBERT) and matching default values. Order is ignored.

3.  **Pairwise Fuzzing (UT-3):** For each aligned pair, fuzz-test for functional
    equivalence using the "Fuzz Aligned, Freeze Unaligned" strategy. This
    ensures logical soundness even with partially matching signatures.

4.  **Scoring & Consensus:**
    - Each successfully validated pair receives a `PairwiseQualityScore` based
      on its Alignment Ratio and Semantic Strength.
    - The script identifies the largest "clique" of models that are all
      mutually validated.
    - A final `ConfidenceScore` is computed from the clique's average quality,
      boosted by a bonus for the size of the consensus.

Dependencies
------------
black                – whitespace-stable formatting
libcst               – reliable CST traversal with comment access
hypothesis           – property-based fuzzing
sentence-transformers (mpnet-base) – SBERT cosine for comment semantics
numpy                - for fast vector operations
"""

from __future__ import annotations

# --- (inside gsm8k_validator.py) ---

import importlib.util
import inspect
import itertools
import json
import re
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Any, Tuple, Dict

import black
import hypothesis.strategies as st
import libcst as cst  # <--- ADD THIS LINE
import numpy as np
from hypothesis import given, settings, HealthCheck
from sentence_transformers import SentenceTransformer

from collections import defaultdict

# ---------------------------------------------------------------------- #
#  Global constants & Configuration
# ---------------------------------------------------------------------- #

_MODEL = SentenceTransformer("all-mpnet-base-v2")
_COS_THRESHOLD = 0.70  # SBERT cosine ≥ 0.7 ⇒ semantic match
_FUZZ_EXAMPLES = 50  # Hypothesis draws
_MIN_ALIGNMENT_FOR_FUZZ = 1 # A pair must align on at least one arg to be fuzzed

# --- Scoring Weights --- #
W_ALIGNMENT = 0.7
W_SEMANTIC = 0.3

# --- Consensus Bonus Multipliers --- #
CONSENSUS_BONUS = {
    2: 1.0,  # Baseline for a pair
    3: 1.1,  # 10% bonus for a 3-way consensus
    4: 1.2,  # 20% bonus for a 4-way consensus
    5: 1.3,
}

_TRACE_RE = re.compile(r"^#: L(\d+)\b")
_DOC_INDEX_RE = re.compile(r"^Index:\s*(\d+)")


# ---------------------------------------------------------------------- #
#  Dataclasses for Structured Data
# ---------------------------------------------------------------------- #

@dataclass(frozen=True)
class Argument:
    """Represents a single argument from a function signature."""
    name: str
    arg_type: str  # <--- ADD THIS LINE
    default: Any
    comment: str


@dataclass(frozen=True)
class ParsedFile:
    """Normalised representation of a generated code file."""
    path: Path
    module_code: str
    func: Any
    args: List[Argument] = field(default_factory=list)


@dataclass(frozen=True)
class AlignmentResult:
    """Stores the result of aligning two ParsedFiles."""
    aligned_pairs: List[Tuple[Argument, Argument]]
    unaligned_A: List[Argument]
    unaligned_B: List[Argument]
    semantic_scores: List[float]


@dataclass
class PairwiseValidation:
    """Stores the full result of a successful pairwise validation."""
    file_A: ParsedFile
    file_B: ParsedFile
    alignment: AlignmentResult
    quality_score: float


# ---------------------------------------------------------------------- #
#  Core Logic Implementation
# ---------------------------------------------------------------------- #

# In your validator script, find and replace the parse_file function

def parse_file(path: Path) -> ParsedFile | None:
    """
    Parse one generated .py file into a structured ParsedFile object using
    a direct, regex-based approach.
    """
    try:
        src_raw = path.read_text(encoding="utf-8")
        src_fmt = black.format_str(src_raw, mode=black.FileMode())

        signature_match = re.search(r"def solve\s*\((.*?)\):", src_fmt, re.DOTALL)
        if not signature_match:
            raise ValueError("Could not find a 'def solve(...):' signature.")
        
        signature_content = signature_match.group(1)

        args = []
        # --- MODIFIED REGEX: Now captures the type hint in group(2) ---
        arg_pattern = re.compile(
            r"^\s*([a-zA-Z_]\w*)\s*:\s*(\w+)\s*=\s*(.*?)\s*,?\s*(?:#\s*(.*))?$"
        )

        for line in signature_content.splitlines():
            if not line.strip(): continue
            match = arg_pattern.match(line)
            if match:
                name = match.group(1)
                arg_type = match.group(2) # <--- CAPTURE TYPE
                default_str = match.group(3).strip()
                default_val = eval(default_str)
                comment = match.group(4).strip() if match.group(4) else ""
                
                args.append(Argument(
                    name=name,
                    arg_type=arg_type, # <--- STORE TYPE
                    default=default_val,
                    comment=comment
                ))

        spec = importlib.util.spec_from_loader(f"gsm8k_{path.stem}_{hash(path)}", loader=None)
        mod_dyn = importlib.util.module_from_spec(spec)
        exec(src_fmt, mod_dyn.__dict__)

        return ParsedFile(
            path=path,
            module_code=src_fmt,
            func=mod_dyn.solve,
            args=args
        )
    except (FileNotFoundError, StopIteration, SyntaxError, Exception) as e:
        print(f"[Parser Error] Skipping {path.name}: {e!r}", file=sys.stderr)
        return None


def ut0_answer_match(files: List[ParsedFile], gold: float) -> List[ParsedFile]:
    """Keep only files whose solve() returns the official answer with default args."""
    ok_files = []
    for pf in files:
        try:
            if np.isclose(pf.func(), gold):
                ok_files.append(pf)
        except Exception as e:
            print(f"[UT-0 Fail] {pf.path.name} raised {e!r}", file=sys.stderr)
    return ok_files

def find_best_alignment(file_A: ParsedFile, file_B: ParsedFile) -> AlignmentResult:
    """
    Finds the best argument alignment using a 'bucket and match' strategy.
    1. Groups args from each file into buckets by (type, default_value).
    2. Only performs semantic comparison on args within matching buckets.
    """
    # --- 1. Create buckets for each file's arguments ---
    buckets_A = defaultdict(list)
    for arg in file_A.args:
        buckets_A[(arg.arg_type, arg.default)].append(arg)
        
    buckets_B = defaultdict(list)
    for arg in file_B.args:
        buckets_B[(arg.arg_type, arg.default)].append(arg)

    aligned_pairs = []
    semantic_scores = []
    
    # --- 2. Iterate through buckets that exist in BOTH files ---
    common_keys = set(buckets_A.keys()) & set(buckets_B.keys())
    
    for key in common_keys:
        args_in_bucket_A = buckets_A[key]
        args_in_bucket_B = buckets_B[key]
        
        # --- 3. Perform semantic alignment ONLY within the current bucket ---
        texts_A = [arg.name.replace("_", " ") + " | " + arg.comment for arg in args_in_bucket_A]
        texts_B = [arg.name.replace("_", " ") + " | " + arg.comment for arg in args_in_bucket_B]
        
        embeddings_A = _MODEL.encode(texts_A, normalize_embeddings=True)
        embeddings_B = _MODEL.encode(texts_B, normalize_embeddings=True)
        similarity_matrix = embeddings_A @ embeddings_B.T

        # Use a greedy matching strategy within the bucket
        sorted_indices = np.argsort(similarity_matrix, axis=None)[::-1]
        flat_indices = np.atleast_1d(sorted_indices)
        rows, cols = np.unravel_index(flat_indices, similarity_matrix.shape)

        used_in_bucket_A = set()
        used_in_bucket_B = set()

        for i, j in zip(rows, cols):
            if i in used_in_bucket_A or j in used_in_bucket_B:
                continue

            similarity_score = similarity_matrix[i, j]
            if similarity_score >= _COS_THRESHOLD:
                aligned_pairs.append((args_in_bucket_A[i], args_in_bucket_B[j]))
                semantic_scores.append(similarity_score)
                used_in_bucket_A.add(i)
                used_in_bucket_B.add(j)

    # --- 4. Calculate the final unaligned sets ---
    final_aligned_A = {p[0] for p in aligned_pairs}
    final_aligned_B = {p[1] for p in aligned_pairs}
    unaligned_A = [arg for arg in file_A.args if arg not in final_aligned_A]
    unaligned_B = [arg for arg in file_B.args if arg not in final_aligned_B]

    return AlignmentResult(aligned_pairs, unaligned_A, unaligned_B, semantic_scores)


def fuzz_aligned_pair(alignment: AlignmentResult, func_A: callable, func_B: callable) -> bool:
    """Fuzz-test an aligned pair using the 'Fuzz Aligned, Freeze Unaligned' strategy."""
    if len(alignment.aligned_pairs) < _MIN_ALIGNMENT_FOR_FUZZ:
        return False

    strat_map = {}
    for i, (arg_A, _) in enumerate(alignment.aligned_pairs):
        literal = arg_A.default
        strat = st.floats if isinstance(literal, float) else st.integers
        strat_map[f"pair_{i}"] = strat(min_value=1, max_value=50)

    # Freeze unaligned args to their defaults
    frozen_kwargs_A = {arg.name: arg.default for arg in alignment.unaligned_A}
    frozen_kwargs_B = {arg.name: arg.default for arg in alignment.unaligned_B}

    @settings(max_examples=_FUZZ_EXAMPLES, deadline=None, suppress_health_check=[HealthCheck.too_slow])
    @given(st.fixed_dictionaries(strat_map))
    def _check(fuzzed_values):
        kwargs_A = frozen_kwargs_A.copy()
        kwargs_B = frozen_kwargs_B.copy()

        for i, (arg_A, arg_B) in enumerate(alignment.aligned_pairs):
            fuzzed_val = fuzzed_values[f"pair_{i}"]
            kwargs_A[arg_A.name] = fuzzed_val
            kwargs_B[arg_B.name] = fuzzed_val

        assert np.isclose(func_A(**kwargs_A), func_B(**kwargs_B))

    try:
        _check()
        return True
    except Exception:
        return False


def calculate_pairwise_score(alignment: AlignmentResult, file_A: ParsedFile, file_B: ParsedFile) -> float:
    """Calculate the quality score for a single validated pair."""
    num_aligned = len(alignment.aligned_pairs)
    
    # Alignment Ratio
    total_args_A = len(file_A.args)
    total_args_B = len(file_B.args)
    max_possible_args = max(total_args_A, total_args_B)
    alignment_ratio = num_aligned / max_possible_args if max_possible_args > 0 else 1.0

    # Semantic Strength
    semantic_strength = np.mean(alignment.semantic_scores) if alignment.semantic_scores else 1.0

    return (W_ALIGNMENT * alignment_ratio) + (W_SEMANTIC * semantic_strength)


# ---------------------------------------------------------------------- #
#  Orchestration and Reporting
# ---------------------------------------------------------------------- #

def analyze_problem_outputs(problem_dir: Path, gold_answer: float):
    """Main orchestrator to analyze all model outputs for a single problem."""
    print(f"\n{'='*20} Analyzing Problem: {problem_dir.name} {'='*20}")
    
    all_files = list(problem_dir.glob("*.py"))
    if not all_files:
        print("No Python files found in this directory.")
        return

    parsed_files = [pf for pf in [parse_file(p) for p in all_files] if pf is not None]
    print(f"Found and parsed {len(parsed_files)} files.")

    survivors_ut0 = ut0_answer_match(parsed_files, gold_answer)
    print(f"{len(survivors_ut0)} files passed UT-0 (correct default answer).")
    if len(survivors_ut0) < 2:
        print("Not enough models passed UT-0 to find a pair. Aborting.")
        return

    # --- Pairwise Validation ---
    validated_pairs: List[PairwiseValidation] = []
    for file_A, file_B in itertools.combinations(survivors_ut0, 2):
        alignment = find_best_alignment(file_A, file_B)
        
        if fuzz_aligned_pair(alignment, file_A.func, file_B.func):
            score = calculate_pairwise_score(alignment, file_A, file_B)
            validated_pairs.append(PairwiseValidation(file_A, file_B, alignment, score))
            print(f"  ✓ Validated Pair: ({file_A.path.name}, {file_B.path.name}), Score: {score:.3f}")

    if not validated_pairs:
        print("\nNo functionally equivalent pairs found after fuzzing.")
        return

    # --- Find Best Consensus Clique ---
    nodes = survivors_ut0
    adj = {pf.path.name: set() for pf in nodes}
    for vp in validated_pairs:
        adj[vp.file_A.path.name].add(vp.file_B.path.name)
        adj[vp.file_B.path.name].add(vp.file_A.path.name)

    best_clique = []
    # Check for cliques of decreasing size
    for size in range(len(nodes), 1, -1):
        for combo in itertools.combinations(nodes, size):
            names = [pf.path.name for pf in combo]
            is_clique = all(
                names[j] in adj[names[i]] for i in range(size) for j in range(i + 1, size)
            )
            if is_clique:
                best_clique = list(combo)
                break
        if best_clique:
            break
    
    # --- Calculate Final Score and Report ---
    if not best_clique:
        # Should not happen if validated_pairs is not empty
        best_pair = max(validated_pairs, key=lambda vp: vp.quality_score)
        final_score = best_pair.quality_score
        clique_size = 2
        best_clique_names = [best_pair.file_A.path.name, best_pair.file_B.path.name]
        avg_quality = final_score
    else:
        clique_size = len(best_clique)
        clique_names = [pf.path.name for pf in best_clique]
        
        clique_pairs_scores = [
            vp.quality_score for vp in validated_pairs 
            if vp.file_A.path.name in clique_names and vp.file_B.path.name in clique_names
        ]
        avg_quality = np.mean(clique_pairs_scores) if clique_pairs_scores else 0
        bonus = CONSENSUS_BONUS.get(clique_size, max(CONSENSUS_BONUS.values()))
        final_score = avg_quality * bonus

    print("\n" + "-"*50)
    print("                 VALIDATION SUMMARY")
    print("-"*50)
    print(f"Best Consensus Found: {clique_size}-way agreement")
    print(f"Models in Consensus: {best_clique_names}")
    print(f"Average Pairwise Quality in Clique: {avg_quality:.4f}")
    print(f"Consensus Bonus Multiplier: x{CONSENSUS_BONUS.get(clique_size, 'N/A')}")
    print(f"FINAL CONFIDENCE SCORE: {final_score:.4f}")
    print("-"*50)

In [64]:
from datasets import load_dataset

# Load the GSM8K dataset (train split)
gsm8k_train = load_dataset("gsm8k", "main", split="train")

sample_5464 = gsm8k_train[5464]
print(sample_5464)

{'question': 'Bill milked his cow and got 16 gallons of milk. He turned 1/4 into sour cream, 1/4 into butter, and kept the rest as whole milk. It takes 4 gallons of milk to make one gallon of butter and 2 gallons of milk to make 1 gallon of sour cream. If Bill sells butter for $5/gallon, sour cream for $6/gallon, and whole milk for $3/gallon, how much money does he make?', 'answer': 'First find how much milk Bill turned into sour cream and butter: 16 gallons * 1/4 = <<16*1/4=4>>4 gallons\nThen find how many gallons of butter he makes out of 4 gallons of milk: 4 gallons milk / 4 gallons milk/1 gallon butter = <<4/4/1=1>>1 gallon butter\nThen find how many gallons of sour cream he makes out of 4 gallons of milk: 4 gallons milk / 2 gallons milk/1 gallon sour cream = <<4/2/1=2>>2 gallon sour cream\nThen subtract the amount of milk turned into butter and sour cream to find the remaining amount of whole milk: 16 gallons - 4 gallons - 4 gallons = <<16-4-4=8>>8 gallons\nThen multiply the numbe

### Cell 1: Setup and Imports

This cell imports the necessary libraries and the functions from your validator script. It also sets up the problem parameters we'll use for the test run.

In [70]:
# Cell 1: Setup and Imports
import sys
from pathlib import Path
import json
import pprint

# Ensure the script's directory is in the Python path
sys.path.append(str(Path.cwd()))

# # Import all the necessary functions and classes from your validator
# from gsm8k_validator import (
#     parse_file,
#     ut0_answer_match,
#     find_best_alignment,
#     fuzz_aligned_pair,
#     calculate_pairwise_score,
#     analyze_problem_outputs,
#     ParsedFile,
#     AlignmentResult,
#     PairwiseValidation
# )

# --- Configuration for our test run ---
# Let's use problem 4483 from our previous discussion
BASE_DIR = Path("code_generation_outputs_cleaned")
PROBLEM_INDEX = 4483
GOLD_ANSWER = 100.0  # Sammy: 200*4=800, Bryan: 100*6+100*1=700. Diff=100

# Pretty printer for clean output
pp = pprint.PrettyPrinter(indent=2)

print("Setup complete. Functions imported and configuration set.")

Setup complete. Functions imported and configuration set.


### Cell 2: Test `parse_file` on a Single File

This cell tests the foundational parsing logic. We'll pick one file and inspect the `ParsedFile` object it produces to ensure arguments, comments, and the function itself are extracted correctly.

In [None]:
# Cell 2: Test `parse_file` on a single file

def test_parse_file(base_dir, problem_index):
    problem_dir = BASE_DIR / str(PROBLEM_INDEX)
    # Let's test the Anthropic file, which had a different argument order
    file_to_test = problem_dir / "anthropic_claude-3-5-haiku-20241022.py"

    print(f"--- Testing parse_file on: {file_to_test.name} ---")
    parsed_file_single = parse_file(file_to_test)

    if parsed_file_single:
        print("\nSuccessfully parsed the file. Contents:")
        print(f"Path: {parsed_file_single.path.name}")
        print(f"Callable function found: {callable(parsed_file_single.func) and parsed_file_single.func.__name__ == 'solve'}")
        
        print("\nExtracted Arguments:")
        for arg in parsed_file_single.args:
            print(f"  - Name: {arg.name}, Default: {arg.default}, Comment: '{arg.comment}'")
    else:
        print("\nFailed to parse the file.")

--- Testing parse_file on: anthropic_claude-3-5-haiku-20241022.py ---

Successfully parsed the file. Contents:
Path: anthropic_claude-3-5-haiku-20241022.py
Callable function found: True

Extracted Arguments:
  - Name: total_records, Default: 200, Comment: 'Peggy has 200 records'
  - Name: sammy_price_per_record, Default: 4, Comment: 'Sammy offers 4 dollars each'
  - Name: bryan_high_price_per_record, Default: 6, Comment: 'Bryan offers 6 dollars for half'
  - Name: bryan_low_price_per_record, Default: 1, Comment: 'Bryan offers 1 dollar for other half'


### Cell 3: Load All Files and Run `ut0_answer_match`

Now, let's parse all files for the problem and run the first filtering step (UT-0) to see which ones produce the correct default answer.

In [None]:
# Cell 3: Load all files and run `ut0_answer_match`

def test_ut0(problem_index, gold_answer):
    all_files_for_problem = list(problem_dir.glob("*.py"))
    print(f"Found {len(all_files_for_problem)} files in '{problem_dir}'.")

    # Parse all of them
    all_parsed_files = [pf for pf in [parse_file(p) for p in all_files_for_problem] if pf]
    print(f"Successfully parsed {len(all_parsed_files)} files.")

    print(f"\n--- Running UT-0 (Answer Match against Gold Answer: {GOLD_ANSWER}) ---")
    survivors_ut0 = ut0_answer_match(all_parsed_files, GOLD_ANSWER)

    print(f"\n{len(survivors_ut0)} files passed UT-0:")
    for pf in survivors_ut0:
        print(f"  - {pf.path.name}")

Found 9 files in 'code_generation_outputs_cleaned/4483'.
Successfully parsed 9 files.

--- Running UT-0 (Answer Match against Gold Answer: 100.0) ---

9 files passed UT-0:
  - anthropic_claude-3-5-haiku-20241022.py
  - google_gemini-2.5-flash.py
  - google_gemini-2.5-flash-lite-preview-06-17.py
  - openai_o3-mini.py
  - openai_gpt-4.1.py
  - google_gemini-2.0-flash-thinking-exp.py
  - openai_o4-mini.py
  - openai_gpt-4.1-mini.py
  - google_gemini-2.5-pro.py


### Cell 4: Test `find_best_alignment` on a Pair

This is a critical step. We will manually select two files that we know have different argument structures and see if the alignment logic correctly identifies the matching parameters.

In [None]:
# Cell 4: Test `find_best_alignment` on a Pair (Corrected Debug Report)

from collections import defaultdict

def test_alignment
try:
    # --- 1. Select the two files to compare ---
    file_A = next(pf for pf in survivors_ut0 if "google_gemini-2.0-flash-thinking-exp" in pf.path.name)
    file_B = next(pf for pf in survivors_ut0 if "anthropic_claude-3-5-haiku" in pf.path.name)

    print(f"--- Aligning {file_A.path.name} (A) and {file_B.path.name} (B) ---")
    
    # --- 2. Run the real alignment function to see its correct output ---
    alignment_result = find_best_alignment(file_A, file_B)

    print("\n--- Correct Alignment Result (from find_best_alignment) ---")
    print(f"  - Found {len(alignment_result.aligned_pairs)} aligned pairs.")
    if alignment_result.aligned_pairs:
        sorted_pairs = sorted(alignment_result.aligned_pairs, key=lambda p: p[0].name)
        for i, (arg_A, arg_B) in enumerate(sorted_pairs):
            print(f"    - Pair {i+1}: '{arg_A.name}' (A) <=> '{arg_B.name}' (B)")

    # --- 3. DETAILED DEBUG REPORT USING THE "BUCKET AND MATCH" STRATEGY ---
    print("\n" + "="*50)
    print("           DETAILED ALIGNMENT DEBUG REPORT")
    print("="*50)
    
    # Create buckets for each file's arguments
    buckets_A = defaultdict(list)
    for arg in file_A.args:
        buckets_A[(arg.arg_type, arg.default)].append(arg)
        
    buckets_B = defaultdict(list)
    for arg in file_B.args:
        buckets_B[(arg.arg_type, arg.default)].append(arg)
    
    # Iterate through buckets that exist in BOTH files
    common_keys = set(buckets_A.keys()) & set(buckets_B.keys())
    
    if not common_keys:
        print("No common (type, default_value) buckets found between files.")
        
    for key in common_keys:
        print(f"\n--- Comparing Bucket: type='{key[0]}', value={key[1]} ---")
        
        args_in_bucket_A = buckets_A[key]
        args_in_bucket_B = buckets_B[key]
        
        # Perform semantic comparison only within this bucket
        texts_A = [arg.name.replace("_", " ") + " | " + arg.comment for arg in args_in_bucket_A]
        texts_B = [arg.name.replace("_", " ") + " | " + arg.comment for arg in args_in_bucket_B]
        
        embeddings_A = _MODEL.encode(texts_A, normalize_embeddings=True)
        embeddings_B = _MODEL.encode(texts_B, normalize_embeddings=True)
        similarity_matrix = embeddings_A @ embeddings_B.T

        for i, arg_A in enumerate(args_in_bucket_A):
            for j, arg_B in enumerate(args_in_bucket_B):
                sim_score = similarity_matrix[i, j]
                status = "PASS" if sim_score >= _COS_THRESHOLD else "FAIL (Similarity Too Low)"
                
                print(f"  - Comparing A:'{arg_A.name}' with B:'{arg_B.name}'")
                print(f"      Similarity: {sim_score:.4f} | ==> Status: {status}")
    
    print("\n" + "="*50)
    print("NOTE: Any comparisons not shown were skipped because the arguments were")
    print("      not in a matching (type, default_value) bucket.")
    print("="*50)


except StopIteration:
    print("Could not find the specified files among UT-0 survivors. Please check filenames.")

--- Aligning google_gemini-2.0-flash-thinking-exp.py (A) and anthropic_claude-3-5-haiku-20241022.py (B) ---

--- Correct Alignment Result (from find_best_alignment) ---
  - Found 4 aligned pairs.
    - Pair 1: 'bryan_price_interested' (A) <=> 'bryan_high_price_per_record' (B)
    - Pair 2: 'bryan_price_not_interested' (A) <=> 'bryan_low_price_per_record' (B)
    - Pair 3: 'sammy_price_per_record' (A) <=> 'sammy_price_per_record' (B)
    - Pair 4: 'total_records' (A) <=> 'total_records' (B)

           DETAILED ALIGNMENT DEBUG REPORT

--- Comparing Bucket: type='int', value=6 ---
  - Comparing A:'bryan_price_interested' with B:'bryan_high_price_per_record'
      Similarity: 0.7210 | ==> Status: PASS

--- Comparing Bucket: type='int', value=1 ---
  - Comparing A:'bryan_price_not_interested' with B:'bryan_low_price_per_record'
      Similarity: 0.7451 | ==> Status: PASS

--- Comparing Bucket: type='int', value=200 ---
  - Comparing A:'total_records' with B:'total_records'
      Similarity

### Cell 5: Test `fuzz_aligned_pair` and `calculate_pairwise_score`

Now we'll take the `alignment_result` from the previous cell and use it to run the fuzzing test. If it passes, we'll calculate the quality score for this specific pair.

In [None]:
# Cell 5: Test `fuzz_aligned_pair` and `calculate_pairwise_score`

# We'll use the alignment_result from the previous cell
if 'alignment_result' in locals() and alignment_result:
    print("--- Running Fuzz Test on the aligned pair ---")
    
    is_equivalent = fuzz_aligned_pair(alignment_result, file_A.func, file_B.func)

    if is_equivalent:
        print("\n✅ FUZZING PASSED: The pair is functionally equivalent.")
        
        print("\n--- Calculating Pairwise Quality Score ---")
        score = calculate_pairwise_score(alignment_result, file_A, file_B)
        print(f"Pairwise Quality Score: {score:.4f}")
        
    else:
        print("\n❌ FUZZING FAILED: The pair is NOT functionally equivalent.")
else:
    print("Alignment result not found. Please run the previous cell first.")

### Cell 6: Run the Full Orchestrator

Finally, this cell calls the main function from the script to run the entire analysis end-to-end and print the final summary report. This confirms that all the pieces work together as expected.

In [None]:
# Cell 6: Run the Full Orchestrator

# This function ties everything together and provides the final report.
print("--- Running the full analysis pipeline ---")
analyze_problem_outputs(problem_dir, GOLD_ANSWER)