In [None]:
from itertools import permutations
import sys

# Increase recursion limit if necessary
sys.setrecursionlimit(10000)

# Define the taxa
ingroup_taxa = ["A", "B", "C", "D", "E", "F"]
outgroup = "O"


def generate_tree_shapes(n):
    """
    Generate all unique rooted binary tree shapes with n leaves.
    """
    if n == 1:
        return [None]  # Leaf node
    shapes = []
    for left_size in range(1, n):
        right_size = n - left_size
        left_shapes = generate_tree_shapes(left_size)
        right_shapes = generate_tree_shapes(right_size)
        for left in left_shapes:
            for right in right_shapes:
                # To avoid mirror images, enforce an ordering constraint
                if (left_size, left) <= (right_size, right):
                    shapes.append((left, right))
    return shapes


def label_tree(shape, taxa):
    """
    Label the given tree shape with the provided taxa.
    """
    if shape is None:
        # Only one taxon left
        return [taxa[0]]
    else:
        n = len(taxa)
        left_size = count_leaves(shape[0]) if shape[0] else 1
        right_size = count_leaves(shape[1]) if shape[1] else 1
        left_taxa_perms = set(permutations(taxa, left_size))
        labeled_trees = []
        for left_taxa in left_taxa_perms:
            right_taxa = tuple(t for t in taxa if t not in left_taxa)
            if len(right_taxa) != right_size:
                continue
            left_labeled = label_tree(shape[0], left_taxa)
            right_labeled = label_tree(shape[1], right_taxa)
            for lt in left_labeled:
                for rt in right_labeled:
                    tree = (lt, rt)
                    labeled_trees.append(tree)
        return labeled_trees


def count_leaves(shape):
    """
    Count the number of leaves in the tree shape.
    """
    if shape is None:
        return 1
    else:
        return count_leaves(shape[0]) + count_leaves(shape[1])


def canonical_form(tree):
    """
    Generate a canonical form of the tree to eliminate duplicates.
    """
    if isinstance(tree, str):
        return tree
    else:
        left_canonical = canonical_form(tree[0])
        right_canonical = canonical_form(tree[1])
        children = sorted([left_canonical, right_canonical], key=lambda x: str(x))
        return tuple(children)


def tree_to_newick(tree):
    """
    Convert the tree to Newick format.
    """
    if isinstance(tree, str):
        return tree
    else:
        left, right = tree
        return f"({tree_to_newick(left)},{tree_to_newick(right)})"


def attach_outgroup(tree_newick, outgroup):
    """
    Attach the outgroup at the root of the tree.
    """
    return f"({outgroup},{tree_newick})"


def ring_permutations_of_trees(trees):
    n = len(trees)
    pairs = []
    for i in range(n):
        first_tree = trees[i]
        second_tree = trees[(i + 1) % n]
        pairs.append([str(first_tree), str(second_tree)])
    return pairs


# Generate all unique tree shapes with 5 leaves
print("Generating all unique rooted binary tree shapes for ingroup taxa...")
tree_shapes = generate_tree_shapes(len(ingroup_taxa))
print(f"Total number of unique tree shapes: {len(tree_shapes)}")
# Generate all labeled trees
print("Generating all labeled trees...")
labeled_trees_set = set()
for shape in tree_shapes:
    labeled_trees = label_tree(shape, tuple(ingroup_taxa))
    for tree in labeled_trees:
        # Get canonical form to avoid duplicates due to label permutations
        c_tree = canonical_form(tree)
        labeled_trees_set.add(c_tree)

all_trees = list(labeled_trees_set)
print(len(all_trees))
num_trees = len(all_trees)
# Generate ring permutations of the trees in pairs
tree_pairs = ring_permutations_of_trees(all_trees)

f = open(f"{"".join(ingroup_taxa)}.trees", "w")
for pair in tree_pairs:
    pair_string = f"({pair[0].replace("'","")},O);\n({pair[1].replace("'","")},O);\n"
    pair_string = pair_string.replace(" ", "")
    f.write(pair_string)
