## Ranks and topics: hSBM results on protein families

In [80]:
import pandas as pd
import math
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
import tqdm as tqdm
import seaborn as sns
from bisect import bisect_left, bisect_right
from collections import defaultdict
import os
import sys
import networkx as nx

In [2]:
taxid_to_name_path = '../taxonomy_ncbi/taxid_to_name.tsv'
taxid_to_parent_path = '../taxonomy_ncbi/taxid_to_parent.names.ranks.tsv'

tid2nm = {}
with open(taxid_to_name_path, 'r') as f:
	for line in f:
		tid, name = line.strip().split('\t')
		tid2nm[int(tid)] = name
                
nm2tid = {v: k for k, v in tid2nm.items()}

________________

### Count matrix

In [4]:
countmatrix_path = '../uniprotkb_reference_proteomes/big_bacteria_countmatrix.csv'
df_cm = pd.read_csv(countmatrix_path, index_col=0)
df_cm

Unnamed: 0_level_0,PF00002,PF00004,PF00005,PF00006,PF00008,PF00009,PF00010,PF00011,PF00012,PF00013,...,PF25791,PF25792,PF25794,PF25796,PF25799,PF25800,PF25815,PF25816,PF25818,PF25819
species,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
UP000000212_1234679,0,6,95,5,0,9,0,0,2,2,...,0,0,0,0,0,0,0,0,0,0
UP000000231_312153,0,8,30,3,0,7,0,0,2,1,...,0,0,0,0,0,1,0,0,0,0
UP000000233_379731,0,13,58,4,0,9,0,2,3,1,...,0,0,0,0,0,2,0,0,0,0
UP000000235_369723,0,9,96,3,0,8,0,1,2,2,...,0,0,0,0,0,0,0,0,0,0
UP000000238_349521,0,13,91,9,0,9,0,2,5,1,...,0,0,0,0,0,1,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
UP001268683_1456591,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
UP001288380_3079793,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
UP001289135_1332059,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
UP001296776_85075,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [5]:
idx = [x.split('_')[1] for x in df_cm.index]
df_cm.index = idx
df_cm

Unnamed: 0,PF00002,PF00004,PF00005,PF00006,PF00008,PF00009,PF00010,PF00011,PF00012,PF00013,...,PF25791,PF25792,PF25794,PF25796,PF25799,PF25800,PF25815,PF25816,PF25818,PF25819
1234679,0,6,95,5,0,9,0,0,2,2,...,0,0,0,0,0,0,0,0,0,0
312153,0,8,30,3,0,7,0,0,2,1,...,0,0,0,0,0,1,0,0,0,0
379731,0,13,58,4,0,9,0,2,3,1,...,0,0,0,0,0,2,0,0,0,0
369723,0,9,96,3,0,8,0,1,2,2,...,0,0,0,0,0,0,0,0,0,0
349521,0,13,91,9,0,9,0,2,5,1,...,0,0,0,0,0,1,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1456591,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3079793,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1332059,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
85075,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


_____________

In [None]:
tree = nx.Graph()
tree_dir = nx.DiGraph()

with open(taxid_to_parent_path, 'r') as f:
    for line in f:
        node_tid, node_rank, node_name, parent_tid, parent_rank, parent_name = line.rstrip('\n').split('\t')

        # normalize ids (keep as strings or cast to int consistently)
        node_tid = int(node_tid)
        parent_tid = int(parent_tid)

        # add nodes with attributes
        tree.add_node(node_tid, name=node_name.strip(), rank=node_rank.strip())
        tree_dir.add_node(node_tid, name=node_name.strip(), rank=node_rank.strip())
        tree.add_node(parent_tid, name=parent_name.strip(), rank=parent_rank.strip())
        tree_dir.add_node(parent_tid, name=parent_name.strip(), rank=parent_rank.strip())

        # parent -> child
        tree.add_edge(parent_tid, node_tid)
        tree_dir.add_edge(parent_tid, node_tid)

