In [2]:
import cassiopeia as cas
import numpy as np
import pandas as pd 
import networkx as nx
from tqdm import tqdm
from scipy.cluster.hierarchy import linkage, dendrogram, to_tree
from scipy.spatial.distance import squareform
from cassiopeia.utils import _get_digraph, get_leaves
from skbio import DistanceMatrix
from skbio.tree import nj



In [3]:
binary_sim = cas.sim.CompleteBinarySimulator(depth=15)
full_tree = binary_sim.simulate_tree()
print("Total leaves:", len(full_tree.leaves))  # should be 2**20

Total leaves: 32768


In [4]:
sequential_sim = cas.sim.SequentialLineageTracingDataSimulator(
    number_of_cassettes=13,
    size_of_cassette=5,
    initiation_rate=3.5,
    continuation_rate=3.0,
    state_priors={a:0.0526315789 for a in range(19)},
    heritable_silencing_rate=0.0,
    stochastic_silencing_rate=0.0,
    heritable_missing_data_state=-1,
    stochastic_missing_data_state=-1,
    random_seed=17
)
sequential_sim.overlay_data(full_tree)
sequential_full_charmat = full_tree.character_matrix


In [6]:
nonzero_counts = {}

for idx, column in character_matrix.items():
    nonzero_count = (column != 0).sum()
    nonzero_counts[idx] = nonzero_count

mod5_counts = {}
for mod in range(5):
    mod5_counts[mod] = int((character_matrix.iloc[:, mod::5] != 0).sum().sum())

print(mod5_counts)
total_nonzero = sum(nonzero_counts.values())
numcols = character_matrix.shape[1]
numcells = character_matrix.shape[0]
print(f"numcols: {numcols}")
print(f"numcells: {numcells}")
nonzero_rate_mod = {}
for mod in range(5):
    nonzero_rate_mod[mod] = mod5_counts[mod] / (numcells * numcols /5)

print(f"mod 5 nonzero rates: {nonzero_rate_mod}")
print(f"fraction of edited sites: {total_nonzero / (character_matrix.shape[0] * character_matrix.shape[1])}")


{0: 411518, 1: 348269, 2: 251046, 3: 153321, 4: 81269}
numcols: 65
numcells: 32768
mod 5 nonzero rates: {0: 0.9660409780649039, 1: 0.8175635704627404, 2: 0.5893319936899039, 3: 0.35992196890024036, 4: 0.19077946589543268}
fraction of edited sites: 0.5847275954026442


In [None]:
target_leaves = 3257
subsampler = cas.sim.UniformLeafSubsampler(number_of_leaves=target_leaves)

sequential_3257 = subsampler.subsample_leaves(full_tree)
print("Leaves after subsampling:", len(sequential_3257.leaves))  # should be 3257

Leaves after subsampling: 3257


In [None]:
# % unique leaf barcodes. could make this a function 
leaves = tree_3257.leaves

# leaves = [int(leaf) for leaf in leaves]
leaf_arrays = [tuple  (character_matrix.loc[leaf]) for leaf in leaves]

# Set of unique barcodes
unique_barcodes = set(leaf_arrays)
num_unique = len(unique_barcodes)
num_leaves = len(leaves)

pct_unique_cells = num_unique / num_leaves

print("Unique leaf barcodes:", num_unique, "of", num_leaves,
      f"({pct_unique_cells:.3f})")


Unique leaf barcodes: 3253 of 3257 (0.999)


In [None]:
# Build list of non-root edges (parent, child)
edges = []
root_count = 0
for node in tree_3257.nodes:
    if not tree_3257.is_root(node):
        parent = tree_3257.parent(node)
        edges.append((parent, node))
    else:
        root_count += 1

# Debug: check tree structure
print(f"Total nodes: {len(tree_3257.nodes)}")
print(f"Root nodes: {root_count}")
print(f"Leaves: {len(tree_3257.leaves)}")
print(f"Edges (non-root nodes): {len(edges)}")
print(f"Expected edges: {len(tree_3257.nodes) - 1}")

# Helper: does edge carry at least one mutation?
# Here: mutation = any character differs between parent and child
# (you can tweak this to ignore -1, etc.)
informative_flags = []
for parent, child in edges:
    # Use get_character_states() to get states for any node (internal or leaf)
    parent_state = np.array(tree_3257.get_character_states(parent))
    child_state = np.array(tree_3257.get_character_states(child))
    # boolean: is there any position that differs?
    informative = np.any(parent_state != child_state)
    informative_flags.append(informative)

