In [1]:
%load_ext autoreload
%autoreload 2

## data loading

In [6]:
import concord as ccd
import scanpy as sc
import torch
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')

data_dir = Path('../data/CBCEcombineN2/')
data_dir.mkdir(parents=True, exist_ok=True)
import time
from pathlib import Path
proj_name = "concord_combine_packerN2"
save_dir = f"../save/dev_{proj_name}-{time.strftime('%b%d')}/"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
seed = 0

In [7]:
adata = sc.read_h5ad(data_dir / 'adata_celsub_Jun26-1610.h5ad')

In [None]:
sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=10000, subset=False)
sc.tl.pca(adata, n_comps=50, svd_solver='arpack', use_highly_variable=True)

In [None]:
feature_list = ccd.ul.select_features(adata, n_top_features=10000, flavor='seurat_v3') # Loosely select features based on Seurat v3 method (so that enough information is preserved)

concord_args = {
        'adata': adata,
        'input_feature': feature_list,
        'domain_key': 'batch', # Key in adata.obs that contains the domain labels
        'batch_size':64, # Batch size for training, adjust as needed
        'latent_dim': 300, # Latent dimension size, adjust as needed
        'encoder_dims':[1000], # Encoder dimensions, recommended to be larger than latent_dim
        'use_decoder': False, # Whether to use a decoder, set to True if you want to use the decoder
        'decoder_dims':[1000], # Decoder dimensions, ignored if use_decoder is False
        'augmentation_mask_prob': 0.3, # Probability of masking features, recommended to be between 0.2 and 0.5
        'dropout_prob': 0, # Dropout rate for the model, recommended to be between 0.1 and 0.3
        'clr_temperature': 0.3, # Temperature for contrastive loss, recommended to be between 0.1 and 0.5
        'clr_beta':0.0,
        'p_intra_knn': 0.3, # Probability of intra-neighborhood sampling, must be less than 0.5
        'sampler_emb': None,
        'sampler_knn': 1000, # Size of neighbohood for intra-neighborhood sampling
        'n_epochs': 15, # Number of epochs for training, adjust as needed
        'verbose': True, # Verbosity level, set to True for more detailed output
        'seed': seed, # random seed for reproducibility
        'device': device, # Device for training, can be 'cpu', 'cuda', or 'mps'
        'save_dir': save_dir # Directory to save the model and results
    }

In [None]:
file_suffix = f"{time.strftime('%b%d-%H%M')}"
output_key = f'Concord_{file_suffix}'
print(f"Output key: {output_key}")
cur_ccd = ccd.Concord(**concord_args)
cur_ccd.fit_transform(output_key=output_key) # Result saved to ccd.adata.obsm[output_key]

In [None]:
# Fill na with unannotated
adata.obs['lineage'].fillna('unannotated', inplace=True)

In [None]:
basis = output_key
file_suffix = f"{time.strftime('%b%d-%H%M')}"
ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
show_cols = ['plot.cell.type', 'raw.embryo.time', "batch", 'dataset', 'lineage']
pal = {'plot.cell.type': 'tab20', 'raw.embryo.time': 'BlueGreenRed', "batch": 'tab20', 'dataset': 'Set1', 'lineage': 'tab20'}
ccd.pl.plot_embedding(
    adata, show_basis, show_cols, figsize=(13,8), dpi=600, ncols=3, font_size=5, point_size=2, legend_loc=None, 
    pal = pal,
    save_path=save_dir / f"{show_basis}_{file_suffix}.png"
)

In [None]:
ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP_3D', n_components=3, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
import plotly.io as pio
pio.renderers.default = 'notebook'
for col in show_cols:
    show_basis = f'{basis}_UMAP_3D'
    ccd.pl.plot_embedding_3d(
            adata, basis=show_basis, color_by=col,
            pal = pal[col],
            save_path=save_dir / f'{show_basis}_{col}_{file_suffix}.html',
            point_size=1, opacity=0.8, width=1500, height=1000
        )

In [None]:
adata.write_h5ad(data_dir / f"{proj_name}_{file_suffix}.h5ad") # Save the adata object with the encoded embeddings
print(f"Saved adata to {data_dir / f'{proj_name}_{file_suffix}.h5ad'}")


In [None]:
ccd.ul.anndata_to_viscello(adata, data_dir / f"cello_{proj_name}_{file_suffix}", project_name = proj_name, organism='hsa')
print(f"Saved viscello to {data_dir / f'cello_{proj_name}_{file_suffix}'}")

