In [1]:
import networkx as nx
from collections import deque, defaultdict
import hypothesis as hp
from hypothesis import strategies as st
from itertools import combinations
from collections import defaultdict
from typing import List, Optional, Set, Union
import pandas as pd
from pprint import pprint
from hypothesis_pick import (
    find_disagreements,
    find_stronger_weaker,
    infer_implications,
)

# The assignment defines a valid input as a Directed Acyclic Graph (DAG)
# represented by edges[cite: 196, 233].


def topsort_implementation(edges: list[tuple[str, str]]) -> list[str]:
    """
    Implementation of topological sort based on assignment pseudocode.

    Args:
        edges: A list of (u, v) tuples representing directed edges.
    Returns:
        A list of vertices in topological order.
    """
    # Build graph and in-degrees
    adj = defaultdict(list)
    in_degree = defaultdict(int)
    nodes = set()

    for u, v in edges:
        adj[u].append(v)
        nodes.add(u)
        nodes.add(v)
        in_degree[v] += 1
        if u not in in_degree:
            in_degree[u] = 0

    # L = empty list to store ordering [cite: 204]
    L = []

    # S = set of vertices with no incoming edges [cite: 205]
    # We sort S to make the output deterministic for testing, though not strictly required by alg.
    S = sorted([n for n in nodes if in_degree[n] == 0])
    queue = deque(S)

    # while S is not empty [cite: 206]
    while queue:
        # u = vertex removed from S [cite: 207]
        u = queue.popleft()
        # append u to L [cite: 207]
        L.append(u)

        # for each vertex v where edge e=(u,v) in g [cite: 208]
        for v in adj[u]:
            # remove e from g (simulated by decrementing in-degree) [cite: 209]
            in_degree[v] -= 1
            # if v has no other incoming edges [cite: 210]
            if in_degree[v] == 0:
                # insert v in S [cite: 210]
                queue.append(v)

    return L

In [2]:
# Unit tests for topsort_implementation (run this cell)

import unittest

def _is_valid_toposort(edges: list[tuple[str, str]], order: list[str]) -> bool:

    """Checks that `order` is a valid topological ordering for `edges`."""

    nodes = {u for u, _ in edges} | {v for _, v in edges}

    if not nodes:

        return order == []



    if set(order) != nodes or len(order) != len(nodes):

        return False



    pos = {node: i for i, node in enumerate(order)}

    return all(pos[u] < pos[v] for u, v in edges)





class TestToposortImplementation(unittest.TestCase):

    def test_empty_graph(self) -> None:

        self.assertEqual(topsort_implementation([]), [])



    def test_single_edge(self) -> None:

        edges = [("A", "B")]

        order = topsort_implementation(edges)

        self.assertEqual(order, ["A", "B"])  # deterministic for this input

        self.assertTrue(_is_valid_toposort(edges, order))



    def test_diamond_graph(self) -> None:

        # A before B and C; both before D.

        edges = [("A", "B"), ("A", "C"), ("B", "D"), ("C", "D")]

        order = topsort_implementation(edges)

        self.assertEqual(order, ["A", "B", "C", "D"])  # deterministic given edge order

        self.assertTrue(_is_valid_toposort(edges, order))



    def test_disconnected_components(self) -> None:

        edges = [("A", "B"), ("C", "D")]

        order = topsort_implementation(edges)

        self.assertEqual(order, ["A", "C", "B", "D"])  # deterministic given current algorithm

        self.assertTrue(_is_valid_toposort(edges, order))



    def test_multiple_sources_single_sink(self) -> None:

        edges = [("A", "D"), ("B", "D"), ("C", "D")]

        order = topsort_implementation(edges)

        self.assertTrue(_is_valid_toposort(edges, order))

        self.assertEqual(order[-1], "D")





suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestToposortImplementation)

runner = unittest.TextTestRunner(verbosity=2)

runner.run(suite)


