In [1]:
# Imports
import hypothesis as hp
from hypothesis import strategies as st
from itertools import combinations
from collections import defaultdict
from typing import List, Optional, Set, Union
import pandas as pd
from pprint import pprint
from hypothesis_pick import (
    find_disagreements,
    find_stronger_weaker,
    infer_implications,
 )

In [2]:
# Configure settings
MAX_EXAMPLES_PER_COMBINATION = 15

In [3]:
# Import assignment information

In [4]:
# Get candidate predicates from LLM.

In [5]:
# Define candidate predicates and sampling strategy.
def p1_le_100(x: int) -> bool:
    return x <= 100


def p2_le_25(x: int) -> bool:
    return x <= 25


def p3_even(x: int) -> bool:
    return x % 2 == 0


def p4_gt100_odd(x: int) -> bool:
    return x > 100 and (x % 2 == 1)

PREDICATES = [p1_le_100, p2_le_25, p3_even, p4_gt100_odd]
PREDICATE_NAMES = [
    "≤ 100",
    "≤ 25",
    "even",
    "≥ 100 & odd",
]

strategy = st.integers(min_value=0, max_value=1000)

NAME_TO_PREDICATE = dict(zip(PREDICATE_NAMES, PREDICATES))

assert len(PREDICATES) == len(PREDICATE_NAMES)

In [6]:
impl = infer_implications(
    predicates=PREDICATES,
    strategy=strategy,
    max_examples=1_000,
    predicate_names=PREDICATE_NAMES,
 )

def classify_relationship(name_a: str, name_b: str) -> str:
    if impl.equivalent(name_a, name_b):
        return "equivalent"
    if impl.implies(name_a, name_b):
        return "subset"
    if impl.implies(name_b, name_a):
        return "superset"
    predicate_a = NAME_TO_PREDICATE[name_a]
    predicate_b = NAME_TO_PREDICATE[name_b]
    for candidate in range(0, 2_000):
        if predicate_a(candidate) and predicate_b(candidate):
            return "overlap"
    return "disjoint"


def build_pair_entry(name_a: str, name_b: str) -> dict:
    return {
        "relation": classify_relationship(name_a, name_b),
    }

pair_relationships = {
    (name_a, name_b): build_pair_entry(name_a, name_b)
    for i, name_a in enumerate(PREDICATE_NAMES)
    for name_b in PREDICATE_NAMES[i + 1 :]
}

combination_examples = defaultdict(list)

def record_combination(value: int) -> None:
    key = tuple(bool(NAME_TO_PREDICATE[name](value)) for name in PREDICATE_NAMES)
    bucket = combination_examples[key]
    if len(bucket) < MAX_EXAMPLES_PER_COMBINATION and value not in bucket:
        bucket.append(value)

def combinations_saturated() -> bool:
    if not combination_examples:
        return False
    return all(len(bucket) >= MAX_EXAMPLES_PER_COMBINATION for bucket in combination_examples.values())

# I believe this is doing duplicate work to Sid's library, 
# but I couldn't get its example cacheing to work, so reimplementing here.
# TODO: Fix that and use it instead.
def populate_combination_examples(max_draws: int = 5_000) -> None:
    seen_values = set()
    draws = 0
    while draws < max_draws:
        value = strategy.example()
        draws += 1
        if value in seen_values:
            continue
        seen_values.add(value)
        record_combination(value)
        if combinations_saturated():
            break

populate_combination_examples()


In [7]:
def render_ascii_table(headers, records):
    widths = {}
    for header in headers:
        candidates = [len(header)]
        candidates.extend(len(str(record.get(header, ""))) for record in records)
        widths[header] = max(candidates)
    border = "+" + "+".join("-" * (widths[header] + 2) for header in headers) + "+"
    lines = [border]
    header_line = (
        "| " + " | ".join(header.ljust(widths[header]) for header in headers) + " |"
    )
    lines.append(header_line)
    lines.append(border)
    for record in records:
        row_line = (
            "| "
            + " | ".join(
                str(record.get(header, "")).ljust(widths[header]) for header in headers
            )
            + " |"
        )
        lines.append(row_line)
    lines.append(border)
    return "\n".join(lines)

In [8]:
print("Pairwise relationships between predicates:")
headers = ["Pred 1", "Pred 2", "relationship"]
rows = []

for i, name_a in enumerate(PREDICATE_NAMES):
    for name_b in PREDICATE_NAMES[i + 1 :]:
        rows.append({
            "Pred 1": name_a,
            "Pred 2": name_b,
            "relationship": classify_relationship(name_a, name_b),
        })

if rows:
    print(render_ascii_table(headers, rows))
else:
    print("No predicate pairs available.")

Pairwise relationships between predicates:
+--------+-------------+--------------+
| Pred 1 | Pred 2      | relationship |
+--------+-------------+--------------+
| ≤ 100  | ≤ 25        | superset     |
| ≤ 100  | even        | overlap      |
| ≤ 100  | ≥ 100 & odd | disjoint     |
| ≤ 25   | even        | overlap      |
| ≤ 25   | ≥ 100 & odd | disjoint     |
| even   | ≥ 100 & odd | disjoint     |
+--------+-------------+--------------+


In [9]:
print(f"\nPredicate combination examples (up to {MAX_EXAMPLES_PER_COMBINATION} values per combination):")
if combination_examples:
    rows = []
    headers = PREDICATE_NAMES + ["examples"]
    for combination, values in sorted(combination_examples.items()):
        row = {
            name: "True" if state else "False"
            for name, state in zip(PREDICATE_NAMES, combination)
        }
        row["examples"] = ", ".join(str(value) for value in sorted(values))
        rows.append(row)

    print(render_ascii_table(headers, rows))
