In [1]:
import copy

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

import anndata as ad

### Load Data

In [2]:
input_dir = "/Users/apple/Desktop/KB/data"
adata_train = ad.read_h5ad(input_dir+'/LarryData/train_test/Larry_200_train.h5ad')
adata_test = ad.read_h5ad(input_dir+'/LarryData/train_test/Larry_200_test.h5ad')

train_labels = adata_train.obs["clone_id"].to_numpy()
test_labels = adata_test.obs["clone_id"].to_numpy()

 
X_train = np.load(input_dir+'/feat_LCL_2025/Larry_top200/scvi_embedding/Larry_scvi_train_200_embeddings.npy')
X_test = np.load(input_dir+'/feat_LCL_2025/Larry_top200/scvi_embedding/Larry_scvi_test_200_embeddings.npy')

print(train_labels.shape, test_labels.shape)
print(X_train.shape, X_test.shape)

(10148,) (1225,)
(10148, 10) (1225, 10)


### Linear Layer with Softmax

In [3]:
# -----------------------
# (Optional) Filter to a trajectory by shared barcodes (same as your old Larry code)
# -----------------------
def filter_to_shared_barcodes(adata, X, keep_barcodes, barcode_key="Lib_Cellbarcode"):
    keep_mask = adata.obs[barcode_key].isin(keep_barcodes).to_numpy()
    return adata[keep_mask].copy(), X[keep_mask]


# -----------------------
# 1) Filter by lineage size at FUTURE timepoint (Larry: time_info == 6.0)
# -----------------------
def filter_by_clone_future_size(
    adata,
    X,
    day_key="time_info",          # <<< Larry
    lineage_key="clone_id",
    future_day=6.0,               # <<< Larry
    min_future_cells=10,
):
    """
    Keep all cells (all timepoints) whose clone has >= min_future_cells at future_day (e.g., 6.0).
    Computed within this adata only (no leakage across splits).
    Returns filtered (adata_sub, X_sub) aligned by obs order.
    """
    # IMPORTANT: time_info is numeric in Larry; compare numerically
    is_future = (adata.obs[day_key].to_numpy() == float(future_day))

    counts = adata.obs.loc[is_future, lineage_key].value_counts()
    keep_clones = set(counts[counts >= int(min_future_cells)].index)

    keep_mask = adata.obs[lineage_key].isin(keep_clones).to_numpy()
    adata_sub = adata[keep_mask].copy()
    X_sub = X[keep_mask]
    return adata_sub, X_sub


# -----------------------
# 2) Build early-time inputs + future lineage composition targets
#    Larry: early_day=2.0, future_day=6.0, celltype_key="state_info"
# -----------------------
def build_targets_from_future(
    X: np.ndarray,
    adata,
    early_day=2.0,                 # <<< Larry
    future_day=6.0,                # <<< Larry
    lineage_key="clone_id",
    celltype_key="state_info",     # <<< Larry
    terminal_types=("Undifferentiated", "Monocyte", "Neutrophil"),  # <<< Larry
    alpha_smooth=1e-3,
    drop_missing_future=True,
    day_key="time_info",           # <<< Larry
):
    terminal_types = list(terminal_types)
    C = len(terminal_types)

    # Future-time cells to compute lineage compositions
    is_future = (adata.obs[day_key].to_numpy() == float(future_day))
    adata_future = adata[is_future].copy()

    # clone_id -> probability vector over terminal_types
    clone_to_probs = {}
    for clone_id, df in adata_future.obs.groupby(lineage_key):
        counts = np.array([(df[celltype_key] == ct).sum() for ct in terminal_types], dtype=float)
        counts = counts + alpha_smooth
        probs = counts / counts.sum()
        clone_to_probs[clone_id] = probs

    # Early-time cells as inputs
    is_early = (adata.obs[day_key].to_numpy() == float(early_day))
    early_idx = np.where(is_early)[0]

    X_early = X[early_idx]
    clone_early = adata.obs.iloc[early_idx][lineage_key].to_numpy()

    y_prob = np.zeros((X_early.shape[0], C), dtype=float)
    keep = np.ones(X_early.shape[0], dtype=bool)

    for i, cid in enumerate(clone_early):
        if cid in clone_to_probs:
            y_prob[i] = clone_to_probs[cid]
        else:
            if drop_missing_future:
                keep[i] = False
            else:
                y_prob[i] = np.ones(C) / C

    X_early = X_early[keep]
    y_prob = y_prob[keep]
    y_prob = y_prob / y_prob.sum(axis=1, keepdims=True)

    return torch.tensor(X_early, dtype=torch.float32), torch.tensor(y_prob, dtype=torch.float32)