test_diamond_graph (__main__.TestToposortImplementation.test_diamond_graph) ... ok
test_disconnected_components (__main__.TestToposortImplementation.test_disconnected_components) ... ok
test_empty_graph (__main__.TestToposortImplementation.test_empty_graph) ... ok
test_multiple_sources_single_sink (__main__.TestToposortImplementation.test_multiple_sources_single_sink) ... ok
test_single_edge (__main__.TestToposortImplementation.test_single_edge) ... ok

----------------------------------------------------------------------
Ran 5 tests in 0.006s

OK


<unittest.runner.TextTestResult run=5 errors=0 failures=0>

In [3]:
# Strategy to generate a DAG and a candidate result list
@st.composite
def graph_and_candidate_strategy(draw):
    # 1. Generate a valid DAG
    # We do this by generating nodes and only allowing edges from lower index to higher index
    # to guarantee acyclicity.
    nodes = draw(
        st.lists(
            st.text(alphabet="ABCDE", min_size=1, max_size=2),
            min_size=2,
            max_size=6,
            unique=True,
        )
    )
    nodes = sorted(nodes)

    edges = []
    # Create random forward edges
    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            if draw(st.booleans()):
                edges.append((nodes[i], nodes[j]))

    # 2. Determine the "Ground Truth" using our implementation
    valid_sort = topsort_implementation(edges)

    # 3. Create a candidate output that might be wrong to trigger predicate disagreements
    # We mix: Valid sorts, Alphabetical sorts, and Random shuffles.
    case_type = draw(st.sampled_from(["valid", "lexical", "shuffled", "missing_node"]))

    if case_type == "valid":
        candidate = valid_sort
    elif case_type == "lexical":
        candidate = sorted(nodes)
    elif case_type == "shuffled":
        candidate = valid_sort[:]  # copy
        # We need a deterministic shuffle for reproducibility in PBT,
        # but for this specific setup, st.permutations is cleaner:
        candidate = draw(st.permutations(valid_sort))
    else:  # missing_node
        candidate = valid_sort[:-1] if valid_sort else []

    # The input to our predicates is the Tuple(Edges, Candidate_List)
    return (edges, candidate)


# Update the main strategy variable for the PICK system
strategy = graph_and_candidate_strategy()

In [4]:
GraphInput = tuple[List[tuple[str, str]], List[str]]


def p1_is_valid_toposort(x: GraphInput) -> bool:
    """True if candidate respects all edge dependencies and contains all nodes."""
    edges, candidate = x

    # Check 1: Are all nodes present?
    graph_nodes = set(u for u, v in edges) | set(v for u, v in edges)
    if set(candidate) != graph_nodes or len(candidate) != len(graph_nodes):
        return False

    # Check 2: Are edge constraints respected? (u must appear before v) [cite: 286]
    # Create a map of node -> index for O(1) lookups
    position = {node: i for i, node in enumerate(candidate)}

    for u, v in edges:
        # If u or v isn't in candidate, it fails (caught by Check 1 usually, but safe to check)
        if u not in position or v not in position:
            return False
        if position[u] > position[v]:
            return False

    return True


def p2_is_permutation(x: GraphInput) -> bool:
    """True if candidate contains exactly the set of nodes in the graph (ignores order)."""
    edges, candidate = x
    graph_nodes = set(u for u, v in edges) | set(v for u, v in edges)
    return set(candidate) == graph_nodes and len(candidate) == len(graph_nodes)


def p3_is_lexical(x: GraphInput) -> bool:
    """True if candidate is sorted alphabetically (ignores graph structure)."""
    edges, candidate = x
    # Just checks if the list is sorted
    return candidate == sorted(candidate) and len(candidate) > 0