### Linear Probe evaluation

In [None]:
eval_keys = adata.obsm.keys()
# Exclude UMAP keys
eval_keys = [key for key in eval_keys if 'UMAP' not in key]
eval_keys

In [None]:
# Check if any of raw.embryo.time is na
if adata.obs['raw.embryo.time'].isna().any():
    print("Warning: raw.embryo.time contains NA values. Please check your data.")

if adata.obs['plot.cell.type'].isna().any():
    sum_na = adata.obs['plot.cell.type'].isna().sum()
    print(f"Warning: plot.cell.type contains {sum_na} NA values. ")

# Set na values to "unannotated"
adata.obs['plot.cell.type'].fillna('unannotated', inplace=True)

In [None]:
from concord.benchmarking import LinearProbeEvaluator
target_keys = ['raw.embryo.time', 'plot.cell.type', 'lineage', 'batch']  
ignore_values=['unannotated'],  # ignore unannotated cells
linear_results = {}
for target_key in target_keys:
    print(f"Evaluating {target_key}...")
    evaluator = LinearProbeEvaluator(
        adata=adata,
        emb_keys=eval_keys,
        target_key=target_key,          # or "pseudotime"
        task="auto",                     # "auto" | "classification" | "regression"
        batch_size=128,
        lr=1e-2,
        weight_decay=1e-3,
        epochs=20,                        # default replicates HCL
        print_every=1,
        ignore_values=ignore_values,  # ignore unannotated cells
        return_preds=True,          # whether to return predictions
        device="cpu"                    # or "cpu"
    )
    results_df, pred_bank = evaluator.run()
    linear_results[target_key] = results_df


In [None]:
from concord.benchmarking import KNNProbeEvaluator
knneval_results = {}
for target_key in target_keys:
    print(f"Evaluating {target_key} with KNN...")
    knn_eval = KNNProbeEvaluator(
        adata         = adata,
        emb_keys      = eval_keys,
        target_key    = target_key,  
        ignore_values = ignore_values,  # ignore unannotated cells
        k             = 15,
        metric        = "euclidean",   # or "cosine"
        return_preds  = True,
        seed          = 0,
    )
    metrics_df, preds_bank = knn_eval.run()
    knneval_results[target_key] = metrics_df

### Lineage graph

In [None]:
import importlib
import lineage_helpers          # already imported earlier
import pandas as pd

importlib.reload(lineage_helpers)   # reloads the .py file from disk

from lineage_helpers import build_lineage_graph  # re-import the updated function


tbl = pd.read_csv("../data/CE_CB/cel_lineage_tree_tbl.csv", index_col=0)

G_tree, tbl_aug = build_lineage_graph(
        adata, tbl,
        broad_lineage_groups=None,
        add_broad_group=True,
        plot=True,
        plot_path=save_dir / "lineage_tree.pdf",
)
print(tbl_aug["mapped"].value_counts())

In [None]:
from lineage_helpers import assign_cells_to_tree, compute_node_medoids
_ = assign_cells_to_tree(G_tree, adata,
                         lineage_key="lineage",
                         celltype_key="plot.cell.type",
                         prefer="lineage")



In [None]:
# 2) compute medoids on your 2-D Concord embedding
_ = compute_node_medoids(G_tree, adata,
                         emb_key="Concord_Jun21-1419",
                         method="medoid",
                         jitter=0) 

In [None]:
import importlib, lineage_helpers as lh
importlib.reload(lh)                       # grab the new code

D_euc, nodes = lh.pairwise_embedding_distances(
    G_tree, adata,
    emb_key="Concord_Jun21-1419",
    metric="euclidean"
)

In [None]:
# Geodesic distances (hop-weighted) on a 30-NN graph
import importlib, lineage_helpers as lh
D_geo, _ = lh.pairwise_geodesic_distances(
    G_tree, adata,
    rep_key="Concord_Jun21-1419",
    n_neighbors=30,
    directed=False
)


In [None]:
D_lin, _ = lh.pairwise_lineage_distances(
    G_tree.to_undirected()      # make sure it's undirected for path length
)

In [None]:
r, p, z = lh.mantel_correlation(D_lin, D_geo, method="spearman", permutations=999)
print(f"Mantel r = {r:.3f}, p = {p:.2e}")

In [None]:
# heatmap of geo distances 
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 8))
sns.heatmap(D_lin, annot=False,cmap='viridis')
plt.title("Geodesic Distances on Lineage Tree")
plt.savefig(save_dir / "geodesic_distances_heatmap.pdf", bbox_inches='tight')

In [None]:
### Lineage 
output_key = 'Concord_Jun21-1419'
sc.pp.neighbors(adata, use_rep=output_key, n_neighbors=30)  # or faiss     # sparse scipy CSR


In [None]:
conn_key = adata.uns["neighbors"]["connectivities_key"]   # usually just 'connectivities'
dist_key = adata.uns["neighbors"]["distances_key"]        # usually 'distances'

G = adata.obsp[conn_key]      # sparse CSR matrix (cells × cells)
D = adata.obsp[dist_key] 

In [None]:
import numpy as np
medoid_idx = {}
medoid_coord = {}
for lin in adata.obs["lineage"].cat.categories:
    idx = np.where(adata.obs["lineage"] == lin)[0]
    coords = adata.obsm[output_key][idx, :]
    medoid_coord[lin], medoid_idx[lin] = get_representative_point(coords, method="medoid", return_idx=True)

In [None]:
# a quick helper

output_key = 'Concord_Jun21-1419'
lineage_to_cells = (
    adata.obs.groupby('lineage')
         .indices              # dict {lineage_name: ndarray of cell indices}
)
celltype_to_cells = (
    adata.obs.groupby('plot.cell.type')
            .indices              # dict {celltype_name: ndarray of cell indices}
)

# attach the medoid (or full index list) as a node attribute
for n,data in G_tree.nodes(data=True):
    lineage_mapped = data.get('lineage_annot', None)
    # If lineage_mapped is None, or if is na, then skip
    if lineage_mapped is None or (isinstance(lineage_mapped, float) and pd.isna(lineage_mapped)):
        continue

    cells = []
    print(f"Processing node {n} with lineage {lineage_mapped}")
    for lin in lineage_mapped:
        if lin in lineage_to_cells:
            cells.extend(lineage_to_cells[lin])
            break
    if len(cells) == 0:
        continue          # node not present in adata – skip later
    # medoid = cell that minimises sum of squared Euclidean distances
    Z   = adata.obsm[output_key][cells]
    d2  = ((Z[:,None,:] - Z[None,:,:])**2).sum(-1)
    medoid = cells[np.argmin(d2.sum(0))]
    data["medoid"] = medoid
    data["cells"]  = cells               # keep full set in case you want avg/min
    print(f"Node {n} has {len(cells)} cells, medoid index: {medoid}")


In [None]:
sc.pp.neighbors(adata, use_rep=output_key, n_neighbors=30, key_added="concord_nn")
conn_key = adata.uns["concord_nn"]["connectivities_key"]
G_knn = adata.obsp[conn_key].tocsr()         # SciPy CSR

In [None]:
lin_nodes, medoid_idx = [], []

for n, data in G_tree.nodes(data=True):
    medoid = data.get("medoid", None)      # returns None if key missing
    if medoid is None or np.isnan(medoid):
        # ‒ lineage not present in the AnnData, or medoid not yet computed
        #   you can log / warn here if you want to know which ones were skipped
        continue
    lin_nodes.append(n)
    medoid_idx.append(int(medoid))

medoid_idx = np.asarray(medoid_idx, dtype=int)
L = len(lin_nodes)
print(f"Using {L} lineage nodes that have valid medoids.")


In [None]:
import scipy.sparse.csgraph as cg
import numpy as np

# medoid_idx and lin_nodes already prepared in your previous step
dist_src_all = cg.dijkstra(G_knn,
                           directed=False,
                           indices=medoid_idx)   # shape (L × V)

# keep only the L target columns you care about
dist_lin_all = dist_src_all[:, medoid_idx]       # shape (L × L)

# For clarity, build a lookup table node → row/col
lin2row = {n: i for i, n in enumerate(lin_nodes)}

In [None]:
# heatmap of distances
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 8))
sns.heatmap(dist_lin_all, cmap='viridis', cbar=True)



In [None]:
len(lin_nodes)

In [None]:
import scipy.sparse as sp
import scipy.sparse.csgraph as cg
import networkx as nx
from ete3 import Tree

# ------------------------------------------------------------------
# 1.  Build MST on latent distance matrix
# ------------------------------------------------------------------
mst_sparse = cg.minimum_spanning_tree(sp.csr_matrix(dist_lin_all))
mst_graph  = nx.Graph(mst_sparse)           # undirected, L nodes, L-1 edges

mapping    = {i: lin_nodes[i] for i in range(len(lin_nodes))}
mst_graph  = nx.relabel_nodes(mst_graph, mapping, copy=False)

# optional: root at the zygote (‘P0’) so you can compare branch lengths
root_node  = "P0" if "P0" in lin_nodes else lin_nodes[0]
mst_rooted = nx.bfs_tree(mst_graph, source=root_node).to_undirected()

In [None]:
# Compute UMAP based on the distances
import umap
umap_model = umap.UMAP(
    n_neighbors=15,  # or adjust based on your data
    min_dist=0.1,
    metric='precomputed',
    random_state=42
)
umap_embedding = umap_model.fit_transform(dist_lin_all)

In [None]:
# Plot the MST
plt.figure(figsize=(30, 20))
pos = nx.spring_layout(mst_rooted, seed=42)  # positions for all nodes
nx.draw(mst_rooted, pos, with_labels=True, node_size=300, node_color='lightblue', font_size=7, font_color='black', edge_color='gray')
plt.savefig(save_dir/"mst_tree.png", dpi=300)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import numpy as np

# -- Map lineage node → point in the UMAP embedding --------------------------
pos = {node: umap_embedding[i]                # dict  {node: (x, y)}
       for i, node in enumerate(lin_nodes)}

# -- Build a list of edge segments -------------------------------------------
segments = []
for parent, child in G_tree.edges():
    if parent in pos and child in pos:        # skip missing medoids
        segments.append([pos[parent], pos[child]])

# Optionally colour-code by lineage depth or anything in node attributes
edge_colors = "grey"          # or a list the same length as segments
edge_lw     = 0.6             # thin lines keep the plot readable

# -- Plot --------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(8, 7), dpi=150)

# scatter (medoids)
ax.scatter(umap_embedding[:, 0],
           umap_embedding[:, 1],
           s       = 12,
           alpha   = 0.9,
           zorder  = 3,
           edgecolors="none")

# edges
lc = LineCollection(segments,
                    colors=edge_colors,
                    linewidths=edge_lw,
                    alpha=0.7,
                    zorder=2)
ax.add_collection(lc)

# (optional) annotate lineage labels
for node, (x, y) in pos.items():
    ax.annotate(node,
                xy=(x, y),
                xytext=(2, 2),
                textcoords="offset points",
                fontsize=3,
                alpha=0.8)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("UMAP of lineage medoids with true tree overlay")
plt.tight_layout()
plt.savefig(save_dir / "umap_lineage_tree_overlay.png", dpi=300)
plt.show()


In [None]:
branch_points = [
    n for n in G_tree.nodes()
    if G_tree.out_degree(n) >= 2
       and all("medoid" in G_tree.nodes[c] for c in G_tree.successors(n))
]

print(f"{len(branch_points)} branch points with ≥2 mapped daughters")

In [None]:
import numpy as np

# helpers ---------------------------------------------------------------
def _as_iter(x):
    """None → []; str → [str]; list/tuple → list"""
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return []
    if isinstance(x, (list, tuple, set)):
        return list(x)
    return [x]          # single string

def _indices_from_names(names, mapping_dict):
    """Concatenate index lists for all present names"""
    out = []
    for n in names:
        out.extend(mapping_dict.get(n, []))
    return out

# main ------------------------------------------------------------------
def cells_of_lineage(node):
    """
    Return np.ndarray of cell indices that belong to the lineage node.
    Sources checked, in order:
      1. node.linannot   (list or string) in lineage_to_cells
      2. node.celltype   (list or string) in celltype_to_cells
      3. node.linorct    (mixed list)     in either dict
      4. the node *name* itself in lineage_to_cells
    """
    data = G_tree.nodes[node]          # attribute dict

    idx = []
    # 1 & 2. explicit mappings
    idx += _indices_from_names(_as_iter(data.get("linannot")),
                               lineage_to_cells)
    idx += _indices_from_names(_as_iter(data.get("celltype")),
                               celltype_to_cells)
    # 3. combined mapping field linorct
    for name in _as_iter(data.get("linorct")):
        if name in lineage_to_cells:
            idx.extend(lineage_to_cells[name])
        elif name in celltype_to_cells:
            idx.extend(celltype_to_cells[name])

    # 4. fallback: node label may match lineage_to_cells directly
    if not idx and node in lineage_to_cells:
        idx.extend(lineage_to_cells[node])

    return np.unique(idx)              # deduplicate & sort