f.close()

Generating all unique rooted binary tree shapes for ingroup taxa...
Total number of unique tree shapes: 6
Generating all labeled trees...
945


In [5]:
from brancharchitect.partition_set import PartitionSet, Partition
from brancharchitect.jumping_taxa.lattice.matrix_ops import (
    solve_matrix_puzzle,
    generalized_meet_product,
)

def test_one():
    encoding = {"7": 7, "8": 8, "2": 2, "3": 3}
    matrix1 = [
        [
            PartitionSet({(7, 8), (2, 3)}, encoding),
            PartitionSet({(6, 7), (7, 8)}, encoding),
        ],  # Row 0: intersection = {(7,8)}
        [
            PartitionSet({(7, 8), (2, 3)}, encoding),
            PartitionSet({(11, 12), (2, 3)}, encoding),
        ],
    ]
    matrix2 = [
        [
            PartitionSet({(11, 12), (6, 7)}, encoding),
            PartitionSet({(6, 7), (7, 8)}, encoding),
        ],  # Row 0: intersection = {(6,7)}
        [
            PartitionSet({(11, 12), (6, 7)}, encoding),
            PartitionSet({(11, 12), (2, 3)}, encoding),
        ],  # Row 1: intersection = {(11,12)}
    ]

    sols = solve_matrix_puzzle(matrix1, matrix2)
    expected = [
        PartitionSet({(7, 8), (11, 12)}, encoding),
        PartitionSet({(2, 3), (6, 7)}, encoding),
    ]
    for exp in expected:
        assert exp in sols, f"Expected solution {exp} not found in {sols}"
    assert len(sols) == 2, f"Expected 2 solutions, got {len(sols)}"


def test_two():
    # Define an encoding mapping for labels.
    encoding = {
        "X": 1,
        "A1": 2,
        "A2": 3,
        "A3": 4,
        "A4": 5,
        "B1": 6,
        "B2": 7,
        "B3": 8,
        "B4": 9,
    }

    matrix = [
        [
            PartitionSet(
                {Partition((1,)), Partition((2,)), Partition((3,)), Partition((4, 5))},
                encoding,
            ),
            PartitionSet(
                Partition((1,)),
                encoding,
            ),
        ],
        [
            PartitionSet({Partition((6, 7, 8, 9))}, encoding),
            PartitionSet(
                {
                    Partition((2,)),
                    Partition((3,)),
                    Partition((6, 7, 8, 9)),
                    Partition((4, 5)),
                },
                encoding,
            ),
        ],
    ]

    sols = generalized_meet_product(matrix)
    expected = [
        PartitionSet({Partition((2,)),}, encoding),
        PartitionSet({Partition((1,)),}, encoding),
    ]
    print(sols)
    for exp in expected:
        assert exp in sols, f"Expected solution {exp} not found in {sols}"
    assert len(sols) == 2, f"Expected 2 solutions, got {len(sols)}"


# ---------------------------
test_one()
test_two()





⎡ {(2, 3), (7, 8)}     │ {(6, 7), (7, 8)}       ⎤
⎢ ───────────────────────────────────────────── ⎥
⎣ {(2, 3), (7, 8)}     │ {(11, 12), (2, 3)}     ⎦


⎡ {(11, 12), (6, 7)}     │ {(6, 7), (7, 8)}       ⎤
⎢ ─────────────────────────────────────────────── ⎥
⎣ {(11, 12), (6, 7)}     │ {(11, 12), (2, 3)}     ⎦


Intersections for Matrix 1
[{((7, 8))}, {((2, 3))}]


Intersections for Matrix 2
[{((6, 7))}, {((11, 12))}]




⎡ {(1), (2), (3), (4, 5)}   │ {(1)}                            ⎤
⎢ ────────────────────────────────────────────────────────── ⎥
⎣ {(6, 7, 8, 9)}            │ {(2), (3), (4, 5), (6, 7, 8, 9)} ⎦


[[((2)), ((3)), ((4, 5))]]


AssertionError: Expected solution (2) not found in [[((2)), ((3)), ((4, 5))]]

In [3]:
####python
# filepath: /path/to/tests/test_lattice_construction.py
import pytest
from brancharchitect.partition_set import Partition, PartitionSet
from brancharchitect.jumping_taxa.lattice.lattice_edge import LatticeEdge
from brancharchitect.jumping_taxa.lattice.lattice_construction import (
    is_independent_any,
    gather_independent_partitions,
    check_non_subsumption_with_residual,
    check_atomic_inclusion,
    check_independence_conditions,
    pairwise_lattice_analysis
)


def test_is_independent_any():
    """
    Verify that is_independent_any returns True if any of
    the tuple elements are True, otherwise False.
    """
    assert is_independent_any((False, False, True)) is True
    assert is_independent_any((False, False, False)) is False
    assert is_independent_any((True,)) is True


def test_check_non_subsumption_with_residual():
    """
    Checks if:
     - 'primary_set' is NOT a subset of 'comparison_set'
     - 'residual' is non-empty
    Then returns True, else False.
    """
    ps1 = PartitionSet({Partition((1,2))})
    ps2 = PartitionSet({Partition((1,2)), Partition((3,))})
    res = PartitionSet({Partition((99,))})
    # Here, ps1 is a subset of ps2, so it should return False:
    assert not check_non_subsumption_with_residual(ps1, ps2, res)

    # Now ps1b is not a subset of ps2 => True expected
    ps1b = PartitionSet({Partition((1,2)), Partition((4,))})
    assert check_non_subsumption_with_residual(ps1b, ps2, res) is True

    # If the residual is empty => always False
    assert not check_non_subsumption_with_residual(ps1b, ps2, PartitionSet())


def test_check_atomic_inclusion():
    """
    check_atomic_inclusion returns True if:
      1) primary_set is NOT subset of comparison_set
      2) primary_set has exactly 1 element
      3) comparison_set has > 1 element
    """
    # single is a subset => expect False
    single = PartitionSet({Partition((7,))})
    bigger = PartitionSet({Partition((7,)), Partition((8,)), Partition((9,))})
    assert not check_atomic_inclusion(single, bigger)

    # single2 not a subset => True
    single2 = PartitionSet({Partition((10,))})
    assert check_atomic_inclusion(single2, bigger) is True

    # multiple => more than one element => always False
    multiple = PartitionSet({Partition((10,)), Partition((11,))})
    assert not check_atomic_inclusion(multiple, bigger)



def test_gather_independent_partitions():
    """
    gather_independent_partitions(intersection_map, left_minus_right_map, right_minus_left_map)
    => returns a list of dicts with keys "A" and "B" indicating independent sides.
    """
    common_key = frozenset({Partition((1,))})

    intersection_map = {
        common_key: {
            "covet_left": PartitionSet({Partition((1,)), Partition((2,))}),
            "covet_right": PartitionSet({Partition((1,)), Partition((2,)), Partition((3,))}),
            "b-a": PartitionSet({Partition((4,))}),
            "a-b": PartitionSet({Partition((2,))}),
        }
    }
    left_minus_right_map = dict(intersection_map)  # same content for example
    right_minus_left_map = dict(intersection_map)

    results = gather_independent_partitions(
        intersection_map, 
        left_minus_right_map, 
        right_minus_left_map
    )
    # Because of the leftover partitions, we expect an "independent" scenario.
    assert len(results) == 1
    assert "A" in results[0]
    assert "B" in results[0]


