In [3]:
# Necessary imports
from typing import List, Tuple, Any, Dict

# Import necessary modules from brancharchitect
from brancharchitect.io import read_newick
from brancharchitect.tree import Node
from brancharchitect.jumping_taxa.tree_interpolation import interpolate_tree
from brancharchitect.consensus import create_majority_consensus_tree

# Import necessary modules from brancharchitect

# ---------------------------------------------------------------------------------------------------
# Function Definitions
# ---------------------------------------------------------------------------------------------------


def node_to_list(node: Node) -> Any:
    """
    Convert a tree into a nested list representation.
    """
    if node.children is None:
        node.children = []
    if len(node.children) > 0:
        return [node_to_list(child) for child in node.children]
    else:
        return node.name if node.name is not None else "Unknown"


def flatten_tree_list(tree_list: Any) -> List[str]:
    """
    Flatten a nested tree list into a linear list of taxa names.
    """
    if isinstance(tree_list, list):
        flattened = []
        for element in tree_list:
            flattened.extend(flatten_tree_list(element))
        return flattened
    else:
        return [tree_list]


def generate_subtree_rotations(
    tree_list: Any, memo: Dict[str, List[List[str]]] = None
) -> List[List[str]]:
    """
    Generate permutations by rotating subtrees based on majority consensus groupings.
    At each node, consider only the original and reversed orders of its children.
    """
    if memo is None:
        memo = {}

    tree_str = str(tree_list)
    if tree_str in memo:
        return memo[tree_str]

    # Base case: if it's a leaf node, return it as a list of one permutation
    if not isinstance(tree_list, list):
        permutations = [[tree_list]]
        memo[tree_str] = permutations
        return permutations

    # Recursively generate rotations for each child
    child_permutations_list = []
    for child in tree_list:
        child_permutations = generate_subtree_rotations(child, memo)
        child_permutations_list.append(child_permutations)

    # At this node, consider original and reversed orders of the children
    permutations = []
    for order in [child_permutations_list, child_permutations_list[::-1]]:
        # Combine permutations from child subtrees
        for child_combination in zip(*order):
            flattened = []
            for child_perm in child_combination:
                flattened.extend(child_perm)
            permutations.append(flattened)

    memo[tree_str] = permutations
    return permutations


def find_best_permutation(
    permutation_list: List[List[str]], tree_list: List[Node]
) -> Tuple[List[str], float]:
    """
    Find the permutation that minimizes the total circular distance over the trajectory.
    """
    min_total_distance = float("inf")
    best_permutation = None

    for perm in permutation_list:
        total_distance = 0
        for tree in tree_list:
            original_order = get_taxa_circular_order(tree)
            distance = circular_distance(perm, original_order)
            total_distance += distance
        if total_distance < min_total_distance:
            min_total_distance = total_distance
            best_permutation = perm

    return best_permutation, min_total_distance


def reorder_tree_preserving_splits(tree: Node, permutation: List[str]) -> None:
    """
    Reorder the taxa in the tree according to the given permutation,
    while preserving the split_indices and tree topology.
    """
    taxa_order = {name: idx for idx, name in enumerate(permutation)}

    # Recursive function to reorder children
    def _reorder(node: Node):
        if node.children:
            # Reorder children based on the minimum index of their descendant taxa
            node.children.sort(
                key=lambda child: min(
                    taxa_order[leaf.name] for leaf in child.get_leaves()
                )
            )
            for child in node.children:
                _reorder(child)
        # Do not modify split_indices or other attributes that affect topology

    _reorder(tree)


def optimize_tree_pair(tree1: Node, tree2: Node):
    """
    Optimize the order between a pair of trees by rotating subtrees based on their consensus.
    """
    # Generate consensus tree for the pair
    interpolation_result = interpolate_tree(tree1, tree2)
    intermediate1 = interpolation_result[0]
    c1 = interpolation_result[1]
    c2 = interpolation_result[2]
    intermediate2 = interpolation_result[3]

    # Convert consensus tree to nested list
    pair_consensus_list = node_to_list(c1)  # Using c1 as the consensus

    # Generate permutations by rotating subtrees in the pair consensus tree
    pair_permutation_list = generate_subtree_rotations(pair_consensus_list)

    # Find the best permutation for this pair
    min_distance = float("inf")
    best_pair_permutation = None
    order1 = get_taxa_circular_order(tree1)
    order2 = get_taxa_circular_order(tree2)

    for perm in pair_permutation_list:
        distance = circular_distance(perm, order1) + circular_distance(perm, order2)
        if distance < min_distance:
            min_distance = distance
            best_pair_permutation = perm

    # Apply the best permutation to both trees and intermediates, preserving splits
    reorder_tree_preserving_splits(tree1, best_pair_permutation)
    reorder_tree_preserving_splits(tree2, best_pair_permutation)
    reorder_tree_preserving_splits(intermediate1, best_pair_permutation)
    reorder_tree_preserving_splits(c1, best_pair_permutation)
    reorder_tree_preserving_splits(c2, best_pair_permutation)
    reorder_tree_preserving_splits(intermediate2, best_pair_permutation)


