In [3]:
import numpy as np
from scipy.stats import poisson, nbinom, norm, multivariate_normal, truncnorm
from numpy.linalg import cholesky, LinAlgError

def nb_theta_mu_to_r_p(theta, mu):
    r = theta
    p = theta / (theta + mu)
    return r, p

def cdf_poisson(k, mu):
    return poisson.cdf(k, mu)

def cdf_nb_theta_mu(k, theta, mu):
    r, p = nb_theta_mu_to_r_p(theta, mu)
    return nbinom.cdf(k, r, p)

def cdf_zip(k, mu, pi):
    if k < 0:
        return 0.0
    if k == 0:
        return pi + (1 - pi) * poisson.pmf(0, mu)
    return pi + (1 - pi) * poisson.cdf(k, mu)

def cdf_zinb_theta_mu(k, theta, mu, pi):
    r, p = nb_theta_mu_to_r_p(theta, mu)
    if k < 0:
        return 0.0
    if k == 0:
        return pi + (1 - pi) * nbinom.pmf(0, r, p)
    return pi + (1 - pi) * nbinom.cdf(k, r, p)

# Compute P(a < Z < b) for Z ~ N(0, Sigma) w/ GHK.
# Assume a and b have shape (p,), sigma is (p,p). Use M
# draws.
def ghk_estimate_rectangle_prob(a, b, Sigma, M=10000):
    p = Sigma.shape[0]
    try:
        L = cholesky(Sigma)
    except LinAlgError:
        raise ValueError("Sigma is not positive definite")
    weights = np.zeros(M)
    for m in range(M):
        z = np.zeros(p)
        w = 1.0
        for i in range(p):
            mu_cond = np.dot(L[i, :i], z[:i])
            sigma_cond = L[i, i]
            # compute lower and upper truncation in u-space
            a_u = (a[i] - mu_cond) / sigma_cond
            b_u = (b[i] - mu_cond) / sigma_cond
            # truncated normal draw for u_i
            u_i = truncnorm.rvs(a_u, b_u, loc=0.0, scale=1.0)
            # weight contribution for this step
            w *= (norm.cdf(b_u) - norm.cdf(a_u))
            
            z[i] = mu_cond + sigma_cond * u_i
        weights[m] = w
    return weights.mean()

# Estimate P(X=x) under Gaussian copula w/ discrete marginals.
def gaussian_copula_point_probability_ghk(x, marg_params, Sigma,
                                           eps=1e-12, M=5000):

    x = np.asarray(x, dtype=int)
    p = x.shape[0]
    diag = np.sqrt(np.diag(Sigma))
    R = Sigma / np.outer(diag, diag)
    np.fill_diagonal(R, 1.0)

    a = np.empty(p)
    b = np.empty(p)
    for i, xi in enumerate(x):
        params = marg_params[i]
        if params[0] == 0:
            # non zero-inflated
            if params[1] == np.inf:
                mu = params[2]
                a[i] = cdf_poisson(xi - 1, mu)
                b[i] = cdf_poisson(xi, mu)
            else:
                theta = params[1]; mu = params[2]
                a[i] = cdf_nb_theta_mu(xi - 1, theta, mu)
                b[i] = cdf_nb_theta_mu(xi, theta, mu)
        else:
            # zero-inflated
            pi = params[0]
            if params[1] == np.inf:
                mu = params[2]
                a[i] = cdf_zip(xi - 1, mu, pi)
                b[i] = cdf_zip(xi, mu, pi)
            else:
                theta = params[1]; mu = params[2]
                a[i] = cdf_zinb_theta_mu(xi - 1, theta, mu, pi)
                b[i] = cdf_zinb_theta_mu(xi, theta, mu, pi)
        a[i] = np.clip(a[i], eps, 1 - eps)
        b[i] = np.clip(b[i], eps, 1 - eps)

    # z-space limits
    z_lower = norm.ppf(a)
    z_upper = norm.ppf(b)

    # Estimate P(z_lower < Z < z_upper)
    prob = ghk_estimate_rectangle_prob(z_lower, z_upper, R, M=M)
    return prob

def getPointScores(
    targetAD, 
    auxMargParams, auxCovMat, auxCopulaGenes, 
    synthMargParams, synthCovMat, synthCopulaGenes
):
    scores = {}
    for i, cell in enumerate(targetAD):
        print(i)
        try:
            auxExpr = cell[:,auxCopulaGenes].X.toarray().flatten()
            auxProb = gaussian_copula_point_probability_ghk(
                auxExpr, auxMargParams, auxCovMat, M=100
            )
            synthExpr = cell[:,synthCopulaGenes].X.toarray().flatten()
            synthProb = gaussian_copula_point_probability_ghk(
                synthExpr, synthMargParams, synthCovMat, M=100
            )
            scores[i] = auxProb / synthProb
        except:
            scores[i] = None
            
    return scores

In [79]:
import numpy as np
import torch
from scipy.stats import poisson, nbinom
from numpy.linalg import cholesky, LinAlgError
from math import sqrt
from scipy.stats import norm as sp_norm

def nb_theta_mu_to_r_p(theta, mu):
    r = theta
    p = theta / (theta + mu)
    return r, p