num_informative = np.sum(informative_flags)
num_edges = len(edges)
pct_informative_edges = num_informative / num_edges

print("Informative edges:", num_informative, "of", num_edges,
      f"({pct_informative_edges:.3f} fraction)")


Informative edges: 6117 of 6513 (0.939 fraction)


In [52]:
char_mat_3257 = tree_3257.character_matrix
print(char_mat_3257)

       0   1   2   3   4   5   6   7   8   9   ...  55  56  57  58  59  60  \
32770  10  12   9   2   0  14   5   0   0   0  ...   3  18  17   0   0  18   
32777  10  12   9   2   0  14   5   0   0   0  ...   3  18   0   0   0  18   
32785  10  12   9   2   4  14   5   0   0   0  ...   3  18   0   0   0  18   
32787  10  12   9   2   4  14   5   5   0   0  ...   3  18   0   0   0  18   
32793  10  12   9   2   4  14   5   2  17   0  ...   3  18   8  14   1  18   
...    ..  ..  ..  ..  ..  ..  ..  ..  ..  ..  ...  ..  ..  ..  ..  ..  ..   
65496  10  16   0   0   0  14  13  11   3   9  ...  16   9   0   0   0  11   
65498  10  16   0   0   0  14  13  11   3   9  ...  16   9   0   0   0  11   
65512  10  16  10   0   0  14  13  11   3   9  ...  16   9   0   0   0  11   
65519  10  16   0   0   0  14  13  11   3   9  ...  16   9   0   0   0  11   
65529  10  16   2   0   0  14  13  11   3   9  ...  16   9   7   9   0  11   

       61  62  63  64  
32770  11  16  11  12  
32777  11  16  

In [None]:
def build_prefix_distance_df(character_matrix, tape_structure):
    """
    df: (cells × sites) DataFrame, entries = edit ID or 0 
    (0 = no edit)
    tape_structure: list of lists, e.g.,
        [[site_1, site_2, site_3], [site_4, site_5]] 
        grouping columns into separate tapes.

    Returns: distance matrix (DataFrame)
    """
    cell_ids = character_matrix.index.tolist()
    n = len(cell_ids)
    max_shared = sum(len(tape) for tape in tape_structure)
    
    # Convert DataFrame to numpy array once 
    # Map column names to indices
    col_to_idx = {col: idx for idx, col in enumerate(character_matrix.columns)}
    
    # Convert tape_structure from column names to indices
    tape_indices = [[col_to_idx[col] for col in tape] for tape in tape_structure]
    
    # Get values as numpy array (row-major: each row is a cell, each column is a site)
    values = character_matrix.values  # shape: (n_cells, n_sites)

    D = np.zeros((n, n))

    for i in tqdm(range(n), desc="Computing distances"):
        for j in range(i + 1, n):
            shared = 0
            for tape_idx in tape_indices:
                prefix_valid = True
                for col_idx in tape_idx:
                    a = values[i, col_idx]
                    b = values[j, col_idx]

                    if a == b and a != 0:
                        if prefix_valid:
                            shared += 1
                    else:
                        prefix_valid = False

            D[i, j] = D[j, i] = max_shared - shared

    result = pd.DataFrame(D, index=cell_ids, columns=cell_ids)
    return result
    
    


In [62]:
tape_structure = [[n, n+1, n+2, n+3, n+4] for n in range(0, 61, 5)]
distance_matrix = build_prefix_distance_df(char_mat_3257, tape_structure)

[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24], [25, 26, 27, 28, 29], [30, 31, 32, 33, 34], [35, 36, 37, 38, 39], [40, 41, 42, 43, 44], [45, 46, 47, 48, 49], [50, 51, 52, 53, 54], [55, 56, 57, 58, 59], [60, 61, 62, 63, 64]]


Computing distances: 100%|██████████| 3257/3257 [00:40<00:00, 80.10it/s] 


In [68]:
# Use scikit-bio's Neighbor Joining

# Convert scikit-bio tree to networkx DiGraph
def skbio_tree_to_networkx(skbio_tree, node_counter=[0]):
    """Convert scikit-bio TreeNode to networkx DiGraph"""
    G = nx.DiGraph()
    
    def traverse(node, parent_name=None):
        # If it's a leaf, use the node name directly
        if node.is_tip():
            leaf_name = node.name
            if parent_name:
                G.add_edge(parent_name, leaf_name, length=node.length or 1.0)
            return leaf_name
        else:
            # Internal node - create a new internal node
            node_counter[0] += 1
            node_name = f"internal_{node_counter[0]}"
            
            if parent_name:
                G.add_edge(parent_name, node_name, length=node.length or 1.0)
            
            # Traverse children
            for child in node.children:
                traverse(child, node_name)
            
            return node_name
    
    # Note: nj_tree_skbio is a TreeNode (the root node), which represents the entire tree
    # In scikit-bio, TreeNode objects are nodes, and you traverse recursively via node.children
    root_name = traverse(skbio_tree)
    return G, root_name


