In [None]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

def infer_lineage_tree(trajectories):
    """
    Infers the lineage tree from trajectory data, identifying branching points based on divergence.
    
    Parameters:
        trajectories (np.ndarray): Array of shape (N_traj, N_time, dim),
                                   where each trajectory may be duplicated until branching.
    
    Returns:
        lineage (dict): Dictionary where keys are child indices and values are parent indices.
    """
    N_traj, N_time, _ = trajectories.shape
    lineage = {0: -1}  # Root trajectory has no parent (-1)
    
    for i in range(1, N_traj):  # Start from second trajectory
        parent = None

        # Compare trajectory `i` to previous ones
        for j in range(i):
            # Find the first divergence point
            divergence_time = np.where(np.any(trajectories[i] != trajectories[j], axis=1))[0]
            
            if len(divergence_time) > 0:  # If there is divergence
                first_divergence = divergence_time[0]  # Time step where they first differ
            else:
                first_divergence = N_time  # No divergence (identical trajectory)
            
            # If they were identical until the branching point, assign `j` as parent
            if first_divergence > 0:  # They were identical up to some point
                parent = j
                break  # Stop searching once we find the closest parent

        # Store the most recent identical trajectory as the parent
        lineage[i] = parent if parent is not None else -1  # Assign -1 if no parent found
    
    return lineage

def plot_lineage_tree(lineage):
    """Plots the inferred lineage tree using networkx and matplotlib (spring layout)."""
    G = nx.DiGraph()
    
    for child, parent in lineage.items():
        if parent != -1:  # Avoid adding the root as its own parent
            G.add_edge(parent, child)

    plt.figure(figsize=(10, 6))
    
    # Use spring layout (no pygraphviz required)
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, with_labels=True, node_size=500, node_color="skyblue", edge_color="gray", font_size=10, arrows=True)
    
    plt.title("Inferred Lineage Tree from Trajectory Data")
    plt.show()

# Example: Simulated trajectory dataset
N_traj, N_time, dim = 6, 5, 2
trajectories = np.array([
    [[0,0], [1,1], [2,2], [3,3], [4,4]],  # Traj 0 (Root)
    [[0,0], [1,1], [2,2], [3,3], [4,4]],  # Traj 1 (Duplicate of 0, should be child of 0)
    [[0,0], [1,1], [2,2], [3,3], [5,5]],  # Traj 2 (Branches from 0 at t=4)
    [[0,0], [1,1], [2,2], [3,3], [6,6]],  # Traj 3 (Branches from 0 at t=4)
    [[0,0], [1,1], [2,2], [3,3], [5,5]],  # Traj 4 (Duplicate of 2, should be child of 2)
    [[0,0], [1,1], [2,2], [3,4], [4,4]]   # Traj 5 (Branches from 0 at t=3)
])

# Infer the lineage tree
lineage_tree = infer_lineage_tree(process.trajectories)

# Plot the inferred tree
plot_lineage_tree(lineage_tree)

# Print inferred lineage for debugging
print("Inferred Lineage Tree:", lineage_tree)