def cdf_poisson_vec(k_arr, mu_arr):
    return poisson.cdf(k_arr, mu_arr)

def cdf_nb_theta_mu_vec(k_arr, theta_arr, mu_arr):
    r_arr, p_arr = theta_arr, theta_arr / (theta_arr + mu_arr)
    return nbinom.cdf(k_arr, r_arr, p_arr)

def cdf_zip_vec(k_arr, mu_arr, pi_arr):
    k_arr = np.asarray(k_arr)
    mu_arr = np.asarray(mu_arr)
    pi_arr = np.asarray(pi_arr)
    out = np.zeros_like(mu_arr, dtype=float)
    mask_neg = k_arr < 0
    out[mask_neg] = 0.0
    mask_zero = (k_arr == 0)
    if np.any(mask_zero):
        out[mask_zero] = pi_arr[mask_zero] + (1 - pi_arr[mask_zero]) * poisson.pmf(0, mu_arr[mask_zero])
    mask_pos = k_arr > 0
    if np.any(mask_pos):
        out[mask_pos] = pi_arr[mask_pos] + (1 - pi_arr[mask_pos]) * poisson.cdf(k_arr[mask_pos], mu_arr[mask_pos])
    return out

def cdf_zinb_theta_mu_vec(k_arr, theta_arr, mu_arr, pi_arr):
    k_arr = np.asarray(k_arr)
    theta_arr = np.asarray(theta_arr)
    mu_arr = np.asarray(mu_arr)
    pi_arr = np.asarray(pi_arr)
    out = np.zeros_like(mu_arr, dtype=float)
    mask_neg = k_arr < 0
    out[mask_neg] = 0.0
    mask_zero = (k_arr == 0)
    if np.any(mask_zero):
        r = theta_arr[mask_zero]
        p = theta_arr[mask_zero] / (theta_arr[mask_zero] + mu_arr[mask_zero])
        out[mask_zero] = pi_arr[mask_zero] + (1 - pi_arr[mask_zero]) * nbinom.pmf(0, r, p)
    mask_pos = k_arr > 0
    if np.any(mask_pos):
        r = theta_arr[mask_pos]
        p = theta_arr[mask_pos] / (theta_arr[mask_pos] + mu_arr[mask_pos])
        out[mask_pos] = pi_arr[mask_pos] + (1 - pi_arr[mask_pos]) * nbinom.cdf(k_arr[mask_pos], r, p)
    return out

_torch_sqrt2 = np.sqrt(2.0)

def torch_norm_cdf(x):
    return 0.5 * (1.0 + torch.erf(x / _torch_sqrt2))

def torch_norm_icdf(u):
    return _torch_sqrt2 * torch.erfinv(2.0 * u - 1.0)

# def ghk_estimate_rectangle_prob_torch_batch(z_lower, z_upper, R, M=10000, device=None, eps=1e-14):
#     if device is None:
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     else:
#         device = torch.device(device)

#     z_lower = np.asarray(z_lower, dtype=float)
#     z_upper = np.asarray(z_upper, dtype=float)
#     R = np.asarray(R, dtype=float)

#     B, p = z_lower.shape
#     # Cholesky of R (correlation -> positive definite check)
#     try:
#         L = cholesky(R)
#     except LinAlgError as e:
#         raise ValueError("R is not positive definite") from e
#     # convert to torch on device
#     L_t = torch.as_tensor(L, dtype=torch.float64, device=device)
#     z_lower_t = torch.as_tensor(z_lower, dtype=torch.float64, device=device)
#     z_upper_t = torch.as_tensor(z_upper, dtype=torch.float64, device=device)

#     # We'll perform M draws per sample in parallel: tensors of shape (B, M)
#     # Initialize log-weights for each (B,M)
#     logw = torch.zeros((B, M), dtype=torch.float64, device=device)

#     # z_i storage: we only need values up to current i
#     # We'll keep z_prev as tensor of shape (i, B, M) initially empty
#     z_prev = None  # will be a tensor (i, B, M)

#     sqrt2 = _torch_sqrt2

#     for i in range(p):
#         # compute mu_cond = L[i, :i] @ z_prev (for each sample and each draw)
#         if i == 0:
#             mu_cond = torch.zeros((B, M), dtype=torch.float64, device=device)
#         else:
#             # z_prev: (i, B, M), L_row (i,)
#             L_row = L_t[i, :i]  # vector length i
#             # tensordot over axis 0 of z_prev (the gene axis) with L_row
#             # result shape (B, M)
#             mu_cond = torch.tensordot(L_row, z_prev, dims=([0], [0]))

#         sigma_cond = float(L_t[i, i].item())  # scalar

#         # compute truncation on standard normal (u-space):
#         a_u = (z_lower_t[:, i:i+1] - mu_cond) / sigma_cond
#         b_u = (z_upper_t[:, i:i+1] - mu_cond) / sigma_cond

#         # compute Phi(a_u), Phi(b_u)
#         Phi_a = torch_norm_cdf(a_u)
#         Phi_b = torch_norm_cdf(b_u)