def scikit_nj_reconstruct(distance_matrix, char_mat):
        
    # Get taxa names (cell IDs)
    taxa = distance_matrix.index.tolist()

    # Convert pandas DataFrame to scikit-bio DistanceMatrix
    # scikit-bio expects a 2D array and list of IDs
    dist_array = distance_matrix.values
    skbio_dm = DistanceMatrix(dist_array, taxa)

    # Run Neighbor Joining
    nj_tree_skbio = nj(skbio_dm)

    # Convert scikit-bio tree to networkx
    nj_networkx, root = skbio_tree_to_networkx(nj_tree_skbio)

    # Create CassiopeiaTree from networkx
    scikit_nj_reconstructed = cas.data.CassiopeiaTree(
        character_matrix=char_mat,
        tree=nj_networkx
    )
    print("treecreated")
    return scikit_nj_reconstructed







In [72]:
mytree =scikit_nj_reconstruct(distance_matrix, sequential_3257_charmat)

treecreated


In [None]:
# Run Neighbor Joining using Cassiopeia's fast implementation
# Create CassiopeiaTree with character matrix and distance matrix
nj_tree = cas.data.CassiopeiaTree(
    character_matrix=char_mat_3257,
    dissimilarity_map=distance_matrix
)

# Use Neighbor Joining solver (generic implementation, no CCPhylo required)
nj_solver = cas.solver.NeighborJoiningSolver(
    fast=False,  # Use generic implementation (doesn't require CCPhylo)
    add_root=True
)

# Solve the tree
nj_solver.solve(nj_tree, collapse_mutationless_edges=True)

# Rename for consistency with rest of notebook
upgma_cassiopeia = nj_tree



  self.__dissimilarity_map.loc[other, node] = dissim


KeyboardInterrupt: 

In [None]:
# debug: Compare leaf sets between ground truth tree and reconstructed CassiopeiaTree
print("Comparing leaf sets...")

# Get leaves from the reconstructed CassiopeiaTree (not the raw networkx graph)
reconstructed_leaves = set(upgma_cassiopeia.leaves)

# Get leaves from ground truth tree
ground_truth_leaves = set(tree_3257.leaves)

# DEBUG: Check what we're comparing
print(f"\nDEBUG: Ground truth tree leaves: {len(ground_truth_leaves)}")
print(f"DEBUG: Reconstructed CassiopeiaTree leaves: {len(reconstructed_leaves)}")

# Sample some leaves to check types
gt_sample = list(ground_truth_leaves)[:5]
recon_sample = list(reconstructed_leaves)[:5]
print(f"\nDEBUG: Sample ground truth leaves (first 5): {gt_sample}")
print(f"DEBUG: Types of ground truth leaves: {[type(l).__name__ for l in gt_sample]}")
print(f"\nDEBUG: Sample reconstructed leaves (first 5): {recon_sample}")
print(f"DEBUG: Types of reconstructed leaves: {[type(l).__name__ for l in recon_sample]}")

# Check if types match
gt_types = set(type(l).__name__ for l in ground_truth_leaves)
recon_types = set(type(l).__name__ for l in reconstructed_leaves)
print(f"\nDEBUG: All types in ground truth: {gt_types}")
print(f"DEBUG: All types in reconstructed: {recon_types}")

# Check if sets are identical
if ground_truth_leaves == reconstructed_leaves:
    print("\n✓ Leaf sets are identical!")