else:
    print("No combination examples collected.")


Predicate combination examples (up to 15 values per combination):
+-------+-------+-------+-------------+---------------------------------------------------------------------------+
| ≤ 100 | ≤ 25  | even  | ≥ 100 & odd | examples                                                                  |
+-------+-------+-------+-------------+---------------------------------------------------------------------------+
| False | False | False | True        | 121, 217, 235, 335, 355, 359, 443, 481, 563, 649, 671, 683, 699, 909, 931 |
| False | False | True  | False       | 118, 230, 232, 266, 362, 408, 418, 440, 486, 514, 550, 736, 748, 758, 922 |
| True  | False | False | False       | 27, 29, 31, 47, 53, 61, 69, 77, 79, 81, 87, 89, 91, 97, 99                |
| True  | False | True  | False       | 28, 34, 36, 40, 44, 56, 58, 66, 70, 74, 76, 80, 82, 88, 96                |
| True  | True  | False | False       | 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25                             |
| Tru

In [10]:
# def collect_candidate_values() -> List[int]:
#     """Return a sorted pool of candidate inputs for predicate interrogation."""
#     if not combination_examples:
#         populate_combination_examples()
#     values: Set[int] = set()
#     for bucket in combination_examples.values():
#         values.update(bucket)
#     return sorted(values)


# def select_best_candidate(
#     remaining: Set[str],
#     candidates: List[int],
#     used: Set[int],
#     attempts: int = 1_000,
#  ) -> Optional[int]:
#     """Pick a candidate that most evenly splits the remaining predicates."""
#     best_value: Optional[int] = None
#     best_score: float = -1.0
#     checked = 0
#     for value in candidates:
#         if value in used:
#             continue
#         checked += 1
#         truth_values = [NAME_TO_PREDICATE[name](value) for name in remaining]
#         true_count = sum(truth_values)
#         false_count = len(truth_values) - true_count
#         if true_count == 0 or false_count == 0:
#             score = 0.0
#         else:
#             score = min(true_count, false_count) / len(truth_values)
#         if score > best_score:
#             best_score = score
#             best_value = value
#         if checked >= attempts and best_score > 0:
#             break
#     return best_value


# def expand_candidate_pool(
#     candidates: List[int],
#     desired: int = 25,
#     max_draws: int = 2_000,
#  ) -> None:
#     """Top up the candidate pool with new strategy samples."""
#     draws = 0
#     while len(candidates) < desired and draws < max_draws:
#         value = strategy.example()
#         draws += 1
#         if value in candidates:
#             continue
#         record_combination(value)
#         candidates.append(value)
#     candidates.sort()


# def prompt_user_label(value: int) -> Union[bool, str, None]:
#     """Ask the user to classify the value; 'unsure' leaves predicates unchanged."""
#     while True:
#         response = input(
#             f"Is {value} valid for the unknown predicate? [y/n/u/q]: "
#         ).strip().lower()
#         if response in {"y", "yes", "valid", "v", "t", "true"}:
#             return True
#         if response in {"n", "no", "invalid", "f", "false"}:
#             return False
#         if response in {"u", "unsure", "skip", "s"}:
#             print("Skipping this value; keeping current predicate set.")
#             return "skip"
#         if response in {"q", "quit", "exit"}:
#             print("Stopping interactive classification.")
#             return None
#         print("Please answer with 'y', 'n', 'u', or 'q'.")


# def summarize_round(
#     value: int,
#     user_label: bool,
#     eliminated: List[str],
#     remaining: Set[str],
#  ) -> None:
#     """Print a compact summary of the classification round."""
#     status_rows = []
#     for name in sorted(eliminated):
#         status_rows.append({
#             "predicate": name,
#             "status": "eliminated",
#         })
#     for name in sorted(remaining):
#         status_rows.append({
#             "predicate": name,
#             "status": "still possible",
#         })
#     print(f"\nAfter labeling {value} as {'valid' if user_label else 'invalid'}:")
#     print(render_ascii_table(["predicate", "status"], status_rows))


# def interactive_predicate_narrowing(max_rounds: int = 25) -> None:
#     """Interactive session to identify the correct predicate via user labels."""
#     remaining: Set[str] = set(PREDICATE_NAMES)
#     candidates = collect_candidate_values()
#     used_values: Set[int] = set()
#     round_index = 0
#     if not candidates:
#         expand_candidate_pool(candidates)
#     while len(remaining) > 1 and round_index < max_rounds:
#         round_index += 1
#         candidate = select_best_candidate(remaining, candidates, used_values)
#         if candidate is None:
#             expand_candidate_pool(candidates)
#             candidate = select_best_candidate(remaining, candidates, used_values)
#         if candidate is None:
#             print("Unable to find informative candidate; stopping.")
#             break
#         used_values.add(candidate)
#         label = prompt_user_label(candidate)
#         if label is None:
#             break
#         if isinstance(label, str) and label == "skip":
#             continue
#         eliminated = []
#         for name in list(remaining):
#             outcome = NAME_TO_PREDICATE[name](candidate)
#             if outcome != label:
#                 remaining.remove(name)
#                 eliminated.append(name)
#         summarize_round(candidate, bool(label), eliminated, remaining)
#         if not remaining:
#             print("All predicates eliminated; no candidates remain.")
#             return
#     if len(remaining) == 1:
#         sole = next(iter(remaining))
#         print(f"\nLikely predicate identified: {sole}")
#     elif len(remaining) > 1:
#         print(
#             f"\nStopped with {len(remaining)} predicates still possible: {sorted(remaining)}",
#         )

In [11]:
interactive_predicate_narrowing()

NameError: name 'interactive_predicate_narrowing' is not defined