#         # numerical safety: clamp differences away from 0
#         diff = (Phi_b - Phi_a).clamp(min=eps)

#         # sample uniform in [Phi_a, Phi_b] for each (B,M)
#         U = torch.rand((B, M), dtype=torch.float64, device=device)
#         U = Phi_a + U * (Phi_b - Phi_a)

#         # invert to get standard normal truncated sample u_i
#         u_i = torch_norm_icdf(U)

#         # compute z_i = mu_cond + sigma_cond * u_i
#         z_i = mu_cond + sigma_cond * u_i  # (B, M)

#         # append z_i to z_prev
#         if z_prev is None:
#             z_prev = z_i.unsqueeze(0)  # shape (1, B, M)
#         else:
#             z_prev = torch.cat([z_prev, z_i.unsqueeze(0)], dim=0)  # shape (i+1, B, M)

#         # update log-weights
#         logw = logw + torch.log(diff)

#     # weights: exp(logw) (shape B x M), average over M draws
#     # To avoid overflow/underflow we can use log-sum-exp trick per row:
#     # prob_b = mean_m exp(logw[b,m]) = exp(logsumexp(logw[b,:]) - log(M))
#     # implement vectorized:
#     max_logw, _ = logw.max(dim=1, keepdim=True)  # (B,1)
#     exp_shift = torch.exp(logw - max_logw)  # (B,M)
#     sum_exp = exp_shift.sum(dim=1)  # (B,)
#     probs_t = (sum_exp * torch.exp(max_logw.squeeze(1))) / float(M)  # (B,)

#     probs = probs_t.cpu().numpy()
#     return probs