tree.remove_edge(1, 1)
tree_dir.remove_edge(1, 1)

print(f'Number of connected components in the tree: {nx.number_connected_components(tree)}')
print(f'Number of weakly connected components in the directed tree: {nx.number_weakly_connected_components(tree_dir)}')

print(f'Number of nodes in the tree: {tree.number_of_nodes()}')
print(f'Number of edges in the tree: {tree.number_of_edges()}')


Number of connected components in the tree: 1
Number of weakly connected components in the directed tree: 1
Number of nodes in the tree: 2703951
Number of edges in the tree: 2703950


### hSBM results

In [66]:
lvl = 0
clusters_path = f'./results/topsbm_level_{str(lvl)}_clusters.csv'

In [67]:
clusters_df = pd.read_csv(clusters_path)
n_clusters = clusters_df.shape[1]
print(f'Number of clusters at level {lvl}: {n_clusters}')

Number of clusters at level 0: 48


In [68]:
clusters_df

Unnamed: 0,Cluster 1,Cluster 2,Cluster 3,Cluster 4,Cluster 5,Cluster 6,Cluster 7,Cluster 8,Cluster 9,Cluster 10,...,Cluster 39,Cluster 40,Cluster 41,Cluster 42,Cluster 43,Cluster 44,Cluster 45,Cluster 46,Cluster 47,Cluster 48
0,UP000001136_367737,UP000000813_176299,UP000014118_397291,UP000001115_316278,UP000016648_1115809,UP000007809_675635,UP000008871_393595,UP000000640_196162,UP000012063_1293054,UP000004263_207949,...,UP000001902_591001,UP000003195_706434,UP000005439_679936,UP000007332_1071400,UP000015503_1245471,UP000007113_682795,UP000014154_1235796,UP000005286_879305,UP000002415_381764,UP000002480_272947
1,UP000002495_235279,UP000002040_137722,UP000003438_411471,UP000001961_64471,UP000006657_709991,UP000004245_525370,UP000014568_421052,UP000003868_312284,UP000006427_469381,UP000006286_930169,...,UP000016640_1321779,UP000010802_1209989,UP000000323_525904,UP000007271_1185325,UP000017050_1203578,UP000006844_401053,UP000012651_997872,UP000005984_525254,UP000004793_511051,UP000007794_291272
2,UP000002043_638303,UP000005258_1110502,UP000014129_1235835,UP000000788_93059,UP000004295_553175,UP000001584_83332,UP000003511_1055526,UP000008229_469383,UP000006804_688269,UP000006282_385025,...,UP000007093_568816,UP000003277_742743,UP000004508_485913,UP000001285_714313,UP000002515_223283,UP000019151_861299,UP000003056_443342,UP000003422_997350,UP000006889_632518,UP000001174_177416
3,UP000005981_392423,UP000018524_1287292,UP000014134_1235800,UP000010379_195253,UP000004394_862515,UP000001419_233413,UP000001231_523791,UP000001116_266940,UP000006866_572479,UP000017838_1408164,...,UP000003503_888062,UP000000272_555079,,UP000001279_1069534,UP000001062_717774,UP000002432_204669,UP000005947_525256,UP000005244_796937,UP000002382_521045,UP000018554_1073362
4,UP000007803_563040,UP000001054_394,UP000014150_397290,UP000001420_167539,UP000006545_879243,UP000005143_1097667,UP000007477_871585,UP000008221_351607,UP000005096_584708,UP000002677_83406,...,UP000002072_519441,UP000010880_748449,,UP000001259_390333,UP000000233_379731,UP000002207_240015,UP000004830_742742,UP000003280_862517,UP000007719_515635,UP000004116_1005043
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
73,,,,,,,,,,,...,,,,UP000000432_220668,,,,,,
74,,,,,,,,,,,...,,,,UP000005147_883112,,,,,,
75,,,,,,,,,,,...,,,,UP000019050_592010,,,,,,
76,,,,,,,,,,,...,,,,UP000005388_764291,,,,,,


In [69]:
for col in clusters_df.columns:
    # populate cluster with taxid
    clusters_df[col] = clusters_df[col].apply(lambda x: int(x.split('_')[1]) if pd.notnull(x) else x)
    
clusters_df

Unnamed: 0,Cluster 1,Cluster 2,Cluster 3,Cluster 4,Cluster 5,Cluster 6,Cluster 7,Cluster 8,Cluster 9,Cluster 10,...,Cluster 39,Cluster 40,Cluster 41,Cluster 42,Cluster 43,Cluster 44,Cluster 45,Cluster 46,Cluster 47,Cluster 48
0,367737.0,176299.0,397291.0,316278.0,1115809.0,675635.0,393595.0,196162.0,1293054.0,207949.0,...,591001.0,706434.0,679936.0,1071400,1245471.0,682795.0,1235796.0,879305.0,381764.0,272947.0
1,235279.0,137722.0,411471.0,64471.0,709991.0,525370.0,421052.0,312284.0,469381.0,930169.0,...,1321779.0,1209989.0,525904.0,1185325,1203578.0,401053.0,997872.0,525254.0,511051.0,291272.0
2,638303.0,1110502.0,1235835.0,93059.0,553175.0,83332.0,1055526.0,469383.0,688269.0,385025.0,...,568816.0,742743.0,485913.0,714313,223283.0,861299.0,443342.0,997350.0,632518.0,177416.0
3,392423.0,1287292.0,1235800.0,195253.0,862515.0,233413.0,523791.0,266940.0,572479.0,1408164.0,...,888062.0,555079.0,,1069534,717774.0,204669.0,525256.0,796937.0,521045.0,1073362.0
4,563040.0,394.0,397290.0,167539.0,879243.0,1097667.0,871585.0,351607.0,584708.0,83406.0,...,519441.0,748449.0,,390333,379731.0,240015.0,742742.0,862517.0,515635.0,1005043.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
73,,,,,,,,,,,...,,,,220668,,,,,,
74,,,,,,,,,,,...,,,,883112,,,,,,
75,,,,,,,,,,,...,,,,592010,,,,,,
76,,,,,,,,,,,...,,,,764291,,,,,,


In [70]:
clusters = {}
for i in range(n_clusters):
    clusters[i] = clusters_df.iloc[:, i].dropna().tolist()

________________________________

In [71]:
def mrca(tree_dir: nx.DiGraph, nodes, *, verbose=True):
    # --- normalize labels like 795666.0 -> 795666
    def _normalize(u):
        if isinstance(u, float) and u.is_integer():
            return int(u)
        return u

    nodes = [ _normalize(u) for u in nodes ]
    if not nodes:
        return None

    # Build parent map + find root (assumes arborescence)
    parent = {}
    root = None
    for v in tree_dir.nodes:
        preds = list(tree_dir.predecessors(v))
        if len(preds) > 1:
            raise ValueError(f"Graph is not a tree (node {v} has >1 parent)")
        parent[v] = preds[0] if preds else None
        if not preds:
            root = v
    if root is None:
        raise ValueError("No root found (cycle or no in-degree-0 node)")

    # Depth from root (BFS)
    depth = {root: 0}
    q = deque([root])
    while q:
        a = q.popleft()
        for b in tree_dir.successors(a):
            depth[b] = depth[a] + 1
            q.append(b)

    # Keep only nodes present in this tree
    present = [u for u in nodes if u in parent]
    skipped = len(nodes) - len(present)
    if verbose and skipped:
        print(f"[mrca] Skipping {skipped} node(s) not in the tree.")

    if len(present) < 2:
        return None

    # Ancestor set (including self); empty if node somehow missing
    def anc_set(u):
        s = set()
        while u is not None and u in parent:
            s.add(u)
            u = parent[u]
        return s

    # Intersect ancestor sets
    common = anc_set(present[0])
    for u in present[1:]:
        common &= anc_set(u)
        if not common:
            return None  # different components or no shared ancestor in this tree

    # Deepest common ancestor
    return max(common, key=lambda x: depth.get(x, -1))

def tip_descendants(tree_dir, node):
    tips = set()
    q = deque([node])
    while q:
        u = q.popleft()
        children = list(tree_dir.successors(u))
        if not children:
            tips.add(u)
        else:
            q.extend(children)
    return tips

def cluster_clade_stats(tree_dir, C):
	
	C = set(C)
	
	print(f'\tComputing MRCA...')
	r = mrca(tree_dir, C)
    
	if r is None:
		print(f'\tNo MRCA found!')
		return dict(size=len(C), precision=0.0, recall=0.0, F1=0.0, clade_like=False)

	print(f'\tMRCA found: {r}. Computing descendants...')
	D = tip_descendants(tree_dir, r)
	inter = len(C & D)
	print(f'\tCluster size: {len(C)}. MRCA descendants size: {len(D)}. Intersection size: {inter}.')
	
	if inter == 0:
		print(f'\tNo intersection between cluster and descendants of MRCA!')
		return dict(size=len(C), precision=0.0, recall=0.0, F1=0.0, clade_like=False)
	
	print(f'\tComputing precision, recall, F1...')
	precision = inter / len(D)
	recall = inter / len(C)
	F1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
	
	return dict(size=len(C), precision=precision, recall=recall, F1=F1,
                clade_like=(precision == 1.0 and recall == 1.0), mrca=r, mrca_size=len(D))


cluster_stats = {}
for i in range(n_clusters):
	print(f'Processing cluster {i+1}/{n_clusters} ...')
	cluster_stats[i] = cluster_clade_stats(tree_dir, clusters[i])

Processing cluster 1/48 ...
	Computing MRCA...
	MRCA found: 3379134. Computing descendants...
	Cluster size: 42. MRCA descendants size: 278020. Intersection size: 41.
	Computing precision, recall, F1...
Processing cluster 2/48 ...
	Computing MRCA...
	MRCA found: 28211. Computing descendants...
	Cluster size: 31. MRCA descendants size: 68707. Intersection size: 31.
	Computing precision, recall, F1...
Processing cluster 3/48 ...
	Computing MRCA...
	MRCA found: 1239. Computing descendants...
	Cluster size: 40. MRCA descendants size: 90689. Intersection size: 40.
	Computing precision, recall, F1...
Processing cluster 4/48 ...
	Computing MRCA...
	MRCA found: 3028117. Computing descendants...
	Cluster size: 22. MRCA descendants size: 31831. Intersection size: 22.
	Computing precision, recall, F1...
Processing cluster 5/48 ...
	Computing MRCA...
	MRCA found: 976. Computing descendants...
	Cluster size: 52. MRCA descendants size: 23910. Intersection size: 52.
	Computing precision, recall, F1..

In [None]:
cluster_stats_df = pd.DataFrame.from_dict(cluster_stats, orient='index')
cluster_stats_df
cluster_stats_df['mrca_name'] = cluster_stats_df['mrca'].apply(lambda x: tid2nm[x] if x in tid2nm else 'unknown')

In [73]:
cluster_stats_df

Unnamed: 0,size,precision,recall,F1,clade_like,mrca,mrca_size,mrca_name
0,42,0.000147,0.97619,0.000295,False,3379134,278020,Pseudomonadati
1,31,0.000451,1.0,0.000902,False,28211,68707,Alphaproteobacteria
2,40,0.000441,1.0,0.000882,False,1239,90689,Bacillota
3,22,0.000691,1.0,0.001381,False,3028117,31831,Cyanophyceae
4,52,0.002175,1.0,0.00434,False,976,23910,Bacteroidota
5,35,0.000363,0.971429,0.000727,False,201174,93547,Actinomycetota
6,13,5.5e-05,1.0,0.000111,False,1224,234365,Pseudomonadota
7,39,6.8e-05,1.0,0.000137,False,2,569721,Bacteria
8,13,2.3e-05,1.0,4.6e-05,False,2,569721,Bacteria
9,20,8.5e-05,1.0,0.000171,False,1224,234365,Pseudomonadota


