In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

In [2]:
import numpy as np
import scanpy as sc
import pandas as pd
import anndata as ad

In [23]:
embed_path = "/Users/apple/Desktop/KB/data/feat_LCL_final/cell_tag_cell_tag_multi"
proj_embed = np.load(embed_path+"/feat_celltagMulti_lambda005_unlab50_bs250_testAsPenalty/test_proj_embed.npy")
adata = ad.read_h5ad("/Users/apple/Desktop/KB/data/Cell_tag-Cell_tag_multi_integrated/Seurat_method/cellTag_train_tag_Seurat.h5ad")
proj_embed_multi = np.load(embed_path+"/feat_celltagMulti_lambda005_unlab50_bs250_testAsPenalty/train_proj_embed.npy")
adata_multi = ad.read_h5ad("/Users/apple/Desktop/KB/data/Cell_tag-Cell_tag_multi_integrated/Seurat_method/cellTag_test_multi_Seurat_clone_id.h5ad")
adata_multi.obsm["X_lcl"] = proj_embed_multi


In [4]:
proj_embed.shape

(6534, 32)

In [18]:
adata.obsm["X_lcl"] = proj_embed
adata.obsm["X_integrate"] = adata.X

In [16]:
def cluster_in_embedding(
    adata,
    embedding_key,
    cluster_key,
    resolution=1.0,
    n_neighbors=30,
    random_state=0,
):
    """
    Cluster cells in `adata` using a precomputed embedding in adata.obsm[embedding_key].
    
    Parameters
    ----------
    adata : AnnData
        AnnData object containing the cells to cluster.
    embedding_key : str
        Key in adata.obsm where the embedding is stored (e.g. "X_lcl").
    cluster_key : str
        Name of the column in adata.obs where Leiden clusters will be stored.
    resolution : float
        Leiden resolution parameter (tune to control number of clusters).
    n_neighbors : int
        Number of neighbors for kNN graph.
    random_state : int
        Random seed for reproducibility.
    """
    sc.pp.neighbors(
        adata,
        use_rep=embedding_key,
        n_neighbors=n_neighbors,
        random_state=random_state,
    )

    sc.tl.leiden(
        adata,
        resolution=resolution,
        key_added=cluster_key,
        random_state=random_state,
    )

    return adata


def _sample_pairs(n_cells, n_pairs, rng):
    """
    Efficiently sample `n_pairs` unordered pairs (i, j) with i != j
    from {0, ..., n_cells-1}.
    """
    i = rng.integers(0, n_cells, size=n_pairs)
    j = rng.integers(0, n_cells, size=n_pairs)

    # ensure i != j
    mask = i != j
    # if some are equal, resample those positions
    # while mask.sum() < n_pairs:
    #     print(mask.sum())
    #     missing = n_pairs - mask.sum()
    #     i_new = rng.integers(0, n_cells, size=missing)
    #     j_new = rng.integers(0, n_cells, size=missing)
    #     new_mask = i_new != j_new

    #     # keep only valid pairs from the new batch
    #     i_new = i_new[new_mask]
    #     j_new = j_new[new_mask]

    #     # fill in remaining slots
    #     i[~mask][:len(i_new)] = i_new
    #     j[~mask][:len(j_new)] = j_new

    #     mask = i != j

    return i, j


def evaluate_pairwise_tpr_fpr(
    adata,
    lineage_key="lineage",
    cluster_key="pred_cluster",
    n_pairs=100_000,
    random_state=0,
):
    """
    Evaluate TPR and FPR for cluster-based lineage prediction via random pairs.

    Truth:   two cells are from the same lineage (adata.obs[lineage_key])
    Predict: two cells are in the same cluster (adata.obs[cluster_key])

    Parameters
    ----------
    adata : AnnData
        AnnData object containing test/unlabeled cells (with true lineages
        available for evaluation).
    lineage_key : str
        Column in adata.obs with true lineage IDs (barcodes).
    cluster_key : str
        Column in adata.obs with cluster labels used as predicted lineages.
    n_pairs : int
        Number of random cell pairs to sample.
    random_state : int
        Random seed.

    Returns
    -------
    metrics : dict
        Dictionary with TPR, FPR, and basic counts.
    """
    rng = np.random.default_rng(random_state)

    # extract arrays
    lineage = adata.obs[lineage_key].to_numpy()
    cluster = adata.obs[cluster_key].to_numpy()
    n_cells = adata.n_obs

    # # optional: restrict to cells with non-missing lineage
    # valid = ~pd.isna(lineage)
    # if valid.sum() < 2:
    #     raise ValueError("Not enough cells with non-missing lineage to evaluate.")
    # lineage = lineage[valid]
    # cluster = cluster[valid]
    # n_cells = lineage.shape[0]

    # sample random pairs
    i, j = _sample_pairs(n_cells, n_pairs, rng)

    same_true = lineage[i] == lineage[j]
    same_pred = cluster[i] == cluster[j]

    # counts
    TP = np.sum(same_true & same_pred)
    FN = np.sum(same_true & ~same_pred)
    FP = np.sum(~same_true & same_pred)
    TN = np.sum(~same_true & ~same_pred)

    # rates
    tpr = TP / (TP + FN) if (TP + FN) > 0 else np.nan  # sensitivity
    fpr = FP / (FP + TN) if (FP + TN) > 0 else np.nan  # 1 - specificity

    return {
        "TP": int(TP),
        "FN": int(FN),
        "FP": int(FP),
        "TN": int(TN),
        "TPR": float(tpr),
        "FPR": float(fpr),
        "n_pairs": int(n_pairs),
        "n_cells_used": int(n_cells),
    }