# -----------------------
# 3) Linear decoder
# -----------------------
class LinearSoftmax(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.fc(x)  # logits


# -----------------------
# 4) Train with early stopping + CellTag-style logs
# -----------------------
def train_kl_earlystop(
    model,
    X_train,
    y_train,
    lr=5e-3,
    weight_decay=1e-4,
    max_epochs=5000,
    batch_size=256,
    val_frac=0.2,
    patience=150,
    min_delta=1e-5,
    seed=42,
    device=None,
    print_every=50,
    verbose=True,
):
    if device is None:
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"

    model = model.to(device)
    X_train = X_train.to(device)
    y_train = y_train.to(device)

    n = X_train.shape[0]
    g = torch.Generator(device="cpu").manual_seed(seed)
    perm = torch.randperm(n, generator=g)

    n_val = int(round(val_frac * n))
    val_idx = perm[:n_val]
    tr_idx = perm[n_val:]

    X_tr, y_tr = X_train[tr_idx], y_train[tr_idx]
    X_val, y_val = X_train[val_idx], y_train[val_idx]

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.KLDivLoss(reduction="batchmean")

    best_val = float("inf")
    best_state = copy.deepcopy(model.state_dict())
    best_epoch = -1
    bad_epochs = 0

    @torch.no_grad()
    def eval_loss(Xe, ye):
        model.eval()
        log_probs = torch.log_softmax(model(Xe), dim=1)
        return criterion(log_probs, ye).item()

    for ep in range(1, max_epochs + 1):
        model.train()

        perm_tr = torch.randperm(X_tr.shape[0], device=device)
        Xs = X_tr[perm_tr]
        ys = y_tr[perm_tr]

        total = 0.0
        for start in range(0, Xs.shape[0], batch_size):
            xb = Xs[start:start + batch_size]
            yb = ys[start:start + batch_size]

            logits = model(xb)
            log_probs = torch.log_softmax(logits, dim=1)
            loss = criterion(log_probs, yb)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += loss.item() * xb.shape[0]

        train_loss = total / Xs.shape[0]
        val_loss = eval_loss(X_val, y_val)

        improved = (best_val - val_loss) > min_delta
        if improved:
            best_val = val_loss
            best_state = copy.deepcopy(model.state_dict())
            best_epoch = ep
            bad_epochs = 0
        else:
            bad_epochs += 1

        if verbose and (ep == 1 or ep % print_every == 0):
            print(
                f"Epoch {ep}/{max_epochs} | train={train_loss:.6f} | val={val_loss:.6f} "
                f"| best_val={best_val:.6f} (ep {best_epoch}) | bad={bad_epochs}/{patience} | device={device}"
            )

        if bad_epochs >= patience:
            if verbose:
                print(f"Early stopping at epoch {ep}. Best val={best_val:.6f} at epoch {best_epoch}.")
            break

    model.load_state_dict(best_state)
    return model, {"best_epoch": best_epoch, "best_val": best_val, "device": device}


@torch.no_grad()
def eval_kl(model, X, y):
    model.eval()
    device = next(model.parameters()).device
    X = X.to(device)
    y = y.to(device)
    criterion = nn.KLDivLoss(reduction="batchmean")
    log_probs = torch.log_softmax(model(X), dim=1)
    return criterion(log_probs, y).item()


# -----------------------
# 5) One-threshold experiment for Larry top200
# -----------------------
def run_one_threshold_experiment_larry200(
    adata_train, X_train,
    adata_test, X_test,
    lineage_threshold: int,
    terminal_types=("Undifferentiated", "Monocyte", "Neutrophil"),
    device="mps",
    seed=42,
    # time/celltype keys for Larry
    day_key="time_info",
    early_day=2.0,
    future_day=6.0,
    lineage_key="clone_id",
    celltype_key="state_info",
    # training hyperparams
    lr=5e-3,
    weight_decay=1e-4,
    max_epochs=5000,
    batch_size=256,
    val_frac=0.2,
    patience=150,
    min_delta=1e-5,
    print_every=50,
    alpha_smooth=1e-3,
    # optional: provide trajectory barcodes to pre-filter BOTH train and test
    trajectory_barcodes=None,
    barcode_key="Lib_Cellbarcode",
):
    """
    Larry top200 compositional prediction:
      predict Day6 lineage state_info composition from Day2 embeddings.

    If trajectory_barcodes is provided (set/list/np.array of Lib_Cellbarcode),
    we pre-filter train/test to those barcodes before everything else.
    """
    t = int(lineage_threshold)

    # Optional trajectory filter (matches your old Larry pipeline)
    if trajectory_barcodes is not None:
        adata_train, X_train = filter_to_shared_barcodes(adata_train, X_train, trajectory_barcodes, barcode_key=barcode_key)
        adata_test,  X_test  = filter_to_shared_barcodes(adata_test,  X_test,  trajectory_barcodes, barcode_key=barcode_key)

    # 1) filter by future clone size (time_info == 6.0)
    ad_tr_f, X_tr_f = filter_by_clone_future_size(
        adata_train, X_train,
        day_key=day_key, lineage_key=lineage_key, future_day=future_day,
        min_future_cells=t
    )
    ad_te_f, X_te_f = filter_by_clone_future_size(
        adata_test, X_test,
        day_key=day_key, lineage_key=lineage_key, future_day=future_day,
        min_future_cells=t
    )

    # 2) build (Day2 -> Day6 composition) pairs
    X_tr2, y_tr = build_targets_from_future(
        X_tr_f, ad_tr_f,
        day_key=day_key,
        early_day=early_day,
        future_day=future_day,
        lineage_key=lineage_key,
        celltype_key=celltype_key,
        terminal_types=terminal_types,
        alpha_smooth=alpha_smooth,
    )
    X_te2, y_te = build_targets_from_future(
        X_te_f, ad_te_f,
        day_key=day_key,
        early_day=early_day,
        future_day=future_day,
        lineage_key=lineage_key,
        celltype_key=celltype_key,
        terminal_types=terminal_types,
        alpha_smooth=alpha_smooth,
    )

    # 3) train linear decoder
    model = LinearSoftmax(input_size=X_tr2.shape[1], output_size=len(terminal_types))
    model, hist = train_kl_earlystop(
        model,
        X_tr2, y_tr,
        lr=lr,
        weight_decay=weight_decay,
        max_epochs=max_epochs,
        batch_size=batch_size,
        val_frac=val_frac,
        patience=patience,
        min_delta=min_delta,
        seed=seed,
        device=device,
        print_every=print_every,
        verbose=True,
    )

    # 4) evaluate test KL
    kl_test = eval_kl(model, X_te2, y_te)
    print(f"Best epoch: {hist['best_epoch']}")
    print(f"Test KL (linear): {kl_test:.4f}")
    print(f"Training device: {hist['device']}")

    summary = {
        "min_future_cells_time6": t,
        "train_cells_total_after_filter": ad_tr_f.n_obs,
        "test_cells_total_after_filter": ad_te_f.n_obs,
        "train_time2_cells_used": X_tr2.shape[0],
        "test_time2_cells_used": X_te2.shape[0],
        "best_epoch": hist["best_epoch"],
        "val_KL_best": hist["best_val"],
        "test_KL": kl_test,
        "device": hist["device"],
        "early_day": early_day,
        "future_day": future_day,
        "terminal_types": terminal_types,
    }
    return summary


In [4]:
summary = run_one_threshold_experiment_larry200(
    adata_train, X_train,
    adata_test,  X_test,
    lineage_threshold=0,
    terminal_types=("Undifferentiated", "Monocyte", "Neutrophil"),
    device="mps",
    seed=42,
    max_epochs=5000,
    patience=150,
    print_every=50,
    trajectory_barcodes=None,   # or trajectory_barcodes
)
print(summary)

Epoch 1/5000 | train=0.653260 | val=0.943850 | best_val=0.943850 (ep 1) | bad=0/150 | device=mps
Epoch 50/5000 | train=0.377575 | val=0.536161 | best_val=0.536161 (ep 50) | bad=0/150 | device=mps
Epoch 100/5000 | train=0.338137 | val=0.488631 | best_val=0.488631 (ep 100) | bad=0/150 | device=mps
Epoch 150/5000 | train=0.324246 | val=0.468707 | best_val=0.468707 (ep 150) | bad=0/150 | device=mps
Epoch 200/5000 | train=0.317175 | val=0.454790 | best_val=0.454790 (ep 200) | bad=0/150 | device=mps
Epoch 250/5000 | train=0.312892 | val=0.444028 | best_val=0.444028 (ep 250) | bad=0/150 | device=mps
Epoch 300/5000 | train=0.310095 | val=0.436366 | best_val=0.436366 (ep 300) | bad=0/150 | device=mps
Epoch 350/5000 | train=0.308198 | val=0.431188 | best_val=0.431188 (ep 350) | bad=0/150 | device=mps
Epoch 400/5000 | train=0.306853 | val=0.427773 | best_val=0.427773 (ep 400) | bad=0/150 | device=mps
Epoch 450/5000 | train=0.305837 | val=0.425547 | best_val=0.425547 (ep 450) | bad=0/150 | device=