In [74]:
clusters_names = {}
for i in range(n_clusters):
	clusters_names[i] = [tid2nm[tid] if tid in tid2nm else 'unknown' for tid in clusters[i]]

clusters_names

{0: ['Aliarcobacter butzleri RM4018',
  'Helicobacter hepaticus ATCC 51449',
  'Thermocrinis albus DSM 14484',
  'Hydrogenivirga sp. 128-5-R1-1',
  'Sulfurimonas autotrophica DSM 16294',
  'Nitratiruptor sp. SB155-2',
  'Helicobacter canis NCTC 12740',
  'Helicobacter macacae MIT 99-5501',
  'Sulfurimonas hongkongensis',
  'Helicobacter canadensis MIT 98-5491',
  'Helicobacter bilis ATCC 43879',
  'Hippea maritima DSM 10411',
  'Helicobacter fennelliae MRY12-0050',
  'Arcobacter nitrofigilis DSM 7299',
  'Desulfurobacterium thermolithotrophum DSM 11699',
  'Sulfurimonas denitrificans DSM 1251',
  'Sulfurovum sp. AR',
  'Nitratifractor salsuginis DSM 16511',
  'Campylobacter curvus 525.92',
  'Mucispirillum schaedleri ASF457',
  'Campylobacter jejuni subsp. jejuni NCTC 11168 = ATCC 700819',
  'Aquifex aeolicus VF5',
  'Sulfurimonas gotlandica GD1',
  'Chitinivibrio alkaliphilus ACht1',
  'Hydrogenobaculum sp. HO',
  'Borreliella burgdorferi B31',
  'Helicobacter bizzozeronii CIII-1',
  

In [None]:
def index_tree(tree_dir: nx.DiGraph):
    # Find root (in-degree 0)
    roots = [v for v in tree_dir.nodes if tree_dir.in_degree(v) == 0]
    if len(roots) != 1:
        raise ValueError(f"Expected exactly one root, found {len(roots)}")
    root = roots[0]

    # Build parent/children
    parent = {root: None}
    children = {v: [] for v in tree_dir.nodes}
    for u, v in tree_dir.edges:
        parent[v] = u
        children[u].append(v)

    # Iterative DFS to get tin, depth, subtree_size
    tin, depth, subtree_size = {}, {}, {}
    node_by_tin = []
    t = 0
    depth[root] = 0
    stack = [(root, 0)]  # (node, 0=enter, 1=exit)
    while stack:
        v, state = stack.pop()
        if state == 0:
            tin[v] = t
            node_by_tin.append(v)
            t += 1
            stack.append((v, 1))
            # Push children in reverse so leftmost visits first
            for w in reversed(children[v]):
                depth[w] = depth[v] + 1
                stack.append((w, 0))
        else:
            s = 1
            for w in children[v]:
                s += subtree_size[w]
            subtree_size[v] = s

    tout = {v: tin[v] + subtree_size[v] - 1 for v in tree_dir.nodes}
    return {
        "root": root, "parent": parent, "children": children, "depth": depth,
        "tin": tin, "tout": tout, "subtree_size": subtree_size, "n": len(node_by_tin)
    }

def _normalize(u):
    # e.g. 795666.0 -> 795666
    return int(u) if isinstance(u, float) and u.is_integer() else u

def cluster_tree_stats(idx, cluster_nodes, alpha=0.8, size_cap=None):
    """
    Compute robust clade-based summaries for a cluster:
      - best coverage clade (max |C∩subtree(v)| / |C|)
      - best Jaccard clade ( |C∩S| / (|C| + |S| - |C∩S|) ) to penalize huge clades
      - alpha-majority clade: deepest v with |C∩subtree(v)| >= ceil(alpha*|C|)
    Optional: size_cap limits candidate clades to subtree_size(v) <= size_cap.
    """

    tin, tout, depth, parent, subtree_size = (idx["tin"], idx["tout"], idx["depth"], idx["parent"], idx["subtree_size"])

    C_all = list(cluster_nodes)	# list 
    C = [_normalize(u) for u in C_all if u in tin]  # filter to nodes present in this tree
    skipped = len(C_all) - len(C)
    if len(C) == 0:
        return {"n_in_cluster": 0, "n_skipped": skipped}

    # Pre-sort cluster members in preorder
    C_tins = sorted(tin[u] for u in C)

    def count_in_subtree(v):
        a, b = tin[v], tout[v]
        return bisect_right(C_tins, b) - bisect_left(C_tins, a)

    # Candidates = union of ancestors of the cluster (small)
    candidates = set()
    for u in C:
        v = u
        while v is not None and v not in candidates:
            if size_cap is None or subtree_size[v] <= size_cap:
                candidates.add(v)
            v = parent[v]

    best_cov_node, best_cov, best_cov_depth = None, -1, -1
    best_j_node, best_j, best_j_depth = None, -1.0, -1
    thr = math.ceil(alpha * len(C))
    alpha_node, alpha_cov, alpha_depth = None, 0, -1

    for v in candidates:
        cov = count_in_subtree(v)
        # 1) Best coverage (fraction of cluster captured)
        if cov > best_cov or (cov == best_cov and depth[v] > best_cov_depth):
            best_cov, best_cov_node, best_cov_depth = cov, v, depth[v]

        # 2) Best Jaccard (penalizes very big clades)
        j = cov / (len(C) + subtree_size[v] - cov)
        if j > best_j or (j == best_j and depth[v] > best_j_depth):
            best_j, best_j_node, best_j_depth = j, v, depth[v]

        # 3) α-majority clade (deepest node capturing >= α of the cluster)
        if cov >= thr and depth[v] > alpha_depth:
            alpha_node, alpha_cov, alpha_depth = v, cov, depth[v]

    result = {
        "n_in_cluster": len(C),
        "n_skipped": skipped,
        "best_coverage_node": best_cov_node,
        "best_coverage_fraction": best_cov / len(C),
        "best_coverage_subtree_size": subtree_size[best_cov_node] if best_cov_node else None,
        "best_jaccard_node": best_j_node,
        "best_jaccard": best_j,
        "best_jaccard_subtree_size": subtree_size[best_j_node] if best_j_node else None,
        "alpha": alpha,
        "alpha_node": alpha_node,
        "alpha_fraction": (alpha_cov / len(C)) if alpha_node else 0.0,
        "alpha_subtree_size": subtree_size[alpha_node] if alpha_node else None,
        "alpha_outliers": (len(C) - alpha_cov) if alpha_node else len(C) - best_cov,
    }
    return result

def cluster_core_and_outliers(idx, cluster_nodes, chosen_node):
    """Split a cluster into the core inside chosen_node's subtree and the outliers."""
    if chosen_node is None:
        return [], list(cluster_nodes)
    tin, tout = idx["tin"], idx["tout"]
    a, b = tin[chosen_node], tout[chosen_node]
    core, out = [], []
    for u in cluster_nodes:
        u2 = _normalize(u)
        if u2 in tin and a <= tin[u2] <= b:
            core.append(u)
        else:
            out.append(u)
    return core, out

In [92]:
idx = index_tree(tree_dir)

In [117]:
def fmt_node(node, tid2nm):
    """Format a node as 'taxid [name]' if name exists; handle None."""
    if node is None:
        return "None"
    name = tid2nm.get(node)
    return f"{node} [{name}]" if name else str(node)

def summarize_clusters(idx, clusters, tid2nm, alpha=0.75, size_cap=None, show_all=False):
    """
    Print a compact summary per cluster with translated node labels in brackets.
    Set show_all=True to also dump the full stats dict.
    """
    n_clusters = len(clusters)
    for i, C in enumerate(clusters, 1):
        stats = cluster_tree_stats(idx, C, alpha=alpha, size_cap=size_cap)
        rep = stats.get("alpha_node") or stats.get("best_jaccard_node")
        core, outliers = cluster_core_and_outliers(idx, C, rep)

        ac = stats.get  # tiny shortcut

        print(f"Cluster {i}/{n_clusters}")
        print(f"  in_tree={ac('n_in_cluster',0)}  skipped={ac('n_skipped',0)}")

        # Majority (alpha) clade
        print(f"  Majority α={ac('alpha')}: {fmt_node(ac('alpha_node'), tid2nm)}  "
              f"frac={ac('alpha_fraction',0):.3f}  size={ac('alpha_subtree_size')}")

        # Best coverage clade
        print(f"  Best coverage: {fmt_node(ac('best_coverage_node'), tid2nm)}  "
              f"frac={ac('best_coverage_fraction',0):.3f}  "
              f"size={ac('best_coverage_subtree_size')}")

        # Best Jaccard clade
        print(f"  Best Jaccard: {fmt_node(ac('best_jaccard_node'), tid2nm)}  "
              f"J={ac('best_jaccard',0):.3f}  size={ac('best_jaccard_subtree_size')}")

        # Representative clade (alpha if present, else best Jaccard)
        print(f"  Representative clade: {fmt_node(rep, tid2nm)}")

        print(f"  Core={len(core)}  Outliers={len(outliers)}")
        if show_all:
            for k, v in stats.items():
                print(f"    {k}: {v}")
        print()
        
summarize_clusters(idx, clusters.values(), tid2nm, alpha=0.75, size_cap=None, show_all=True)

Cluster 1/48
  in_tree=42  skipped=0
  Majority α=0.75: 3379134 [Pseudomonadati]  frac=1.000  size=290974
  Best coverage: 3379134 [Pseudomonadati]  frac=1.000  size=290974
  Best Jaccard: 64898 [Aquificaceae]  J=0.033  size=114
  Representative clade: 3379134 [Pseudomonadati]
  Core=42  Outliers=0
    n_in_cluster: 42
    n_skipped: 0
    best_coverage_node: 3379134
    best_coverage_fraction: 1.0
    best_coverage_subtree_size: 290974
    best_jaccard_node: 64898
    best_jaccard: 0.033112582781456956
    best_jaccard_subtree_size: 114
    alpha: 0.75
    alpha_node: 3379134
    alpha_fraction: 1.0
    alpha_subtree_size: 290974
    alpha_outliers: 0

Cluster 2/48
  in_tree=31  skipped=0
  Majority α=0.75: 356 [Hyphomicrobiales]  frac=0.871  size=36655
  Best coverage: 28211 [Alphaproteobacteria]  frac=1.000  size=71591
  Best Jaccard: 176299 [Agrobacterium fabrum str. C58]  J=0.032  size=1
  Representative clade: 356 [Hyphomicrobiales]
  Core=27  Outliers=4
    n_in_cluster: 31
    

In [115]:
def cluster_separation(idx, reps):
    """
    Compute how distinct clusters' representative clades are on the tree.

    Returns:
      - overlap_matrix[i,j] = True if subtrees overlap (nested or intersect)
      - distance_matrix[i,j] = edge distance between representative nodes
      - separation_fraction = fraction of cluster pairs that are disjoint
      - mean_interclade_distance = mean distance across distinct clades
    """
    tin, tout, depth, parent = idx["tin"], idx["tout"], idx["depth"], idx["parent"]
    n = len(reps)
    overlap = np.zeros((n, n), dtype=bool)
    dist = np.zeros((n, n), dtype=float)

    # helper to find LCA distance
    def lca_distance(u, v):
        if u is None or v is None:
            return np.nan
        a, b = u, v
        while not (tin[a] <= tin[b] <= tout[a]) and not (tin[b] <= tin[a] <= tout[b]):
            if depth[a] > depth[b]:
                a = parent[a]
            else:
                b = parent[b]
        lca = a if tin[a] <= tin[b] <= tout[a] else b
        return depth[u] + depth[v] - 2 * depth[lca]

    for i in range(n):
        a = reps[i]
        if a is None:
            continue
        for j in range(i + 1, n):
            b = reps[j]
            if b is None:
                continue
            # overlap test
            if (tin[a] <= tin[b] <= tout[a]) or (tin[b] <= tin[a] <= tout[b]):
                overlap[i, j] = overlap[j, i] = True
            # distance
            dist[i, j] = dist[j, i] = lca_distance(a, b)

    # fraction of disjoint (non-overlapping) pairs
    nonoverlap_fraction = 1 - overlap[np.triu_indices(n, 1)].mean()
    # mean interclade distance ignoring NaN
    mean_dist = np.nanmean(dist[np.triu_indices(n, 1)])
    return {
        "overlap_matrix": overlap,
        "distance_matrix": dist,
        "fraction_disjoint": nonoverlap_fraction,
        "mean_interclade_distance": mean_dist,
    }

In [121]:
def summarize_separation(idx, clusters, tid2nm, alpha=0.65):
    """Print global separation stats and a compact matrix of overlaps."""
    reps = []
    for C in clusters:
        stats = cluster_tree_stats(idx, C, alpha=alpha)
        reps.append(stats.get("alpha_node") or stats.get("best_jaccard_node"))

    sep = cluster_separation(idx, reps)
    print(f"\nGlobal cluster separation:")
    print(f"  Fraction of disjoint clades: {sep['fraction_disjoint']:.3f}")
    print(f"  Mean interclade distance: {sep['mean_interclade_distance']:.2f}\n")

    # show which clusters overlap
    n = len(reps)
    print("Pairwise overlap (Y=overlap, .=disjoint):")
    for i in range(n):
        row = ''.join('Y' if sep["overlap_matrix"][i,j] else '.' for j in range(n))
        print(f"Cluster {i+1:>2}: {row}")
        
summarize_separation(idx, clusters.values(), tid2nm, alpha=0.65)


Global cluster separation:
  Fraction of disjoint clades: 0.673
  Mean interclade distance: 4.67

Pairwise overlap (Y=overlap, .=disjoint):
Cluster  1: .Y..Y.Y.YY.YYY..Y.YY.YYY.YYYYY.YYY..YYY...YY..YY
Cluster  2: Y.......Y..YY.....Y.......Y.Y..Y....Y.Y.......YY
Cluster  3: ........Y.Y...................Y...Y.Y.YY....YYY.
Cluster  4: ........Y...........Y...............Y.Y.....Y.Y.
Cluster  5: Y.......Y.......Y.Y.......Y.Y.......Y.Y.......Y.
Cluster  6: .......YY......Y...................YY.Y.....Y.Y.
Cluster  7: Y.......YY.YYY....Y.......Y.Y....Y..Y.Y...Y...YY
Cluster  8: .....Y..Y......Y...................YY.Y.....Y.Y.
Cluster  9: YYYYYYYY.YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY
Cluster 10: Y.....Y.Y..YYY....Y.......Y.Y....Y..Y.Y...Y...YY
Cluster 11: ..Y.....Y.....................Y...Y.Y.YY....YYY.
Cluster 12: YY....Y.YY..YY....YY..Y..YY.Y..Y.Y..YYY...Y...YY
Cluster 13: YY....Y.YY.Y.Y....YY..Y..YY.Y..Y.Y..YYY...Y...YY
Cluster 14: Y.....Y.YY.YY.....Y.......Y.Y....Y..Y.Y...Y...YY
Clust