else:
    print("\n✗ Leaf sets are NOT identical")
    # Find differences
    
    missing_in_reconstructed = ground_truth_leaves - reconstructed_leaves
    extra_in_reconstructed = reconstructed_leaves - ground_truth_leaves
    common_leaves = ground_truth_leaves & reconstructed_leaves

    print(f"\nCommon leaves: {len(common_leaves)}")
    print(f"Missing in reconstructed (in ground truth but not in reconstructed): {len(missing_in_reconstructed)}")
    print(f"Extra in reconstructed (in reconstructed but not in ground truth): {len(extra_in_reconstructed)}")

    if missing_in_reconstructed:
        missing_list = sorted(list(missing_in_reconstructed))[:10]
        print(f"\nFirst 10 missing leaves in reconstructed: {missing_list}")
        print(f"Types of missing leaves: {[type(l).__name__ for l in missing_list]}")
        # Try converting to see if that's the issue
        if missing_list:
            first_missing = missing_list[0]
            print(f"\nDEBUG: First missing leaf value: {first_missing!r}, type: {type(first_missing)}")
            # Check if it exists as a different type
            if isinstance(first_missing, int):
                str_version = str(first_missing)
                print(f"DEBUG: Checking if '{str_version}' exists in reconstructed: {str_version in reconstructed_leaves}")
            elif isinstance(first_missing, str):
                try:
                    int_version = int(first_missing)
                    print(f"DEBUG: Checking if {int_version} exists in reconstructed: {int_version in reconstructed_leaves}")
                except ValueError:
                    pass
    
    if extra_in_reconstructed:
        extra_list = sorted(list(extra_in_reconstructed))[:10]
        print(f"\nFirst 10 extra leaves in reconstructed: {extra_list}")
        print(f"Types of extra leaves: {[type(l).__name__ for l in extra_list]}")



Comparing leaf sets...

DEBUG: Ground truth tree leaves: 3257
DEBUG: Reconstructed CassiopeiaTree leaves: 3257

DEBUG: Sample ground truth leaves (first 5): ['54519', '42526', '59560', '42514', '53602']
DEBUG: Types of ground truth leaves: ['str', 'str', 'str', 'str', 'str']

DEBUG: Sample reconstructed leaves (first 5): ['54519', '42526', '59560', '42514', '53602']
DEBUG: Types of reconstructed leaves: ['str', 'str', 'str', 'str', 'str']

DEBUG: All types in ground truth: {'str'}
DEBUG: All types in reconstructed: {'str'}

✓ Leaf sets are identical!


In [149]:
def _robinson_foulds_bitset(tree1: nx.DiGraph, tree2: nx.DiGraph):
    """Compute the unrooted Robinson–Foulds distance using bitsets."""
    leaves1 = sorted([n for n in tree1 if tree1.out_degree(n) == 0])
    leaves2 = sorted([n for n in tree2 if tree2.out_degree(n) == 0])
    if set(leaves1) != set(leaves2):
        raise ValueError("Trees must have identical leaf sets.")

    leaf_index = {leaf: i for i, leaf in enumerate(leaves1)}

    def get_splits(tree, leaf_index):
        """Return a set of canonical bitmasks representing bipartitions."""
        topo = list(nx.topological_sort(tree))
        bitset = {}
        # postorder accumulation of leaf bitsets
        for n in reversed(topo):
            if tree.out_degree(n) == 0:
                bitset[n] = (1 << leaf_index[n]) if n in leaf_index else 0
            else:
                m = 0
                for c in tree.successors(n):
                    m |= bitset[c]
                bitset[n] = m

        all_mask = bitset[topo[0]]
        length = all_mask.bit_count()
        # For unrooted splits, each internal edge defines a bipartition.
        # Canonicalize by mapping each side to min(side, complement),
        # so the split is independent of rooting.
        splits = set()
        for _, c in tree.edges:
            bs = bitset[c]
            k = bs.bit_count()
            # exclude trivial: 1 or length-1 leaves
            if 1 < k < length - 0:  # k<length and k>1
                comp = all_mask ^ bs
                # exclude complement-trivial as well; the test above already ensures k<length
                if comp != 0 and comp != all_mask:
                    splits.add(min(bs, comp))
        return splits

    splits1 = get_splits(tree1, leaf_index)
    splits2 = get_splits(tree2, leaf_index)

    rf = len(splits1.symmetric_difference(splits2))
    max_rf = len(splits1) + len(splits2)
    return rf, max_rf


In [None]:
def robinson_foulds(
    tree1,
    tree2,
    key1: str | None = None,
    key2: str | None = None,
) -> tuple[float, float]:
    """Compute the Robinson–Foulds distance between two trees.

    Args:
        tree1: The tree object.
        tree2: The tree object to compare against. If ``None``, ``key1`` and ``key2``
            are used to select two trees from the `tree1` object.
        key1: If ``tree1`` is a :class:`treedata.TreeData`, specifies the ``obst`` key to use.
            Only required if multiple trees are present.
        key2: The ``obst`` key to compare against. Selects from ``tree2`` if provided,
            otherwise selects from ``tree1``. Only required if multiple trees are present.

    Returns:
        tuple[float, float]: The Robinson–Foulds distance and the maximum
        possible distance for the pair of trees.
    """
    if tree2 is None and (key1 is None or key2 is None):
        raise ValueError("If tree2 is None, both key1 and key2 must be provided.")
    t1, _ = _get_digraph(tree1, tree_key=key1)
    t2, _ = (
        _get_digraph(tree2, tree_key=key2)
        if tree2 is not None
        else _get_digraph(tree1, tree_key=key2)
    )

    if set(get_leaves(t1)) != set(get_leaves(t2)):
        raise ValueError("Trees must have identical leaf sets.")

    return _robinson_foulds_bitset(t1, t2)