In [None]:
lineage_to_cells = (
    adata.obs.groupby('lineage')
         .indices              # dict {lineage_name: ndarray of cell indices}
)
celltype_to_cells = (
    adata.obs.groupby('plot.cell.type')
            .indices              # dict {celltype_name: ndarray of cell indices}
)


In [None]:
latent_key = 'Concord_Jun21-1419'  # or whatever your embedding key is

In [None]:
for n,data in G_tree.nodes(data=True):
    cells = cells_of_lineage(n)
    print(f"Node {n} has {len(cells)} cells in lineage")
    if cells.size == 0:
        continue
    Z = adata.obsm[latent_key][cells]
    d2 = ((Z[:,None,:] - Z[None,:,:])**2).sum(-1).sum(0)
    data["medoid"] = int(cells[np.argmin(d2)])
    data["cells"]  = cells               # keep full set if needed l

In [None]:

# ------------------------------------------------------------------------------
# 2.  IDENTIFY BRANCH POINTS AND COLLECT CELLS
# ------------------------------------------------------------------------------
branch_points = [
    n for n in G_tree.nodes()
    if G_tree.out_degree(n) >= 2
       and sum(cells_of_lineage(ch).size > 0 for ch in G_tree.successors(n)) >= 2
]

triplet_data = {}
for p in branch_points:
    daughters  = [ch for ch in G_tree.successors(p) if cells_of_lineage(ch).size]
    if len(daughters) < 2:
        continue
    cells = np.concatenate([
        cells_of_lineage(p),
        *[cells_of_lineage(d) for d in daughters]
    ])
    if cells.size == 0:
        continue
    triplet_data[p] = dict(daughters=daughters[:2],  # keep first 2 for Y-split
                           cells=cells)

print(f"Found {len(triplet_data)} branch points with ≥2 mapped daughters.")



In [None]:
# ------------------------------------------------------------------------------
# 3.  CORE ROUTINES
# ------------------------------------------------------------------------------
from sklearn.neighbors import NearestNeighbors
import scipy.sparse as sp
import scipy.sparse.csgraph as cg
import umap, numpy as np, matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from pathlib import Path
from scipy.spatial.distance import pdist, squareform
import pandas as pd

n_neighbors=10
def build_subgraph(cells):
    Z = adata.obsm[latent_key][cells]
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(Z)
    knn_idx, knn_dist = nbrs.kneighbors(Z, return_distance=True)
    rows = np.repeat(np.arange(len(cells)), n_neighbors)
    G = sp.csr_matrix((knn_dist.flatten(),
                       (rows, knn_idx.flatten())),
                      shape=(len(cells), len(cells)))
    return Z, G

def embed_subgraph(dists):
    um = umap.UMAP(n_neighbors=n_neighbors,
                   min_dist=0.1,
                   metric="precomputed",
                   random_state=42)
    return um.fit_transform(dists)

def branch_metrics(parent_mask, d1_mask, d2_mask, Z):
    z_p  = Z[parent_mask].mean(0)
    z_d1 = Z[d1_mask   ].mean(0)
    z_d2 = Z[d2_mask   ].mean(0)
    d_pd = (np.linalg.norm(z_p-z_d1) + np.linalg.norm(z_p-z_d2)) / 2
    d_dd =  np.linalg.norm(z_d1-z_d2)
    return d_pd, d_dd

def plot_triplet(p, daughters, cells, Z, X, out_png):
    # colour by label
    labels = np.full(len(cells), p, dtype=object)
    labels[np.isin(cells, cells_of_lineage(daughters[0]))] = daughters[0]
    labels[np.isin(cells, cells_of_lineage(daughters[1]))] = daughters[1]
    lut = {p:"#636363",
           daughters[0]:"#4575b4",
           daughters[1]:"#d73027"}
    colors = [lut[l] for l in labels]

    fig, ax = plt.subplots(figsize=(4,4), dpi=200)
    ax.scatter(X[:,0], X[:,1], s=6, c=colors, alpha=.9, rasterized=True)
    # MST skeleton for clarity
    mst = cg.minimum_spanning_tree(squareform(pdist(Z)))
    segs = [[X[i], X[j]] for i,j in zip(*mst.nonzero())]
    ax.add_collection(LineCollection(segs, colors="#00000055", lw=.6))
    ax.set_title(f"{p} → {daughters[0]}, {daughters[1]}")
    ax.axis("off")
    fig.savefig(out_png, bbox_inches="tight")
    plt.close(fig)

