In [1]:
import copy

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

import anndata as ad

### Load Data

In [2]:
input_dir = "/Users/apple/Desktop/KB/data"
adata_train = ad.read_h5ad(input_dir+'/BiddyData/Biddy_train.h5ad')
adata_test = ad.read_h5ad(input_dir+'/BiddyData/Biddy_test.h5ad')

# X_train = np.load(input_dir+'/feat_LCL_2025/cell_tag/scvi_embedding/Biddy_scvi_train_embeddings.npy')
# X_test = np.load(input_dir+'/feat_LCL_2025/cell_tag/scvi_embedding/Biddy_scvi_test_embeddings.npy')

embed_dir = input_dir + "/feat_LCL_2025/cell_tag/feat_celltag_lambda005_unlab5_bs50_testAsPenalty"
X_train_LCL = np.load(embed_dir+'/train_proj_embed.npy')
X_test_LCL = np.load(embed_dir+'/test_proj_embed.npy')

print(adata_train.shape, adata_test.shape)
print(X_train_LCL.shape, X_test_LCL.shape)

(5893, 2000) (641, 2000)
(5893, 32) (641, 32)


### Linear Layer with Softmax

In [3]:


# -----------------------
# 1) Helpers: filter adata + embeddings together
# -----------------------
def filter_by_clone_future_size(
    adata,
    X,
    day_key="reprogramming_day",
    lineage_key="clone_id",
    future_day="28",
    min_future_cells=10,
):
    """
    Keep all cells (all days) whose clone has >= min_future_cells at future_day (e.g., Day 28).
    Filtering is computed within this adata only (no leakage across splits).
    Returns filtered (adata_sub, X_sub) aligned by obs order.
    """
    day = adata.obs[day_key].astype(str)
    is_future = (day == str(future_day))

    # Count future-day cells per clone
    counts = adata.obs.loc[is_future, lineage_key].value_counts()

    # Clones to keep
    keep_clones = set(counts[counts >= int(min_future_cells)].index)

    # Keep all cells from those clones (all days)
    keep_mask = adata.obs[lineage_key].isin(keep_clones).to_numpy()
    adata_sub = adata[keep_mask].copy()
    X_sub = X[keep_mask]

    return adata_sub, X_sub


# -----------------------
# 2) Build Day12 inputs + lineage composition targets from Day28
# -----------------------
def build_targets_from_future(
    X: np.ndarray,
    adata,
    early_day="12",
    future_day="28",
    lineage_key="clone_id",
    celltype_key="cell_type",
    terminal_types=("iEP", "Fibroblast", "Ambiguous"),
    alpha_smooth=1e-3,
    drop_missing_future=True,
):
    terminal_types = list(terminal_types)
    C = len(terminal_types)

    # Future-day cells to compute lineage compositions
    future_mask = (adata.obs["reprogramming_day"].astype(str) == str(future_day))
    adata_future = adata[future_mask].copy()

    # clone_id -> probability vector over terminal_types
    clone_to_probs = {}
    for clone_id, df in adata_future.obs.groupby(lineage_key):
        counts = np.array([(df[celltype_key] == ct).sum() for ct in terminal_types], dtype=float)
        counts = counts + alpha_smooth
        probs = counts / counts.sum()
        clone_to_probs[clone_id] = probs

    # Early-day cells as inputs
    early_mask = (adata.obs["reprogramming_day"].astype(str) == str(early_day))
    early_idx = np.where(early_mask.values)[0]

    X_early = X[early_idx]
    clone_early = adata.obs.iloc[early_idx][lineage_key].to_numpy()

    y_prob = np.zeros((X_early.shape[0], C), dtype=float)
    keep = np.ones(X_early.shape[0], dtype=bool)

    for i, cid in enumerate(clone_early):
        if cid in clone_to_probs:
            y_prob[i] = clone_to_probs[cid]
        else:
            # no future cells for this lineage in this split
            if drop_missing_future:
                keep[i] = False
            else:
                y_prob[i] = np.ones(C) / C

    X_early = X_early[keep]
    y_prob = y_prob[keep]

    # normalize (should already be normalized)
    y_prob = y_prob / y_prob.sum(axis=1, keepdims=True)

    return torch.tensor(X_early, dtype=torch.float32), torch.tensor(y_prob, dtype=torch.float32)