In [254]:
#reconstruct without sequential 

# Method 1: VanillaGreedySolver
print("Reconstructing with VanillaGreedySolver...")
vanilla_greedy = cas.solver.VanillaGreedySolver()
vanilla_tree = cas.data.CassiopeiaTree(character_matrix=char_mat_3257)
vanilla_greedy.solve(vanilla_tree, collapse_mutationless_edges=True)
reconstructed_trees = {}
reconstructed_trees['SequentialVanillaGreedy'] = vanilla_tree
print(f"✓ VanillaGreedy completed. Tree has {len(vanilla_tree.leaves)} leaves")



Reconstructing with VanillaGreedySolver...
✓ VanillaGreedy completed. Tree has 3257 leaves


  t1, _ = _get_digraph(tree1, tree_key=key1)
  _get_digraph(tree2, tree_key=key2)


In [256]:
rf_distance, rf_max = robinson_foulds(tree_3257, vanilla_tree)
normalized_rf = rf_distance / rf_max if rf_max > 0 else 0.0

print(normalized_rf)

0.3336550836550837


  t1, _ = _get_digraph(tree1, tree_key=key1)
  _get_digraph(tree2, tree_key=key2)


In [242]:
rf_distance, rf_max = cas.critique.compare.robinson_foulds(tree_3257, upgma_cassiopeia)


ValueError: Trees must have identical leaf sets.

In [257]:
# Method 2: NeighborJoiningSolver
print("Reconstructing with NeighborJoiningSolver...")
nj_solver = cas.solver.NeighborJoiningSolver(add_root=True)
nj_tree = cas.data.CassiopeiaTree(character_matrix=char_mat_3257)
nj_solver.solve(nj_tree, collapse_mutationless_edges=True)
reconstructed_trees['NeighborJoining'] = nj_tree
print(f"✓ NeighborJoining completed. Tree has {len(nj_tree.leaves)} leaves")


Reconstructing with NeighborJoiningSolver...
✓ NeighborJoining completed. Tree has 3257 leaves


In [259]:
rf_distance, rf_max = robinson_foulds(tree_3257, nj_tree)
normalized_rf = rf_distance / rf_max if rf_max > 0 else 0.0

print(normalized_rf)

0.07954362847501205


  t1, _ = _get_digraph(tree1, tree_key=key1)
  _get_digraph(tree2, tree_key=key2)


In [268]:
# Compare UPGMA tree to ground truth tree using Robinson-Foulds distance

rf_distance, rf_max = robinson_foulds(tree_3257, upgma_cassiopeia)
normalized_rf = rf_distance / rf_max

# Calculate triplets correct statistics
triplets_stats = cas.critique.compare.triplets_correct(
    tree_3257, 
    upgma_cassiopeia, 
    number_of_trials=1000
)
all_triplets_correct, resolvable_triplets_correct, unresolved_triplets_correct, proportion_resolvable = triplets_stats

# Summary statistics
avg_all_triplets = np.mean([v for v in all_triplets_correct.values()])
avg_resolvable = np.mean([v for v in resolvable_triplets_correct.values()])

print(f"Robinson-Foulds distance: {rf_distance}/{rf_max} (normalized: {normalized_rf:.4f})")
print(f"Average triplets correct: {avg_all_triplets:.4f}")
# print(f"Average resolvable triplets correct: {avg_resolvable:.4f}")


  t1, _ = _get_digraph(tree1, tree_key=key1)
  _get_digraph(tree2, tree_key=key2)


Robinson-Foulds distance: 628/6508 (normalized: 0.0965)
Average triplets correct: 0.9116


In [277]:
print(Out.keys())
Out[60]

dict_keys([60, 150, 153, 204, 206, 209, 211, 219, 221, 222, 224, 225, 226, 227, 269, 270, 271, 272, 273, 274, 275, 276])