def deploy_lineage_prediction_on_unlabeled(
    adata_test,
    lineage_key="lineage",
    lcl_embedding_key="X_lcl",
    baseline_embedding_key=None,
    lcl_cluster_key="lcl_cluster",
    baseline_cluster_key="baseline_cluster",
    lcl_resolution=1.0,
    baseline_resolution=1.0,
    n_pairs=100_000,
    n_neighbors=30,
    random_state=0,
):
    """
    Main wrapper:
    1) Cluster CellTag test (unlabeled) cells in LCL embedding.
    2) Evaluate pairwise TPR/FPR.
    3) Optionally do the same with a baseline embedding.

    Parameters
    ----------
    adata_test : AnnData
        AnnData object containing CellTag test cells (unlabeled in training,
        but with true lineage barcodes available in adata.obs[lineage_key]
        for evaluation).
    lineage_key : str
        Column in adata_test.obs with true lineage IDs (barcodes).
    lcl_embedding_key : str
        Key in adata_test.obsm with LCL embedding (e.g. "X_lcl").
    baseline_embedding_key : str or None
        Optional key in adata_test.obsm with baseline embedding
        (e.g. "X_celltag_integration"). If not None, baseline TPR/FPR is
        also computed.
    lcl_cluster_key : str
        Column name in adata_test.obs to store LCL clusters.
    baseline_cluster_key : str
        Column name in adata_test.obs to store baseline clusters.
    lcl_resolution : float
        Leiden resolution for LCL clustering.
    baseline_resolution : float
        Leiden resolution for baseline clustering.
    n_pairs : int
        Number of random pairs for evaluation.
    n_neighbors : int
        kNN neighbors for `sc.pp.neighbors`.
    random_state : int
        Random seed for reproducibility.

    Returns
    -------
    results : dict
        {
          "LCL": {TPR, FPR, ...},
          "baseline": {TPR, FPR, ...}  # only if baseline_embedding_key is not None
        }
    """
    results = {}

    # --- LCL embedding ---
    cluster_in_embedding(
        adata_test,
        embedding_key=lcl_embedding_key,
        cluster_key=lcl_cluster_key,
        resolution=lcl_resolution,
        n_neighbors=n_neighbors,
        random_state=random_state,
    )
    res_lcl = evaluate_pairwise_tpr_fpr(
        adata_test,
        lineage_key=lineage_key,
        cluster_key=lcl_cluster_key,
        n_pairs=n_pairs,
        random_state=random_state,
    )
    results["LCL"] = res_lcl
    print("LCL-based prediction:")
    print(res_lcl)

    # --- Baseline embedding (optional) ---
    if baseline_embedding_key is not None:
        cluster_in_embedding(
            adata_test,
            embedding_key=baseline_embedding_key,
            cluster_key=baseline_cluster_key,
            resolution=baseline_resolution,
            n_neighbors=n_neighbors,
            random_state=random_state,
        )
        res_base = evaluate_pairwise_tpr_fpr(
            adata_test,
            lineage_key=lineage_key,
            cluster_key=baseline_cluster_key,
            n_pairs=n_pairs,
            random_state=random_state + 1,
        )
        results["baseline"] = res_base
        print("\nBaseline-embedding prediction:")
        print(res_base)

    return results

In [25]:
results = deploy_lineage_prediction_on_unlabeled(
    adata_test=adata_multi,
    lineage_key="clone_id",
    lcl_embedding_key="X_lcl",
    baseline_embedding_key= None,  
    lcl_cluster_key="lcl_cluster",
    baseline_cluster_key=None,
    lcl_resolution=1.0,            
    baseline_resolution=1.0,       
    n_pairs=100_000,               
    n_neighbors=10,
    random_state=42,
)

LCL-based prediction:
{'TP': 353, 'FN': 62, 'FP': 527, 'TN': 99058, 'TPR': 0.8506024096385543, 'FPR': 0.005291961640809359, 'n_pairs': 100000, 'n_cells_used': 22238}


In [21]:
results = deploy_lineage_prediction_on_unlabeled(
    adata_test=adata,
    lineage_key="clone_id",
    lcl_embedding_key="X_lcl",
    baseline_embedding_key= "X_integrate",  
    lcl_cluster_key="lcl_cluster",
    baseline_cluster_key=None,
    lcl_resolution=1.0,            
    baseline_resolution=1.0,       
    n_pairs=100_000,               
    n_neighbors=10,
    random_state=42,
)

LCL-based prediction:
{'TP': 666, 'FN': 5798, 'FP': 6371, 'TN': 87165, 'TPR': 0.10303217821782178, 'FPR': 0.06811281217926787, 'n_pairs': 100000, 'n_cells_used': 6534}

Baseline-embedding prediction:
{'TP': 2185, 'FN': 4231, 'FP': 20305, 'TN': 73279, 'TPR': 0.34055486284289277, 'FPR': 0.21697084971790048, 'n_pairs': 100000, 'n_cells_used': 6534}


In [9]:

embedding_key = "X_lcl"
cluster_key = "Leiden"
resolution=1.0
n_neighbors=10
random_state=0

sc.pp.neighbors(
adata,
use_rep=embedding_key,
n_neighbors=n_neighbors,
random_state=random_state,
)

sc.tl.leiden(
adata,
resolution=resolution,
key_added=cluster_key,
random_state=random_state,
)




In [13]:
adata.obs.loc[:, "Leiden"].value_counts()

Leiden
0     738
1     733
2     652
3     598
4     481
5     316
6     304
7     303
8     295
9     286
10    285
11    255
12    248
13    219
14    188
15    180
16    173
17    148
18     49
19     24
20     24
21     23
22     12
Name: count, dtype: int64