In [2]:
import os
from ete3 import Tree
from collections import defaultdict

def compute_purity(sizes):
    """
    Compute purity = sum_i (|S_i| / n)^2
    where n is the total number of elements in the partition (for one label).
    """
    n = sum(sizes)
    if n == 0:
        return 0.0
    return sum((s / n) ** 2 for s in sizes)

def get_label_dict(tree):
    """
    Map each leaf node to its label (full leaf.name if no underscore, else prefix before underscore).
    """
    return {leaf: leaf.name.split("_")[0] for leaf in tree.iter_leaves()}

def find_maximal_subtrees(tree, label_dict, target_label):
    """
    Find maximal subtrees where all leaves share the given target_label.
    Returns a list of sets of leaf node objects.
    """
    partitions = []

    def is_pure(node):
        return {label_dict[leaf] for leaf in node.get_leaves()} == {target_label}

    for node in tree.traverse("postorder"):
        if node.is_leaf():
            continue
        if is_pure(node):
            parent = node.up
            if parent is None or not is_pure(parent):
                partitions.append(set(node.get_leaves()))
    return partitions

def compute_label_purities(tree):
    """
    For each unique label in the tree:
    - Identify maximal pure subtrees
    - Add singleton partitions for any leaves not covered
    - Compute purity and return sizes
    
    Returns:
        dict: {label: (purity, [partition sizes])}
    """
    label_dict = get_label_dict(tree)
    all_labels = set(label_dict.values())
    purity_dict = {}

    for label in all_labels:
        # All leaf nodes bearing this label
        leaves_with_label = {leaf for leaf, lab in label_dict.items() if lab == label}

        # Maximal pure subtrees
        partitions = find_maximal_subtrees(tree, label_dict, label)
        covered = set().union(*partitions) if partitions else set()

        # Singletons for uncovered leaves
        for leaf in leaves_with_label - covered:
            partitions.append({leaf})

        sizes = [len(part) for part in partitions]
        purity = compute_purity(sizes)
        purity_dict[label] = (purity, sizes)

    return purity_dict

def process_folder(folder_path, extensions=(".nwk", ".txt")):
    """
    Process all tree files in a folder and print per-label purities and partition sizes.
    """
    for fname in os.listdir(folder_path):
        if not fname.endswith(extensions):
            continue
        try:
            tree = Tree(os.path.join(folder_path, fname), format=1)
            label_data = compute_label_purities(tree)
            avg = sum(p for p, _ in label_data.values()) / len(label_data) if label_data else 0.0

            print(f"\n{fname}:")
            for label, (purity, sizes) in sorted(label_data.items()):
                print(f"  {label}: purity = {purity:.4f}, sizes = {sizes}")
            print(f"  → Average purity: {avg:.4f}")

        except Exception as e:
            print(f"[!] Error processing {fname}: {e}")

# Example usage:
if __name__ == "__main__":
    trees = [
        'sarscov2_tree2',
        'rhinovirus_tree2',
        'mammalianMT_tree2',
        ]
    for test in trees:
        folder = f"./trees_purity/{test}"
        print(test)
        process_folder(folder)


sarscov2_tree2

sarscov2_NVM_tree.nwk:
  A: purity = 0.2778, sizes = [2, 2, 1, 1]
  B: purity = 0.4400, sizes = [3, 1, 1]
  D: purity = 0.2778, sizes = [2, 2, 1, 1]
  G: purity = 0.2778, sizes = [2, 2, 1, 1]
  GH: purity = 0.4400, sizes = [3, 1, 1]
  L: purity = 0.3750, sizes = [2, 1, 1]
  M: purity = 0.7222, sizes = [5, 1]
  O: purity = 0.1667, sizes = [1, 1, 1, 1, 1, 1]
  → Average purity: 0.3722

sarscov2_MKS_tree.nwk:
  A: purity = 1.0000, sizes = [6]
  B: purity = 1.0000, sizes = [5]
  D: purity = 0.5000, sizes = [4, 1, 1]
  G: purity = 0.7222, sizes = [5, 1]
  GH: purity = 0.6800, sizes = [4, 1]
  L: purity = 1.0000, sizes = [4]
  M: purity = 1.0000, sizes = [6]
  O: purity = 0.7222, sizes = [5, 1]
  → Average purity: 0.8281

sarscov2_FPS_tree.nwk:
  A: purity = 0.2778, sizes = [2, 2, 1, 1]
  B: purity = 0.2800, sizes = [2, 1, 1, 1]
  D: purity = 0.1667, sizes = [1, 1, 1, 1, 1, 1]
  G: purity = 0.1667, sizes = [1, 1, 1, 1, 1, 1]
  GH: purity = 0.2000, sizes = [1, 1, 1, 1, 1]
  L: