In [1]:
import torch
import torch.nn.functional as F
from captum.attr import IntegratedGradients


  from .autonotebook import tqdm as notebook_tqdm


In [None]:

# ----------------------------
# 1) Wrap f and g together
# ----------------------------
class LCLFG(torch.nn.Module):
    
    def __init__(self, f, g):
        super().__init__()
        self.f = f
        self.g = g

    def forward(self, x):
        # x: [B, G]
        h = self.f(x)
        z = self.g(h)
        return z  # [B, Dproj]

model_fg = LCLFG(f, g).eval()

# ----------------------------
# 2) Precompute all z and lineage centroids
# ----------------------------
@torch.no_grad()
def compute_z_and_centroids(model_fg, X, lineage_ids, device="cuda", batch_size=1024):
    model_fg = model_fg.to(device)
    X = X.to(device)

    Z = []
    for i in range(0, X.shape[0], batch_size):
        z = model_fg(X[i:i+batch_size])
        Z.append(z.detach().cpu())
    Z = torch.cat(Z, dim=0)  # [N, D]

    # map lineage -> indices
    lineage_to_idx = {}
    for i, lid in enumerate(lineage_ids):
        lineage_to_idx.setdefault(lid, []).append(i)

    # compute centroid per lineage
    centroids = {}
    for lid, idxs in lineage_to_idx.items():
        centroids[lid] = Z[idxs].mean(dim=0)  # [D]
        centroids[lid] = centroids[lid] / (centroids[lid].norm() + 1e-12)

    return Z, centroids, lineage_to_idx

Z_cpu, centroids, lineage_to_idx = compute_z_and_centroids(model_fg, X, lineage_ids)

# ----------------------------
# 3) Define a scalar function F(x) = cosine(z(x), centroid[lineage])
#    Captum wants forward() -> scalar per example
# ----------------------------
class CosineToCentroid(torch.nn.Module):
    def __init__(self, model_fg, centroid_vec):
        super().__init__()
        self.model_fg = model_fg
        self.register_buffer("centroid", centroid_vec.clone())

    def forward(self, x):
        z = self.model_fg(x)  # [B, D]
        z = z / (z.norm(dim=1, keepdim=True) + 1e-12)
        # cosine with centroid -> [B]
        return (z * self.centroid.unsqueeze(0)).sum(dim=1)

# ----------------------------
# 4) Run IG for a set of cells in one lineage
# ----------------------------
def run_ig_for_lineage(
    model_fg, X, idxs, centroid_vec, baseline="zero",
    n_steps=64, method="gausslegendre", device="cuda"
):
    model_fg = model_fg.to(device).eval()
    X = X.to(device)

    # baseline tensor
    if baseline == "zero":
        base = torch.zeros((1, X.shape[1]), device=device)
    elif baseline == "mean":
        base = X.mean(dim=0, keepdim=True).detach()
    else:
        raise ValueError("baseline must be 'zero' or 'mean'")

    # Build scalar model
    scalar_model = CosineToCentroid(model_fg, centroid_vec.to(device)).eval()
    ig = IntegratedGradients(scalar_model)

    # Gather inputs
    inputs = X[idxs].detach().requires_grad_(True)

    # Captum supports broadcasting baselines to batch
    attributions, delta = ig.attribute(
        inputs,
        baselines=base,
        n_steps=n_steps,
        method=method,
        return_convergence_delta=True
    )
    # attributions: [B, G], delta: [B]
    return attributions.detach().cpu(), delta.detach().cpu()


In [None]:

# Example: pick one lineage and run IG on up to 200 cells from it
some_lineage = list(lineage_to_idx.keys())[0]
idxs = lineage_to_idx[some_lineage][:200]
centroid_vec = centroids[some_lineage]  # [D]
attr, delta = run_ig_for_lineage(model_fg, X, idxs, centroid_vec, baseline="zero")
print("attr shape:", attr.shape, "delta mean:", delta.mean().item())