In [None]:
triplet_data['ABalaaa']

In [None]:
info = triplet_data['ABalaaa']
daughters = info["daughters"]        # two daughters
cells     = info["cells"]
Z, G_sub  = build_subgraph(cells)
dists_sub = cg.dijkstra(G_sub, directed=False, indices=None)  # full matrix
X_2d      = embed_subgraph(dists_sub)


In [None]:
# ------------------------------------------------------------------
# 4.  RUN LOOP: make plots and metrics  (UMAP-from-latent version)
# ------------------------------------------------------------------
results = []

for p, info in triplet_data.items():
    daughters = info["daughters"]            # two daughters
    cells     = np.asarray(info["cells"], int)
    print(f"Processing {p} → {daughters[0]}, {daughters[1]}  (n={len(cells)})")

    # latent vectors of the triplet
    Z, _ = build_subgraph(cells)             # Z = latent coords; G_sub unused now

    # --- NEW: UMAP directly on Z (Euclidean metric) --------------------------
    um = umap.UMAP(n_neighbors=n_neighbors,
                   min_dist=0.1,
                   metric="euclidean",
                   random_state=42)
    X_2d = um.fit_transform(Z)               # shape (n_cells, 2)
    # ------------------------------------------------------------------------

    # masks for the metrics
    lab_p  = np.isin(cells, cells_of_lineage(p))
    lab_d1 = np.isin(cells, cells_of_lineage(daughters[0]))
    lab_d2 = np.isin(cells, cells_of_lineage(daughters[1]))
    d_pd, d_dd = branch_metrics(lab_p, lab_d1, lab_d2, Z)

    results.append(dict(parent=p,
                        daughter1=daughters[0],
                        daughter2=daughters[1],
                        d_parent_daughter=d_pd,
                        d_daughter_daughter=d_dd,
                        ratio=d_dd/d_pd if d_pd else np.nan))

    # plot & save
    out_png = save_dir / f"triplet_{p.replace(' ', '_')}.png"
    plot_triplet(p, daughters, cells, Z, X_2d, out_png)

metrics_df = pd.DataFrame(results)
metrics_df.to_csv(save_dir / "branch_metrics.csv", index=False)
print("✓ Finished.   Plots →", save_dir,
      "\n              Metrics →", save_dir / "branch_metrics.csv")


In [None]:
adata.obsm[latent_key][cells]

In [None]:
results

In [None]:
# X  = your geodesic distance matrix from the latent
#      shape (L × L), order given by `lin_nodes`
X = dist_lin_all

# Build biological path-length matrix Y using the ground-truth tree
import networkx as nx, numpy as np, pandas as pd, scipy.sparse.csgraph as cg

lin2row = {n: i for i, n in enumerate(lin_nodes)}     # helper

# 0.1  choose edge weight
#   a)  plain edge count                (every division = 1)
#   b)  embryo time difference          (if you have time per node)
edge_weight = "count"                  # or "time"

def true_path_matrix(G_tree, lin_nodes, edge_weight="count"):
    L  = len(lin_nodes)
    Y  = np.full((L, L), np.inf)
    for i,u in enumerate(lin_nodes):
        Y[i,i] = 0
        for j,v in enumerate(lin_nodes[i+1:], i+1):
            try:
                if edge_weight == "count":
                    d = nx.shortest_path_length(G_tree, u, v)
                else:
                    # sum embryo time differences along path
                    path = nx.shortest_path(G_tree, u, v)
                    t = nx.get_node_attributes(G_tree, "embryo_time")
                    d = sum(abs(t[path[k]] - t[path[k+1]]) for k in range(len(path)-1))
            except nx.NetworkXNoPath:
                continue
            Y[i,j] = Y[j,i] = d
    return Y

Y = true_path_matrix(G_tree, lin_nodes, edge_weight="count")