Unnamed: 0,32770,32777,32785,32787,32793,32802,32804,32820,32827,32829,...,65472,65473,65475,65483,65490,65496,65498,65512,65519,65529
32770,0.0,34.0,32.0,32.0,33.0,36.0,36.0,36.0,36.0,36.0,...,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0
32777,34.0,0.0,34.0,34.0,34.0,36.0,36.0,36.0,36.0,36.0,...,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0
32785,32.0,34.0,0.0,29.0,32.0,36.0,36.0,36.0,36.0,36.0,...,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0
32787,32.0,34.0,29.0,0.0,32.0,36.0,36.0,36.0,36.0,36.0,...,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0
32793,33.0,34.0,32.0,32.0,0.0,36.0,36.0,35.0,35.0,36.0,...,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
65496,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,...,38.0,38.0,38.0,38.0,37.0,0.0,34.0,39.0,39.0,39.0
65498,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,...,38.0,38.0,38.0,38.0,37.0,34.0,0.0,39.0,39.0,39.0
65512,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,...,38.0,39.0,39.0,38.0,39.0,39.0,39.0,0.0,35.0,37.0
65519,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,59.0,...,39.0,39.0,39.0,38.0,39.0,39.0,39.0,35.0,0.0,37.0


In [None]:
small_sim = cas.sim.CompleteBinarySimulator(depth=10)
small_tree = small_sim.simulate_tree()

cas9_sim = cas.sim.Cas9LineageTracingDataSimulator(
    number_of_cassettes=3, 
    size_of_cassette=30,
    mutation_rate=0.4,
    number_of_states=100,
    heritable_silencing_rate=0.0,
    stochastic_silencing_rate=0.0,
    collapse_sites_on_cassette=False 

)
cas9_sim.overlay_data(small_tree)
small_character_matrix = small_tree.character_matrix

In [220]:
print(small_character_matrix)

      0   1   2   3   4   5   6   7   8   9   ...  80  81  82  83  84  85  86  \
1024  47   0  47  -1  -1  -1  -1  -1  -1  -1  ...  -1  -1  -1  -1  -1  -1   0   
1025  47   0  47  -1  -1  -1  -1  -1  -1  -1  ...  -1  -1  -1  -1  -1  -1   0   
1026  47   0  47  -1  -1  -1  -1  -1  -1  -1  ...  -1  -1  -1  -1  -1  -1   0   
1027  47   0  47  -1  -1  -1  -1  -1  -1  -1  ...  -1  -1  -1  -1  -1  -1   0   
1028  47   0  47  -1  -1  -1  -1  -1  -1  -1  ...  -1  -1  -1  -1  -1  -1   0   
...   ..  ..  ..  ..  ..  ..  ..  ..  ..  ..  ...  ..  ..  ..  ..  ..  ..  ..   
2043   0   0  -1  -1  -1  -1  -1  -1  -1  -1  ...  86   0   0   0   0   0   0   
2044   0   0   0   0  40   0   0   0   0   0  ...   0   0   0   0   0   0   0   
2045   0   0   0   0  40   0   0   0   0   0  ...  -1  -1  -1   0   0   0   0   
2046   0   0   0   0  40   0   0   0   0   0  ...   0   0   0   0   0   0  42   
2047   0   0   0   0  40   0   0   0   0   0  ...   0   0   0   0   0  62   0   

      87  88  89  
1024   0

In [227]:
def seen_mutation_rate(character_matrix):
    mutated_state = ((character_matrix != 0) & (character_matrix != -1)).astype(int)
    num_mutated = np.sum(mutated_state, axis=0)
    num_mutated = np.sum(num_mutated, axis=0)
    total_sites = character_matrix.shape[1] * character_matrix.shape[0]
    return float(num_mutated / total_sites)

seen_mutation_rate(small_character_matrix)


0.07581380208333334

In [281]:
mutation_rates = np.arange(0.0, 4, 0.1)
seen_mutation_rates = []
for mutation_rate in mutation_rates:
    cas9_sim = cas.sim.Cas9LineageTracingDataSimulator(
        number_of_cassettes=22, 
        size_of_cassette=3,
        mutation_rate=mutation_rate,
        number_of_states=100,
        heritable_silencing_rate=0.0,
        stochastic_silencing_rate=0.0,
        # collapse_sites_on_cassette=False 
    )
    cas9_sim.overlay_data(small_tree)
    seen_mutation_rates.append(seen_mutation_rate(small_tree.character_matrix))

print(seen_mutation_rates)

