In [1]:
import numpy as np
import scanpy as sc
from scipy.sparse import issparse
from scipy.spatial.distance import cdist

# train_adata = ad.read_h5ad("100_ind_1000_gene_ct0.h5ad")
# holdout_adata = ad.read_h5ad("group1_holdout.h5ad")
# synthetic_adata = ad.read_h5ad("1000_genes_100_ind_subset1_out.h5ad")

synthetic_adata = sc.read_h5ad("good_synth.h5ad")
sc.pp.filter_cells(synthetic_adata, min_counts=10)
sc.pp.filter_genes(synthetic_adata, min_cells=3)
sc.pp.normalize_total(synthetic_adata, target_sum=1e4)
sc.pp.log1p(synthetic_adata)
synth_genes = synthetic_adata.var.index

train_adata = sc.read_h5ad("onek1k_annotated_train.h5ad")
train_adata = train_adata[:,synth_genes]
sc.pp.normalize_total(train_adata, target_sum=1e4)
sc.pp.log1p(train_adata)

holdout_adata = sc.read_h5ad("onek1k_annotated_test.h5ad")
holdout_adata = holdout_adata[:,synth_genes]
sc.pp.normalize_total(holdout_adata, target_sum=1e4)
sc.pp.log1p(holdout_adata)


ST = train_adata.X
SC = holdout_adata.X
SG = synthetic_adata.X

  view_to_actual(adata)
  view_to_actual(adata)


In [6]:
from scipy.sparse import vstack
m = ST.shape[0]
training_indices = set(range(m))
member_labels = np.concatenate([np.ones(m), np.zeros(SC.shape[0])])
combined = vstack([ST, SC])
n_records = combined.shape[0]

# Compute min distance per record (for median heuristic radius)
SG_cuda = torch.tensor(SG.toarray(), device='cuda', dtype=torch.float32)
min_dists = np.zeros(n_records)
for start in range(0, n_records, 500):
    print(f"{start}/{n_records}")
    end = min(start + 500, n_records)
    batch = combined[start:end]
    batch_cuda = torch.tensor(batch.toarray(), device='cuda', dtype=torch.float32)
    # dists = pairwise_euclidean(batch_cuda, SG_cuda)
    dists = torch.cdist(batch_cuda, SG_cuda)
    min_dists[start:end] = torch.min(dists.cpu(), axis=1).values

r = np.median(min_dists)
print(f"Neighborhood radius {r}")

0/1267733
500/1267733
1000/1267733
1500/1267733
2000/1267733
2500/1267733
3000/1267733
3500/1267733
4000/1267733
4500/1267733
5000/1267733
5500/1267733
6000/1267733
6500/1267733
7000/1267733
7500/1267733
8000/1267733
8500/1267733
9000/1267733
9500/1267733
10000/1267733
10500/1267733
11000/1267733
11500/1267733
12000/1267733
12500/1267733
13000/1267733
13500/1267733
14000/1267733
14500/1267733
15000/1267733
15500/1267733
16000/1267733
16500/1267733
17000/1267733
17500/1267733
18000/1267733
18500/1267733
19000/1267733
19500/1267733
20000/1267733
20500/1267733
21000/1267733
21500/1267733
22000/1267733
22500/1267733
23000/1267733
23500/1267733
24000/1267733
24500/1267733
25000/1267733
25500/1267733
26000/1267733
26500/1267733
27000/1267733
27500/1267733
28000/1267733
28500/1267733
29000/1267733
29500/1267733
30000/1267733
30500/1267733
31000/1267733
31500/1267733
32000/1267733
32500/1267733
33000/1267733
33500/1267733
34000/1267733
34500/1267733
35000/1267733
35500/1267733
36000/1267733
36

In [7]:
batch_size = 500
scores = np.zeros(n_records)
eta = 1e-12
for start in range(0, n_records, batch_size):
    print(f"{start}/{n_records}")
    end = min(start + batch_size, n_records)
    batch = combined[start:end]
    batch_cuda = torch.tensor(batch.toarray(), device='cuda', dtype=torch.float32)
    dists = torch.cdist(batch_cuda, SG_cuda)
    within_r = dists <= r

    for i in range(end - start):
        mask = within_r[i]
        if mask.any():
            scores[start + i] = torch.mean(torch.log(dists[i, mask] + eta)).item()
        else:
            scores[start + i] = 0.0

0/1267733
500/1267733
1000/1267733
1500/1267733
2000/1267733
2500/1267733
3000/1267733
3500/1267733
4000/1267733
4500/1267733
5000/1267733
5500/1267733
6000/1267733
6500/1267733
7000/1267733
7500/1267733
8000/1267733
8500/1267733
9000/1267733
9500/1267733
10000/1267733
10500/1267733
11000/1267733
11500/1267733
12000/1267733
12500/1267733
13000/1267733
13500/1267733
14000/1267733
14500/1267733
15000/1267733
15500/1267733
16000/1267733
16500/1267733
17000/1267733
17500/1267733
18000/1267733
18500/1267733
19000/1267733
19500/1267733
20000/1267733
20500/1267733
21000/1267733
21500/1267733
22000/1267733
22500/1267733
23000/1267733
23500/1267733
24000/1267733
24500/1267733
25000/1267733
25500/1267733
26000/1267733
26500/1267733
27000/1267733
27500/1267733
28000/1267733
28500/1267733
29000/1267733
29500/1267733
30000/1267733
30500/1267733
31000/1267733
31500/1267733
32000/1267733
32500/1267733
33000/1267733
33500/1267733
34000/1267733
34500/1267733
35000/1267733
35500/1267733
36000/1267733
36

In [9]:
from sklearn.metrics import roc_auc_score

top_indices = np.argsort(scores)[-m:]
training_indices = set(range(m))
top_set = set(top_indices)
auc = roc_auc_score(member_labels, scores)
print(f"AUC: {auc}")

AUC: 0.503885430534968