# -----------------------
# 3) Linear decoder
# -----------------------
class LinearSoftmax(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.fc(x)  # logits


# -----------------------
# 4) Train with early stopping AND print logs like your example
# -----------------------
def train_kl_earlystop(
    model,
    X_train,
    y_train,
    lr=5e-3,
    weight_decay=1e-4,
    max_epochs=5000,
    batch_size=256,
    val_frac=0.2,
    patience=150,
    min_delta=1e-5,
    seed=42,
    device=None,
    print_every=50,     # <<< prints Epoch 1 and every print_every epochs
    verbose=True,
):
    if device is None:
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"

    # Move to device
    model = model.to(device)
    X_train = X_train.to(device)
    y_train = y_train.to(device)

    # train/val split
    n = X_train.shape[0]
    g = torch.Generator(device="cpu").manual_seed(seed)
    perm = torch.randperm(n, generator=g)

    n_val = int(round(val_frac * n))
    val_idx = perm[:n_val]
    tr_idx = perm[n_val:]

    X_tr, y_tr = X_train[tr_idx], y_train[tr_idx]
    X_val, y_val = X_train[val_idx], y_train[val_idx]

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.KLDivLoss(reduction="batchmean")

    best_val = float("inf")
    best_state = copy.deepcopy(model.state_dict())
    best_epoch = -1
    bad_epochs = 0

    history = {"train_loss": [], "val_loss": [], "best_epoch": None, "best_val": None, "device": device}

    @torch.no_grad()
    def eval_loss(Xe, ye):
        model.eval()
        log_probs = torch.log_softmax(model(Xe), dim=1)
        return criterion(log_probs, ye).item()

    for ep in range(1, max_epochs + 1):
        model.train()

        # shuffle train each epoch
        perm_tr = torch.randperm(X_tr.shape[0], device=device)
        Xs = X_tr[perm_tr]
        ys = y_tr[perm_tr]

        total = 0.0
        for start in range(0, Xs.shape[0], batch_size):
            xb = Xs[start:start + batch_size]
            yb = ys[start:start + batch_size]

            logits = model(xb)
            log_probs = torch.log_softmax(logits, dim=1)
            loss = criterion(log_probs, yb)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += loss.item() * xb.shape[0]

        train_loss = total / Xs.shape[0]
        val_loss = eval_loss(X_val, y_val)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

        # early stopping bookkeeping
        improved = (best_val - val_loss) > min_delta
        if improved:
            best_val = val_loss
            best_state = copy.deepcopy(model.state_dict())
            best_epoch = ep
            bad_epochs = 0
        else:
            bad_epochs += 1

        if verbose and (ep == 1 or ep % print_every == 0):
            print(
                f"Epoch {ep}/{max_epochs} | train={train_loss:.6f} | val={val_loss:.6f} "
                f"| best_val={best_val:.6f} (ep {best_epoch}) | bad={bad_epochs}/{patience} | device={device}"
            )

        if bad_epochs >= patience:
            if verbose:
                print(f"Early stopping at epoch {ep}. Best val={best_val:.6f} at epoch {best_epoch}.")
            break

    # restore best
    model.load_state_dict(best_state)
    history["best_epoch"] = best_epoch
    history["best_val"] = best_val
    return model, history


# -----------------------
# 5) Evaluate KL on test (device-safe)
# -----------------------
@torch.no_grad()
def eval_kl(model, X, y):
    model.eval()
    device = next(model.parameters()).device
    X = X.to(device)
    y = y.to(device)
    criterion = nn.KLDivLoss(reduction="batchmean")
    log_probs = torch.log_softmax(model(X), dim=1)
    return criterion(log_probs, y).item()


# -----------------------
# 6) One-threshold experiment (replaces the sweep wrapper)
# -----------------------
def run_one_threshold_experiment(
    adata_train, X_train,
    adata_test, X_test,
    lineage_threshold: int,
    terminal_types=("iEP", "Fibroblast", "Ambiguous"),
    device="mps",
    seed=42,
    # training hyperparams
    lr=5e-3,
    weight_decay=1e-4,
    max_epochs=5000,
    batch_size=256,
    val_frac=0.2,
    patience=150,
    min_delta=1e-5,
    print_every=50,
    alpha_smooth=1e-3,
):
    """
    Runs ONE experiment for a single lineage threshold and prints training logs in the
    same style as your former version.

    Returns:
      summary dict (same fields as your sweep per threshold)
    """
    t = int(lineage_threshold)

    # 1) filter train/test separately (no leakage)
    ad_tr_f, X_tr_f = filter_by_clone_future_size(adata_train, X_train, min_future_cells=t)
    ad_te_f, X_te_f = filter_by_clone_future_size(adata_test, X_test, min_future_cells=t)

    # 2) build Day12 -> composition pairs
    X_tr12, y_tr = build_targets_from_future(
        X_tr_f, ad_tr_f,
        terminal_types=terminal_types,
        alpha_smooth=alpha_smooth,
    )
    X_te12, y_te = build_targets_from_future(
        X_te_f, ad_te_f,
        terminal_types=terminal_types,
        alpha_smooth=alpha_smooth,
    )

    # 3) train linear decoder with early stopping (prints progress)
    model = LinearSoftmax(input_size=X_tr12.shape[1], output_size=len(terminal_types))
    model, hist = train_kl_earlystop(
        model,
        X_tr12, y_tr,
        lr=lr,
        weight_decay=weight_decay,
        max_epochs=max_epochs,
        batch_size=batch_size,
        val_frac=val_frac,
        patience=patience,
        min_delta=min_delta,
        seed=seed,
        device=device,
        print_every=print_every,
        verbose=True,
    )

    # 4) evaluate on test
    kl_test = eval_kl(model, X_te12, y_te)

    print(f"Best epoch: {hist['best_epoch']}")
    print(f"Test KL (linear): {kl_test:.4f}")
    print(f"Training device: {hist['device']}")

    summary = {
        "min_future_cells_day28": t,
        "train_cells_total_after_filter": ad_tr_f.n_obs,
        "test_cells_total_after_filter": ad_te_f.n_obs,
        "train_day12_cells_used": X_tr12.shape[0],
        "test_day12_cells_used": X_te12.shape[0],
        "best_epoch": hist["best_epoch"],
        "val_KL_best": hist["best_val"],
        "test_KL": kl_test,
        "device": hist["device"],
    }
    return summary

In [4]:
summary = run_one_threshold_experiment(
    adata_train, X_train_LCL,
    adata_test,  X_test_LCL,
    lineage_threshold=3,
    terminal_types=("iEP", "Fibroblast", "Ambiguous"),
    device="mps",
    seed=42,
    max_epochs=5000,
    patience=150,
    print_every=50,
)
print(summary)

Epoch 1/5000 | train=0.476202 | val=0.390638 | best_val=0.390638 (ep 1) | bad=0/150 | device=mps
Epoch 50/5000 | train=0.085646 | val=0.108993 | best_val=0.108993 (ep 50) | bad=0/150 | device=mps
Epoch 100/5000 | train=0.053764 | val=0.069635 | best_val=0.069635 (ep 100) | bad=0/150 | device=mps
Epoch 150/5000 | train=0.041409 | val=0.052458 | best_val=0.052458 (ep 150) | bad=0/150 | device=mps
Epoch 200/5000 | train=0.034928 | val=0.042980 | best_val=0.042980 (ep 200) | bad=0/150 | device=mps
Epoch 250/5000 | train=0.030853 | val=0.036896 | best_val=0.036896 (ep 250) | bad=0/150 | device=mps
Epoch 300/5000 | train=0.027947 | val=0.032689 | best_val=0.032689 (ep 300) | bad=0/150 | device=mps
Epoch 350/5000 | train=0.025777 | val=0.029702 | best_val=0.029702 (ep 350) | bad=0/150 | device=mps
Epoch 400/5000 | train=0.024157 | val=0.027542 | best_val=0.027542 (ep 400) | bad=0/150 | device=mps
Epoch 450/5000 | train=0.022934 | val=0.025965 | best_val=0.025965 (ep 450) | bad=0/150 | device=