[0.0, 0.09777462121212122, 0.17344341856060605, 0.24971886837121213, 0.32299064867424243, 0.36527876420454547, 0.3990737452651515, 0.4662198153409091, 0.49829841382575757, 0.5003551136363636, 0.5222093986742424, 0.6405658143939394, 0.5726207386363636, 0.6169211647727273, 0.659372040719697, 0.6478900331439394, 0.6729551373106061, 0.6822176846590909, 0.6421934185606061, 0.7170780066287878, 0.6895271070075758, 0.62744140625, 0.6292761600378788, 0.7132605350378788, 0.6744495738636364, 0.6974431818181818, 0.701467803030303, 0.6743903882575758, 0.681640625, 0.7300840435606061, 0.6377840909090909, 0.697265625, 0.7116033380681818, 0.6568714488636364, 0.6359493371212122, 0.6283291903409091, 0.65966796875, 0.5416518702651515, 0.6098928740530303, 0.685354521780303]


In [246]:
binary_sim = cas.sim.CompleteBinarySimulator(depth=15)
cas9_full_tree = binary_sim.simulate_tree()
print("Total leaves:", len(full_tree.leaves))  # should be 2**20
cas9_sim = cas.sim.Cas9LineageTracingDataSimulator(
    number_of_cassettes=3, 
    size_of_cassette=30,
    mutation_rate=1.0,
    number_of_states=100,
    collapse_sites_on_cassette=False 

)
cas9_sim.overlay_data(cas9_full_tree)
cas9_character_matrix = cas9_full_tree.character_matrix

print("finished overlaying cas9 data")


cas9_tree_3257 = subsampler.subsample_leaves(cas9_full_tree)
print("Leaves after subsampling:", len(tree_3257.leaves))  # should be 3257
print("fraction of sites mutated:", seen_mutation_rate(cas9_character_matrix))

Total leaves: 32768
finished overlaying cas9 data
Leaves after subsampling: 3257
fraction of sites mutated: 0.6136837429470486


In [247]:
# Use the function to calculate unique barcodes
num_unique, num_leaves, pct_unique_cells = calculate_unique_barcodes(
    cas9_character_matrix, 
    leaves=cas9_tree_3257.leaves
)

collision_rate = 1.0 - pct_unique_cells
print("Collision rate:", collision_rate)


Unique leaf barcodes: 3256 of 3257 (1.000 fraction)
Collision rate: 0.000307031010132075


In [82]:
# Build list of non-root edges (parent, child)
edges = []
for node in cas9_tree_3257.nodes:
    if not cas9_tree_3257.is_root(node):
        parent = cas9_tree_3257.parent(node)
        edges.append((parent, node))

# Helper: does edge carry at least one mutation?
# Here: mutation = any character differs between parent and child
# (you can tweak this to ignore -1, etc.)
informative_flags = []
for parent, child in edges:
    # Use get_character_states() to get states for any node (internal or leaf)
    parent_state = np.array(cas9_tree_3257.get_character_states(parent))
    child_state = np.array(cas9_tree_3257.get_character_states(child))
    # boolean: is there any position that differs?
    informative = np.any(parent_state != child_state)
    informative_flags.append(informative)

num_informative = np.sum(informative_flags)
num_edges = len(edges)
pct_informative_edges = num_informative / num_edges

print("Informative edges:", num_informative, "of", num_edges,
      f"({pct_informative_edges:.3f} fraction)")


Informative edges: 845 of 6513 (0.130 fraction)


In [248]:
# Reconstruct trees using different Cassiopeia solvers
# Extract character matrix for the subsampled leaves
cas9_subsampled_leaves = cas9_tree_3257.leaves
cas9_subsampled_character_matrix = cas9_tree_3257.character_matrix 

print(f"Character matrix shape: {cas9_subsampled_character_matrix.shape}")
print(f"Number of leaves: {len(cas9_subsampled_leaves)}")

# Store reconstructed trees
reconstructed_trees = {}


Character matrix shape: (3257, 90)
Number of leaves: 3257


In [249]:
# Method 1: VanillaGreedySolver
print("Reconstructing with VanillaGreedySolver...")
vanilla_greedy = cas.solver.VanillaGreedySolver()
vanilla_tree = cas.data.CassiopeiaTree(character_matrix=cas9_subsampled_character_matrix)
vanilla_greedy.solve(vanilla_tree, collapse_mutationless_edges=True)
reconstructed_trees['VanillaGreedy'] = vanilla_tree
print(f"✓ VanillaGreedy completed. Tree has {len(vanilla_tree.leaves)} leaves")


Reconstructing with VanillaGreedySolver...
✓ VanillaGreedy completed. Tree has 3257 leaves