def circular_distances_trees(trees: List[Node]) -> float:
    """
    Compute the average circular distance over a list of trees.
    """
    total_distance = 0.0
    num_pairs = len(trees) - 1
    for i in range(num_pairs):
        order1 = get_taxa_circular_order(trees[i])
        order2 = get_taxa_circular_order(trees[i + 1])
        distance = circular_distance(order1, order2)
        total_distance += distance
        print(f"Distance between tree {i} and tree {i + 1}: {distance}")
    return total_distance / num_pairs if num_pairs > 0 else 0


from brancharchitect.tree_order_optimisation import generate_permutations_on_randomness


# ---------------------------------------------------------------------------------------------------
# Main Code
# ---------------------------------------------------------------------------------------------------
# Read your trees
tree_file_path = "./../data/alltrees_treees_cutted/alltrees.trees_cutted.newick"  # Replace with your actual file path
tree_list = read_newick(tree_file_path)
min_perm = generate_permutations_on_randomness(tree_list)
tree_list = read_newick(tree_file_path, min_perm)
print(f"Loaded {len(tree_list)} trees.")


# Step 1: Compute the majority consensus tree
threshold = 0.99
mc_tree = create_majority_consensus_tree(tree_list, threshold=threshold)
print(f"Computed majority consensus tree with threshold {threshold}.")

# Convert consensus tree to nested list
consensus_tree_list = node_to_list(mc_tree)
# Step 2: Generate permutations by rotating subtrees
permutations = generate_subtree_rotations(consensus_tree_list)
print(f"Generated {len(permutations)} permutations.")

# Step 3: Find and apply the best permutation to the trajectory
best_permutation, min_total_distance = find_best_permutation(permutations, tree_list)
average_distance_before = min_total_distance / (len(tree_list) - 1)
print(f"\nBest permutation found with average distance {average_distance_before}:")
print(best_permutation)

# Apply the best permutation to all trees, including intermediates, preserving splits
for tree in tree_list:
    reorder_tree_preserving_splits(tree, best_permutation)

# Step 4: Optimize between pairs of trees, including intermediates
for i in range(len(tree_list) - 1):
    tree1 = tree_list[i]
    tree2 = tree_list[i + 1]
    # Optimize the order of the trees and intermediates
    optimize_tree_pair(tree1, tree2)

# Step 5: Compute the circular distances between consecutive trees after optimization
print("\nCircular distances between consecutive trees after optimization:")
average_distance_after = circular_distances_trees(tree_list)
print(
    f"\nAverage circular distance over the trajectory after optimization: {average_distance_after}"
)

Distance between tree 0 and tree 1: 0.0
Distance between tree 1 and tree 2: 0.0
Distance between tree 2 and tree 3: 0.0
Distance between tree 3 and tree 4: 0.0
Distance between tree 4 and tree 5: 0.0
Distance between tree 5 and tree 6: 0.0
Distance between tree 6 and tree 7: 0.0
Distance between tree 7 and tree 8: 0.0
Distance between tree 8 and tree 9: 0.0
Distance between tree 9 and tree 10: 0.0
Distance between tree 10 and tree 11: 0.0
Distance between tree 11 and tree 12: 0.0
Distance between tree 12 and tree 13: 0.0
Distance between tree 13 and tree 14: 0.0
Distance between tree 14 and tree 15: 0.0
Distance between tree 15 and tree 16: 0.0
Distance between tree 16 and tree 17: 0.0
Distance between tree 17 and tree 18: 0.0
Distance between tree 18 and tree 19: 0.0
Distance between tree 19 and tree 20: 0.0
Distance between tree 20 and tree 21: 0.0
Distance between tree 21 and tree 22: 0.0
Distance between tree 22 and tree 23: 0.0
Distance between tree 23 and tree 24: 0.0
Distance be