def ghk_estimate_rectangle_prob_torch_batch(
    z_lower, z_upper, R, M=10000, device=None, eps=1e-12, min_sigma=1e-12
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(device)

    z_lower = np.asarray(z_lower, dtype=float)
    z_upper = np.asarray(z_upper, dtype=float)
    R = np.asarray(R, dtype=float)

    B, p = z_lower.shape
    try:
        L = cholesky(R)
    except LinAlgError as e:
        raise ValueError("R is not positive definite") from e

    L_t = torch.as_tensor(L, dtype=torch.float64, device=device)
    z_lower_t = torch.as_tensor(z_lower, dtype=torch.float64, device=device)
    z_upper_t = torch.as_tensor(z_upper, dtype=torch.float64, device=device)

    logw = torch.zeros((B, M), dtype=torch.float64, device=device)
    # z_prev = previuos (i, B, M)
    z_prev = None

    for i in range(p):
        if i == 0:
            mu_cond = torch.zeros((B, M), dtype=torch.float64, device=device)
        else:
            L_row = L_t[i, :i]
            mu_cond = torch.tensordot(L_row, z_prev, dims=([0], [0]))

        # conditional scale (coefficient of the new independent normal variate)
        sigma_cond = float(L_t[i, i].item())
        if sigma_cond <= 0.0:
            raise ValueError(f"Non-positive conditional scale L[{i},{i}]={sigma_cond}")
        # floor to avoid huge a_u/b_u from small sigma
        if sigma_cond < min_sigma:
            sigma_cond = min_sigma

        a_u = (z_lower_t[:, i:i+1] - mu_cond) / sigma_cond
        b_u = (z_upper_t[:, i:i+1] - mu_cond) / sigma_cond

        Phi_a = torch_norm_cdf(a_u)
        Phi_b = torch_norm_cdf(b_u)

        diff = (Phi_b - Phi_a)
        diff_clamped = diff.clamp(min=eps)

        # sample uniformly in the clamped interval (prevents exact 0/1 and erfinv(±1))
        U = Phi_a + torch.rand((B, M), dtype=torch.float64, device=device) * diff_clamped

        U = U.clamp(min=eps, max=1.0 - eps)
        u_i = torch_norm_icdf(U)
        z_i = mu_cond + sigma_cond * u_i

        # append z_i
        if z_prev is None:
            z_prev = z_i.unsqueeze(0)
        else:
            z_prev = torch.cat([z_prev, z_i.unsqueeze(0)], dim=0)

        # update log-weights using the clamped diff
        logw = logw + torch.log(diff_clamped)

    max_logw, _ = logw.max(dim=1, keepdim=True)
    exp_shift = torch.exp(logw - max_logw)
    sum_exp = exp_shift.sum(dim=1)
    probs_t = (sum_exp * torch.exp(max_logw.squeeze(1))) / float(M)

    probs = probs_t.cpu().numpy()
    return probs

# Compute marginal A/B over cell batch.
def build_ab_for_batch(x_batch, marg_params, eps=1e-12):
    x_batch = np.asarray(x_batch, dtype=int)
    B, p = x_batch.shape
    a = np.empty((B, p), dtype=float)
    b = np.empty((B, p), dtype=float)

    # Vectorized over genes (loop over p is fine if p moderate)
    for j in range(p):
        params = marg_params[j]
        xj = x_batch[:, j]
        if params[0] == 0:
            # not zero-inflated
            if params[1] == np.inf:
                mu = np.full(B, params[2], dtype=float)
                a[:, j] = cdf_poisson_vec(xj - 1, mu)
                b[:, j] = cdf_poisson_vec(xj, mu)
            else:
                theta = np.full(B, params[1], dtype=float)
                mu = np.full(B, params[2], dtype=float)
                a[:, j] = cdf_nb_theta_mu_vec(xj - 1, theta, mu)
                b[:, j] = cdf_nb_theta_mu_vec(xj, theta, mu)
        else:
            pi = np.full(B, params[0], dtype=float)
            if params[1] == np.inf:
                mu = np.full(B, params[2], dtype=float)
                a[:, j] = cdf_zip_vec(xj - 1, mu, pi)
                b[:, j] = cdf_zip_vec(xj, mu, pi)
            else:
                theta = np.full(B, params[1], dtype=float)
                mu = np.full(B, params[2], dtype=float)
                a[:, j] = cdf_zinb_theta_mu_vec(xj - 1, theta, mu, pi)
                b[:, j] = cdf_zinb_theta_mu_vec(xj, theta, mu, pi)

        a[:, j] = np.clip(a[:, j], eps, 1 - eps)
        b[:, j] = np.clip(b[:, j], eps, 1 - eps)
    return a, b

def gaussian_copula_point_probability_ghk_batch(
    X_batch, marg_params, Sigma, M=5000, device=None
):
    # convert Sigma to correlation matrix R (like original)
    diag = np.sqrt(np.diag(Sigma))
    R = Sigma / np.outer(diag, diag)
    np.fill_diagonal(R, 1.0)

    # Build a and b (cdf bounds)
    a_batch, b_batch = build_ab_for_batch(X_batch, marg_params)

    z_lower = sp_norm.ppf(a_batch)
    z_upper = sp_norm.ppf(b_batch)

    probs = ghk_estimate_rectangle_prob_torch_batch(z_lower, z_upper, R, M=M, device=device)
    return probs

def extract_row(cell, genes):
    arr = cell[:, genes].X
    if hasattr(arr, "toarray"):
        arr = arr.toarray().flatten()
    else:
        arr = np.asarray(arr).reshape(-1)
    return arr.astype(int)

def getPointScores_gpu(
    targetAD,
    auxMargParams, auxCovMat, auxCopulaGenes,
    synthMargParams, synthCovMat, synthCopulaGenes,
    batch_size=32, M=2000, device=None
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    n_cells = len(targetAD)
    scores = {}
    idx = 0
    noAuxNans = 0
    noSynthNans = 0
    
    while idx < n_cells:
        if idx % 10000 < batch_size: 
            print(f"{idx} / {n_cells}")
        batch_cells = []
        batch_indices = list(range(idx, min(idx + batch_size, n_cells)))
        for i in batch_indices:
            try:
                r = extract_row(targetAD[i], auxCopulaGenes)
            except Exception as e:
                print(e)
                r = None
            batch_cells.append(r)

        valid_mask = [row is not None for row in batch_cells]
        if not any(valid_mask):
            for ii, i_cell in enumerate(batch_indices):
                scores[i_cell] = None
            idx += batch_size
            continue

        valid_rows_idx = [ii for ii, ok in enumerate(valid_mask) if ok]
        X_aux = np.vstack([batch_cells[ii] for ii in valid_rows_idx])
        X_synth = np.vstack([
            extract_row(targetAD[batch_indices[ii]], synthCopulaGenes)
            for ii in valid_rows_idx
        ])

        try:
            aux_probs = gaussian_copula_point_probability_ghk_batch(
                X_aux, auxMargParams, auxCovMat, M=M, device=device
            )
        except Exception as e:
            aux_probs = np.full((len(valid_rows_idx),), np.nan)
            print("Error computing aux_probs for batch starting at", idx, ":", e)

        try:
            synth_probs = gaussian_copula_point_probability_ghk_batch(
                X_synth, synthMargParams, synthCovMat, M=M, device=device
            )
        except Exception as e:
            synth_probs = np.full((len(valid_rows_idx),), np.nan)
            print("Error computing synth_probs for batch starting at", idx, ":", e)

        vi = 0
        minNonZero = np.min(aux_probs[aux_probs!=0])
        minAuxProb = minNonZero
        for ii, ok in enumerate(valid_mask):
            global_idx = batch_indices[ii]
            if not ok:
                scores[global_idx] = None
            else:
                a_p = aux_probs[vi]
                s_p = synth_probs[vi]
                if np.isnan(a_p) or np.isnan(s_p):
                    if np.isnan(a_p):
                        noAuxNans += 1
                    if np.isnan(s_p):
                        noSynthNans += 1
                        
                    scores[global_idx] = None
                elif a_p == 0:
                    scores[global_idx] = float(s_p / minAuxProb)
                else:
                    scores[global_idx] = float(s_p / a_p)
                    #scores[global_idx] = float(s_p)
                vi += 1
        idx += batch_size
    #print(f"% SYNTH NANS: {noSynthNans / n_cells}")
    #print(f"% SYNTH NANS: {noAuxNans / n_cells}")
    return scores

In [90]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
import random
import pickle

def getIndScores(auxAD, trainAD, auxModelDir, trainModelDir, split=None):
    cellTypes = trainAD.obs['cell_type'].value_counts().sort_values(ascending=True).index.tolist()
    individualScores = {}
    trainInds = set(trainAD.obs.individual.unique())
    for ct in cellTypes:
        nTrue = auxAD[(auxAD.obs.cell_type==ct)&(auxAD.obs.individual.isin(trainInds))].n_obs
        nTotal = auxAD[(auxAD.obs.cell_type==ct)].n_obs
        print(f"PERCENT MEMBERS: {nTrue/nTotal}")
        if auxAD[auxAD.obs.cell_type==ct].n_obs > 20000:
            continue
        synthModel = rdata.read_rds(f"{trainModelDir}/{ct}.rds")[str(ct)]
        #synthCopulaGenes = synthModel['gene_sel1']
        #synthCopulaGenes = [x - 1 for x in synthModel['gene_sel1']]
        synthCopulaGenes = synthModel['marginal_param1'].coords["dim_0"].data
        synthCovMat = synthModel['cov_mat']
        synthCopulaMarginals = synthModel['marginal_param1']

        auxModel = rdata.read_rds(f"{auxModelDir}/{ct}.rds")[str(ct)]
        #auxCopulaGenes = auxModel['gene_sel1']
        #auxCopulaGenes = [x - 1 for x in auxModel['gene_sel1']]
        auxCopulaGenes = auxModel['marginal_param1'].coords["dim_0"].data
        
        auxCovMat = auxModel['cov_mat']
        auxCopulaMarginals = auxModel['marginal_param1']

        auxCT = auxAD[auxAD.obs.cell_type==ct]
        all_inds = auxCT.obs["individual"].astype(str)
        actualLabels = all_inds.isin(trainInds).tolist()
        
        scores = getPointScores_gpu(
            auxCT,
            auxCopulaMarginals, auxCovMat, auxCopulaGenes, 
            synthCopulaMarginals, synthCovMat, synthCopulaGenes,
            batch_size=500, M=2000, device=None
        )
        with open(f'scores/split{split}_ct{ct}.pickle', 'wb') as handle:
            pickle.dump(scores, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
        nonUnderflowScores = [x for x in scores.values() if x is not None]
        if len(nonUnderflowScores) == 0:
            print(f"Underflows affected all cell types in ct {ct}")
            continue
        percentile_66 = np.percentile(nonUnderflowScores, 66)
        #percentile_66 = np.percentile(nonUnderflowScores, 90)
        min_score = np.min(nonUnderflowScores)
        max_score = np.max(nonUnderflowScores)

        predictedLabels = np.interp(
            [score if score is not None else np.nan for score in scores.values()],
            (min_score, max_score),
            (0, 1)
        )
#         predictedLabels = [bool(random.getrandbits(1)) for i in range(len(scores))]
#         for ind, score in scores.items():
#             if score is not None:
#                 predictedLabels[ind] = score > percentile_66
                
        auroc = roc_auc_score(actualLabels, predictedLabels)
        print(f"AUROC CT {ct}:", auroc)
#         tn, fp, fn, tp = confusion_matrix(actualLabels, predictedLabels).ravel()
#         print(f"CT {ct}: TP {tp / nTotal}, FP {fp / nTotal}")
        for ind, label in enumerate(predictedLabels):
            individual = auxCT.obs.iloc[ind]["individual"]
            if individual not in individualScores:
                individualScores[individual] = 0
            individualScores[individual] += label
                
    for individual, sumScore in individualScores.items():
        nCells = auxAD[auxAD.obs.individual==individual].n_obs
        individualScores[individual] /= nCells
    
    return individualScores


In [91]:
import scanpy as sc
import rdata


def runMIA(auxAdPath, trainAdPath, auxModelDir, trainModelDir, split=None):
    auxAD = sc.read_h5ad(auxAdPath)
    trainAD = sc.read_h5ad(trainAdPath)
    trainInds = set(trainAD.obs.individual.unique())
    indScores = getIndScores(auxAD, trainAD, auxModelDir, trainModelDir, split=split)
    percentile_66 = np.percentile(list(indScores.values()), 66)
    predictedLabels = []
    actualLabels = []
    
    for ind, score in indScores.items():
        if score is not None:
            predictedLabels.append(score > percentile_66)
            actualLabels.append(ind in trainInds)
    
    return roc_auc_score(actualLabels, predictedLabels)

# def runMIA2(auxAdPath, trainAdPath, auxModelDir, trainModelDir, split=None):
#     auxAD = sc.read_h5ad(auxAdPath)
#     trainAD = sc.read_h5ad(trainAdPath)
#     trainInds = set(trainAD.obs.individual.unique())
#     for ct in auxAD.obs.cell_type.unique():
#         if auxAD[auxAD.obs.cell_type==ct].n_obs > 20000:
#             continue
#         with open(f"scores/split{split}_ct{ct}.pickle", "rb") as f:
#             scores = pickle.load(f)
            
#         print(scores)
# #     indScores = getIndScores(auxAD, trainAD, auxModelDir, trainModelDir, split=split)
# #     percentile_66 = np.percentile(list(indScores.values()), 66)
# #     predictedLabels = []
# #     actualLabels = []
    
# #     for ind, score in indScores.items():
# #         if score is not None:
# #             predictedLabels.append(score > percentile_66)
# #             actualLabels.append(ind in trainInds)
    
# #     return roc_auc_score(actualLabels, predictedLabels)

for i in [3,1,2]:
    print(f"Split {i}")
    score = runMIA("train.h5ad", f"splits/{i}/train.h5ad", "models", f"splits/{i}/model", split=str(i))
    print(f"Score {score}")

Split 3
PERCENT MEMBERS: 0.3172645739910314
0 / 1784
AUROC CT 8: 0.8057625023934274
PERCENT MEMBERS: 0.4180618975139523
0 / 1971
AUROC CT 2: 0.536271277541243
PERCENT MEMBERS: 0.29753966429064155
0 / 4349
AUROC CT 6: 0.6516422263651702
PERCENT MEMBERS: 0.2981670929241262
0 / 4692
AUROC CT 15: 0.6530034142212986
PERCENT MEMBERS: 0.32225063938618925
0 / 7038
AUROC CT 10: 0.7042020232271805
PERCENT MEMBERS: 0.342255350386809
0 / 17709
10000 / 17709
AUROC CT 5: 0.4411277385273528
PERCENT MEMBERS: 0.36346033373198927
0 / 17559
10000 / 17559
AUROC CT 9: 0.6541106864622467
PERCENT MEMBERS: 0.3333931696809227
PERCENT MEMBERS: 0.3414830563071769
PERCENT MEMBERS: 0.32885756676557865
PERCENT MEMBERS: 0.3262453709234261
PERCENT MEMBERS: 0.31024163707254077
PERCENT MEMBERS: 0.293648705279264
PERCENT MEMBERS: 0.3172738155590868
Score 0.5738066736495586
Split 1
PERCENT MEMBERS: 0.3054932735426009
0 / 1784
AUROC CT 8: 0.8083272245299924
PERCENT MEMBERS: 0.3333333333333333
0 / 1971
AUROC CT 2: 0.488540

In [44]:
import scanpy as sc

ad1 = sc.read_h5ad("splits/1/train.h5ad")
ad2 = sc.read_h5ad("splits/2/train.h5ad")
ad3 = sc.read_h5ad("splits/3/train.h5ad")

In [58]:
ads = {1: ad1, 2: ad2, 3: ad3}
cts = ad1.obs.cell_type.unique()
for ct in cts:
    for split in [1,2,3]:
        rModel = rdata.read_rds(f"splits/{split}/model/{ct}.rds")
        copulaGenes = [x-1 for x in rModel[str(ct)]['gene_sel1']]
        ad = ads[split]
        expr = ad[ad.obs.cell_type==ct,copulaGenes].X != 0
        #print(expr.shape)
        density = (expr.sum(axis=1) / expr.shape[1]).mean()
        density = expr.mean()
        #print(density.shape)
        print(f"Split {split} CT {ct} density: {density}")

Split 1 CT 0 density: 0.3819313199512026
Split 2 CT 0 density: 0.3636809894173149
Split 3 CT 0 density: 0.36867927061793365
Split 1 CT 3 density: 0.411434081633677
Split 2 CT 3 density: 0.40001329004354264
Split 3 CT 3 density: 0.4011154692536468
Split 1 CT 13 density: 0.3989216445716977
Split 2 CT 13 density: 0.394272856997793
Split 3 CT 13 density: 0.39271322519123103
Split 1 CT 14 density: 0.42907974013721534
Split 2 CT 14 density: 0.42166655073733483
Split 3 CT 14 density: 0.4123314276335586
Split 1 CT 4 density: 0.3727609770669855
Split 2 CT 4 density: 0.3792973647738235
Split 3 CT 4 density: 0.3779255782125085
Split 1 CT 12 density: 0.3931845113390907
Split 2 CT 12 density: 0.3922982953299932
Split 3 CT 12 density: 0.390269994544849
Split 1 CT 8 density: 0.46740015175553645
Split 2 CT 8 density: 0.46753461166536847
Split 3 CT 8 density: 0.45435563472826634
Split 1 CT 5 density: 0.38541265697926175
Split 2 CT 5 density: 0.38972019712458766
Split 3 CT 5 density: 0.39405083406420705

In [59]:
print(ad1[ad1.obs.cell_type==0].n_obs)

81827


In [37]:
# ---------- Batched getPointScores (synthetic probs only) ----------
def getPointScores_gpuNew(
    targetAD,
    synthMargParams, synthCovMat, synthCopulaGenes,
    batch_size=32, M=2000, device=None
):
    """
    Compute synthetic probabilities only, batched.

    targetAD: AnnData-like object or list of cell slices
    synthMargParams, synthCovMat, synthCopulaGenes: parameters for synthetic model
    batch_size: number of cells processed together
    M: MC draws for GHK
    device: cuda/cpu
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    def extract_row(cell, genes):
        arr = cell[:, genes].X
        if hasattr(arr, "toarray"):
            arr = arr.toarray().flatten()
        else:
            arr = np.asarray(arr).reshape(-1)
        return arr.astype(int)

    n_cells = len(targetAD)
    scores = {}
    idx = 0

    while idx < n_cells:
        if idx % 10000 < batch_size:
            print(f"{idx} / {n_cells}")

        batch_cells = []
        batch_indices = list(range(idx, min(idx + batch_size, n_cells)))
        for i in batch_indices:
            try:
                r = extract_row(targetAD[i], synthCopulaGenes)
            except Exception:
                r = None
            batch_cells.append(r)

        valid_mask = [row is not None for row in batch_cells]
        if not any(valid_mask):
            for ii, i_cell in enumerate(batch_indices):
                scores[i_cell] = None
            idx += batch_size
            continue

        valid_rows_idx = [ii for ii, ok in enumerate(valid_mask) if ok]
        X_synth = np.vstack([batch_cells[ii] for ii in valid_rows_idx])

        try:
            synth_probs = gaussian_copula_point_probability_ghk_batch(
                X_synth, synthMargParams, synthCovMat, M=M, device=device
            )
        except Exception as e:
            synth_probs = np.full((len(valid_rows_idx),), np.nan)
            print("Error computing synth_probs for batch starting at", idx, ":", e)

        vi = 0
        for ii, ok in enumerate(valid_mask):
            global_idx = batch_indices[ii]
            if not ok or np.isnan(synth_probs[vi]):
                scores[global_idx] = None
            else:
                scores[global_idx] = float(synth_probs[vi])
                vi += 1
        idx += batch_size

    return scores


# ---------- Get individual scores via statistical test ----------
def getIndScoresNew(auxAD, trainAD, synthModelDir):
    """
    Compare synthetic probabilities between train individuals and others.
    Uses Mann-Whitney U test per cell type.
    """
    import rdata
    individualScores = {}
    trainInds = set(trainAD.obs.individual.unique())
    cellTypes = auxAD.obs.cell_type.unique()

#     for ct in cellTypes:
    for ct in [8]:
        auxCT = auxAD[auxAD.obs.cell_type == ct]
        if auxCT.n_obs == 0:
            continue

        synthModel = rdata.read_rds(f"{synthModelDir}/{ct}.rds")[str(ct)]
        synthCopulaGenes = [x - 1 for x in synthModel['gene_sel1']]
        synthCovMat = synthModel['cov_mat']
        synthCopulaMarginals = synthModel['marginal_param1']

        scores = getPointScores_gpuNew(
            auxCT,
            synthCopulaMarginals, synthCovMat, synthCopulaGenes,
            batch_size=500, M=2000, device=None
        )

        all_scores = np.array([scores[i] if scores[i] is not None else np.nan for i in range(len(scores))])
        labels = auxCT.obs['individual'].astype(str).isin(trainInds).to_numpy()

        # Separate train vs other cells
        train_scores = all_scores[labels]
        other_scores = all_scores[~labels]

        # Remove NaNs
        print(len(train_scores), len(other_scores))
        train_scores = train_scores[~np.isnan(train_scores)]
        other_scores = other_scores[~np.isnan(other_scores)]
        print(len(train_scores), len(other_scores))
        if len(train_scores) == 0 or len(other_scores) == 0:
            print(f"No valid scores for cell type {ct}")
            continue

        # Mann-Whitney U test (one-sided: train > others)
        stat, pval = mannwhitneyu(train_scores, other_scores, alternative='greater')
        print(f"CT {ct}: Mann-Whitney U p-value = {pval:.3e}")

        # Aggregate per individual: mean synthetic probability
        for ind in auxCT.obs['individual'].unique():
            ind_mask = auxCT.obs['individual'] == ind
            ind_scores = all_scores[ind_mask.to_numpy()]
            ind_scores = ind_scores[~np.isnan(ind_scores)]
            if len(ind_scores) == 0:
                continue
            individualScores[ind] = np.mean(ind_scores)

    return individualScores

def runMIA(auxAdPath, trainAdPath, auxModelDir, trainModelDir):
    auxAD = sc.read_h5ad(auxAdPath)
    trainAD = sc.read_h5ad(trainAdPath)
    trainInds = set(trainAD.obs.individual.unique())
    indScores = getIndScoresNew(auxAD, trainAD, trainModelDir)
    percentile_66 = np.percentile(list(indScores.values()), 66)
    predictedLabels = []
    actualLabels = []
    
    for ind, score in indScores.items():
        if score is not None:
            predictedLabels.append(score > percentile_66)
            actualLabels.append(ind in trainInds)
    
    return roc_auc_score(actualLabels, predictedLabels)

for i in [2,3,1]:
    print(f"Split {i}")
    score = runMIA("train.h5ad", f"splits/{i}/train.h5ad", "models", f"splits/{i}/model")
    print(f"Score {score}")

Split 2
0 / 1784
673 1111
0 0
No valid scores for cell type 8


IndexError: index -1 is out of bounds for axis 0 with size 0

In [11]:
import torch
torch.cuda.empty_cache()

In [46]:
scoresGPU = getPointScores_gpu(
    adata_sample,
    auxCopulaMarginals, auxCovMat, auxCopulaGenes, 
    synthCopulaMarginals, synthCovMat, synthCopulaGenes,
    batch_size=32, M=2000, device=None
)

In [95]:
import rdata
import pandas as pd
import scanpy as sc

synthModel = rdata.read_rds("splits/1/model/0.rds")['0']
synthCopulaGenes = synthModel['gene_sel1']
synthCovMat = synthModel['cov_mat']
synthCopulaMarginals = synthModel['marginal_param1']

auxModel = rdata.read_rds("models/0.rds")['0']
auxCopulaGenes = auxModel['gene_sel1']
auxCovMat = auxModel['cov_mat']
auxCopulaMarginals = auxModel['marginal_param1']

auxAD = sc.read_h5ad("train.h5ad")
auxCT0 = auxAD[auxAD.obs.cell_type==0]

trainAD = sc.read_h5ad("splits/1/train.h5ad")
trainInds = set(trainAD.obs.individual.unique())
all_inds = auxCT0.obs["individual"].astype(str)
actualLabels = all_inds.isin(trainInds).tolist()



# actualLabels = []
# for i, cell in enumerate(auxCT0):
#     ind = cell.obs["individual"].astype(str)[0]
#     if ind in trainInds:
#         actualLabels.append(True)
#     else:
#         actualLabels.append(False)

# idx = np.random.choice(auxCT0.n_obs, size=1000, replace=False)
# adata_sample = auxCT0[idx].copy()

# scores = getPointScores(
#     adata_sample, auxCopulaMarginals, auxCovMat, auxCopulaGenes, 
#     synthCopulaMarginals, synthCovMat, synthCopulaGenes
# )

In [105]:
auxAD.obs.cell_type.value_counts()

cell_type
0     234731
3     102609
14     80534
4      66968
13     40440
1      31044
12     22283
5      17709
9      17559
10      7038
15      4692
6       4349
2       1971
8       1784
Name: count, dtype: int64

In [100]:
auxCT0[0]

View of AnnData object with n_obs × n_vars = 1 × 1118
    obs: 'individual', 'cell_type', 'cell_label', 'barcode_col', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
    var: 'gene_ids', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'

In [57]:
scoresGPU = getPointScores_gpu(
    auxCT0,
    auxCopulaMarginals, auxCovMat, auxCopulaGenes, 
    synthCopulaMarginals, synthCovMat, synthCopulaGenes,
    batch_size=1000, M=2000, device=None
)

0 / 234731
1000 / 234731
2000 / 234731
3000 / 234731
4000 / 234731
5000 / 234731
6000 / 234731
7000 / 234731
8000 / 234731
9000 / 234731
10000 / 234731
11000 / 234731
12000 / 234731
13000 / 234731
14000 / 234731
15000 / 234731
16000 / 234731
17000 / 234731
18000 / 234731
19000 / 234731
20000 / 234731
21000 / 234731
22000 / 234731
23000 / 234731
24000 / 234731
25000 / 234731
26000 / 234731
27000 / 234731
28000 / 234731
29000 / 234731
30000 / 234731
31000 / 234731
32000 / 234731
33000 / 234731
34000 / 234731
35000 / 234731
36000 / 234731
37000 / 234731
38000 / 234731
39000 / 234731
40000 / 234731
41000 / 234731
42000 / 234731
43000 / 234731
44000 / 234731
45000 / 234731
46000 / 234731
47000 / 234731
48000 / 234731
49000 / 234731
50000 / 234731
51000 / 234731
52000 / 234731
53000 / 234731
54000 / 234731
55000 / 234731
56000 / 234731
57000 / 234731
58000 / 234731
59000 / 234731
60000 / 234731
61000 / 234731
62000 / 234731
63000 / 234731
64000 / 234731
65000 / 234731
66000 / 234731
67000 / 

In [96]:
# oldScores = scores
scores = scoresGPU

In [97]:
import random
percentile_66 = np.percentile([x for x in scores.values() if x is not None], 66)
predictedLabels = [bool(random.getrandbits(1)) for i in range(len(scores))]
for ind, score in scores.items():
    if score is not None:
        predictedLabels[ind] = score > percentile_66

In [98]:
from sklearn.metrics import roc_auc_score

auroc = roc_auc_score(actualLabels, predictedLabels)
print("AUROC:", auroc)

AUROC: 0.5399298219241364


In [29]:
import scanpy as sc
trainAD = sc.read_h5ad("splits/1/train.h5ad")



In [46]:
ct0 = trainAD[trainAD.obs.cell_type==0, copulaGenes]
cell = ct0[0].X.toarray().flatten()

In [54]:
gaussian_copula_point_probability_ghk(cell, copulaMarginals, covMat)

9.530491408619522e-17

In [63]:
vals = []
for i in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]:
    vals.append(gaussian_copula_point_probability_ghk(cell, copulaMarginals, covMat, M=i))

In [70]:
gaussian_copula_point_probability_ghk(cell, copulaMarginals, covMat, M=100)

8.94688859631545e-17

In [78]:
for i, c in enumerate(trainAD):
    print(i)
    print(c.X.toarray().flatten())
    break

0
[0. 0. 0. ... 0. 0. 0.]