In [171]:
# Method 2: NeighborJoiningSolver
print("Reconstructing with NeighborJoiningSolver...")
nj_solver = cas.solver.NeighborJoiningSolver(add_root=True)
nj_tree = cas.data.CassiopeiaTree(character_matrix=cas9_subsampled_character_matrix)
nj_solver.solve(nj_tree, collapse_mutationless_edges=True)
reconstructed_trees['NeighborJoining'] = nj_tree
print(f"✓ NeighborJoining completed. Tree has {len(nj_tree.leaves)} leaves")


Reconstructing with NeighborJoiningSolver...
✓ NeighborJoining completed. Tree has 3257 leaves


In [170]:
# Method 3: UPGMASolver
print("Reconstructing with UPGMASolver...")
upgma_solver = cas.solver.UPGMASolver()
upgma_tree = cas.data.CassiopeiaTree(character_matrix=cas9_subsampled_character_matrix)
upgma_solver.solve(upgma_tree, collapse_mutationless_edges=True)
reconstructed_trees['UPGMA'] = upgma_tree
print(f"✓ UPGMA completed. Tree has {len(upgma_tree.leaves)} leaves")


Reconstructing with UPGMASolver...
✓ UPGMA completed. Tree has 3257 leaves


In [None]:
# Method 4: ILPSolver (may be slower for large trees)
print("Reconstructing with ILPSolver...")
try:
    ilp_solver = cas.solver.ILPSolver()
    ilp_tree = cas.data.CassiopeiaTree(character_matrix=cas9_subsampled_character_matrix)
    ilp_solver.solve(ilp_tree, collapse_mutationless_edges=True)
    reconstructed_trees['ILP'] = ilp_tree
    print(f"✓ ILP completed. Tree has {len(ilp_tree.leaves)} leaves")
except Exception as e:
    print(f"✗ ILP failed: {e}")


In [None]:
# Method 5: HybridSolver (combines multiple methods)
print("Reconstructing with HybridSolver...")
try:
    hybrid_solver = cas.solver.HybridSolver()
    hybrid_tree = cas.data.CassiopeiaTree(character_matrix=cas9_subsampled_character_matrix)
    hybrid_solver.solve(hybrid_tree, collapse_mutationless_edges=True)
    reconstructed_trees['Hybrid'] = hybrid_tree
    print(f"✓ Hybrid completed. Tree has {len(hybrid_tree.leaves)} leaves")
except Exception as e:
    print(f"✗ Hybrid failed: {e}")


In [None]:
# Method 6: SharedMutationJoiningSolver
print("Reconstructing with SharedMutationJoiningSolver...")
try:
    smj_solver = cas.solver.SharedMutationJoiningSolver()
    smj_tree = cas.data.CassiopeiaTree(character_matrix=cas9_subsampled_character_matrix)
    smj_solver.solve(smj_tree, collapse_mutationless_edges=True)
    reconstructed_trees['SharedMutationJoining'] = smj_tree
    print(f"✓ SharedMutationJoining completed. Tree has {len(smj_tree.leaves)} leaves")
except Exception as e:
    print(f"✗ SharedMutationJoining failed: {e}")


In [None]:
# Summary: Compare all reconstructed trees to ground truth
print("\n" + "="*60)
print("COMPARISON SUMMARY")
print("="*60)

ground_truth = cas9_tree_3257

for method_name, recon_tree in reconstructed_trees.items():
    print(f"\n{method_name}:")
    try:
        # Robinson-Foulds distance
        rf_distance, rf_max = robinson_foulds(ground_truth, recon_tree)
        normalized_rf = rf_distance / rf_max if rf_max > 0 else 0.0
        print(f"  Robinson-Foulds: {rf_distance}/{rf_max} (normalized: {normalized_rf:.4f})")
        
        # Triplets correct
        triplets_stats = cas.critique.compare.triplets_correct(
            ground_truth, recon_tree, number_of_trials=500
        )
        all_triplets_correct, resolvable_triplets_correct, _, _ = triplets_stats
        avg_triplets = np.mean([v for v in all_triplets_correct.values()])
        print(f"  Average triplets correct: {avg_triplets:.4f}")
        
    except Exception as e:
        print(f"  Error comparing trees: {e}")

print("\n" + "="*60)



COMPARISON SUMMARY

SequentialVanillaGreedy:
  Error comparing trees: Trees must have identical leaf sets.



  t1, _ = _get_digraph(tree1, tree_key=key1)
  _get_digraph(tree2, tree_key=key2)