def p4_valid_source_sink(x: GraphInput) -> bool:
    """True if the first element is a valid source and last is a valid sink."""
    edges, candidate = x
    if not candidate:
        return False

    # Build degrees
    in_degree = defaultdict(int)
    out_degree = defaultdict(int)
    nodes = set()
    for u, v in edges:
        in_degree[v] += 1
        out_degree[u] += 1
        nodes.add(u)
        nodes.add(v)

    first = candidate[0]
    last = candidate[-1]

    # First node must have 0 in-degree (Source) [cite: 295]
    is_source = in_degree[first] == 0
    # Last node must have 0 out-degree (Sink)
    is_sink = out_degree[last] == 0

    return is_source and is_sink

In [5]:
PREDICATES = [
    p1_is_valid_toposort,
    p2_is_permutation,
    p3_is_lexical,
    p4_valid_source_sink,
]
PREDICATE_NAMES = [
    "Valid Toposort",
    "Is Permutation",
    "Is Alphabetical",
    "Valid Source/Sink",
]

NAME_TO_PREDICATE = dict(zip(PREDICATE_NAMES, PREDICATES))

assert len(PREDICATES) == len(PREDICATE_NAMES)

# Reset the cache for the new problem domain
combination_examples = defaultdict(list)

In [6]:
impl = infer_implications(
    predicates=PREDICATES,
    strategy=strategy,
    max_examples=1_000,
    predicate_names=PREDICATE_NAMES,
)


def classify_relationship(name_a: str, name_b: str) -> str:
    if impl.equivalent(name_a, name_b):
        return "equivalent"
    if impl.implies(name_a, name_b):
        return "subset"
    if impl.implies(name_b, name_a):
        return "superset"
    predicate_a = NAME_TO_PREDICATE[name_a]
    predicate_b = NAME_TO_PREDICATE[name_b]
    for candidate in range(0, 2_000):
        if predicate_a(candidate) and predicate_b(candidate):
            return "overlap"
    return "disjoint"


def build_pair_entry(name_a: str, name_b: str) -> dict:
    return {
        "relation": classify_relationship(name_a, name_b),
    }


pair_relationships = {
    (name_a, name_b): build_pair_entry(name_a, name_b)
    for i, name_a in enumerate(PREDICATE_NAMES)
    for name_b in PREDICATE_NAMES[i + 1 :]
}

combination_examples = defaultdict(list)


def record_combination(value: int) -> None:
    key = tuple(bool(NAME_TO_PREDICATE[name](value)) for name in PREDICATE_NAMES)
    bucket = combination_examples[key]
    if len(bucket) < MAX_EXAMPLES_PER_COMBINATION and value not in bucket:
        bucket.append(value)


def combinations_saturated() -> bool:
    if not combination_examples:
        return False
    return all(
        len(bucket) >= MAX_EXAMPLES_PER_COMBINATION
        for bucket in combination_examples.values()
    )


# I believe this is doing duplicate work to Sid's library,
# but I couldn't get its example cacheing to work, so reimplementing here.
# TODO: Fix that and use it instead.
def populate_combination_examples(max_draws: int = 5_000) -> None:
    seen_values = set()
    draws = 0
    while draws < max_draws:
        value = strategy.example()
        draws += 1
        if value in seen_values:
            continue
        seen_values.add(value)
        record_combination(value)
        if combinations_saturated():
            break


populate_combination_examples()

FailedHealthCheck: Input generation is slow: Hypothesis only generated 2 valid inputs after 4.72 seconds.

      count | fraction |    slowest draws (seconds)
  x |    2  |    100%  |      --      --      --      --   4.723

This could be for a few reasons:
1. This strategy could be generating too much data per input. Try decreasing the amount of data generated, for example by decreasing the minimum size of collection strategies like st.lists().
2. Some other expensive computation could be running during input generation. For example, if @st.composite or st.data() is interspersed with an expensive computation, HealthCheck.too_slow is likely to trigger. If this computation is unrelated to input generation, move it elsewhere. Otherwise, try making it more efficient, or disable this health check if that is not possible.

If you expect input generation to take this long, you can disable this health check with @settings(suppress_health_check=[HealthCheck.too_slow]). See https://hypothesis.readthedocs.io/en/latest/reference/api.html#hypothesis.HealthCheck for details.