def test_pairwise_lattice_analysis():
    """
    pairwise_lattice_analysis(s_edge) -> (intersection_map, left_minus_right_map, right_minus_left_map)
    Check that dictionaries are formed correctly.
    """
    edge = LatticeEdge(
        split=frozenset({Partition((1,2,3))}),
        left_cover=[
            PartitionSet({Partition((1,2))}),
            PartitionSet({Partition((2,3))}),
        ],
        right_cover=[
            PartitionSet({Partition((2,3))}),
            PartitionSet({Partition((3,4))}),
        ],
        child_meet=PartitionSet(),
        left_node=None,
        right_node=None,
    )
    intersection_map, left_minus_right_map, right_minus_left_map = pairwise_lattice_analysis(edge)

    # We just verify the maps are not empty and have expected structure
    assert isinstance(intersection_map, dict)
    assert isinstance(left_minus_right_map, dict)
    assert isinstance(right_minus_left_map, dict)

    key2 = frozenset({Partition((2,))})
    assert key2 in intersection_map, f"Expected {key2} in intersection_map"

    entry = intersection_map[key2]
    for field in ["covet_left", "covet_right", "b-a", "a-b"]:
        assert field in entry

    print("pairwise_lattice_analysis -> OK")


if __name__ == "__main__":
    # Run tests directly
    test_is_independent_any()
    test_check_non_subsumption_with_residual()
    test_check_atomic_inclusion()
    test_gather_independent_partitions()
    test_pairwise_lattice_analysis()
    print("All tests passed (run directly).")

Final independence determination: True


TypeError: LatticeEdge.__init__() missing 5 required positional arguments: 'left_unique_atoms', 'right_unique_atoms', 'left_unique_covet', 'right_unique_covet', and 'look_up'

In [2]:
def test_check_independence_conditions():
    """
    check_independence_conditions(left_dict, right_dict) =>
      (LeftNonSubsump, RightNonSubsump, LeftAtomicIncl, RightAtomicIncl)
    """
    left = {
        "covet_left": PartitionSet({Partition((1, 2, 3))}),  # multi-element partition
        "b-a": PartitionSet({Partition((9,))}),
    }
    right = {
        "covet_right": PartitionSet({Partition((1, 2, 3)), Partition((4,))}),
        "a-b": PartitionSet(),  # empty => means no leftover from right
    }

    # Extract partition sets
    left_arm = left.get("covet_left", PartitionSet())
    right_arm = right.get("covet_right", PartitionSet())

    # Condition 1: Left partition is not fully contained
    # and has a non-empty right-side residual
    left_non_subsumption = check_non_subsumption_with_residual(
        primary_set=left_arm,
        comparison_set=right_arm,
        residual=left.get("b-a", PartitionSet()),
    )

    print(left.get("b-a", PartitionSet()))

    # Condition 2: Right partition is not fully contained
    # and has a non-empty left-side residual
    right_non_subsumption = check_non_subsumption_with_residual(
        primary_set=right_arm,
        comparison_set=left_arm,
        residual=right.get("a-b", PartitionSet()),
    )

    # Condition 3: Left atomic inclusion
    # (single element set not subset of a larger set)
    left_atomic_inclusion = check_atomic_inclusion(
        primary_set=left_arm, comparison_set=right_arm
    )

    # Condition 4: Right atomic inclusion
    # (single element set not subset of a larger set)
    right_atomic_inclusion = check_atomic_inclusion(
        primary_set=right_arm, comparison_set=left_arm
    )

    results = (
        left_non_subsumption,
        right_non_subsumption,
        left_atomic_inclusion,
        right_atomic_inclusion,
    )

    # We expect => (True, False, False, False)
    # Explanation:
    #  - Left non-subsumption => True if left is not subset and has a non-empty residual
    #  - Right non-subsumption => probably False (since right's leftover is empty)
    #  - Left atomic => False (the left partition has more than 1 element)
    #  - Right atomic => False (the right partition also has more than 1 element)
    print(results)
    assert results == (True, False, False, False)


test_check_independence_conditions()

NameError: name 'PartitionSet' is not defined