In [17]:
import anndata as ad
from os import sys
import scanpy as sc
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, roc_auc_score
from scipy import stats
import os

sys.path.append('..')
from VAE.VAE_model import VAE

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [18]:
def load_VAE():
    autoencoder = VAE(
        num_genes=17789,
        device='cuda',
        seed=0,
        loss_ae='mse',
        hidden_dim=128,
        decoder_activation='ReLU',
    )
    autoencoder.load_state_dict(torch.load('../data/pbmc_AE/model_seed=0_step=199999.pt'))
    return autoencoder

In [19]:
adata = sc.read_10x_mtx(
    '../data/pbmc68k/data/pbmc68k/filtered_matrices_mex/hg19/',
    var_names='gene_symbols', 
    cache=True
)
adata.var_names_make_unique()
sc.pp.filter_cells(adata, min_genes=10)
sc.pp.filter_genes(adata, min_cells=3)

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

celltype = pd.read_csv('../data/pbmc68k/data/pbmc68k/filtered_matrices_mex/68k_pbmc_barcodes_annotation.tsv', sep='\t')[
    'celltype'].values
adata.obs['celltype'] = celltype

In [20]:
cato = ['CD14+ Monocyte', 'CD19+ B', 'CD34+', 'CD4+ T Helper2', 'CD4+/CD25 T Reg',
        'CD4+/CD45RA+/CD25- Naive T', 'CD4+/CD45RO+ Memory', 'CD56+ NK',
        'CD8+ Cytotoxic T', 'CD8+/CD45RA+ Naive Cytotoxic', 'Dendritic']
index2 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

cell_gen_all = []
gen_class = []
for i in index2:
    npzfile = np.load(f'../samples/pbmc/label/l1/pbmc{i}.npz', allow_pickle=True)
    length = min(adata[adata.obs['celltype'] == cato[i]].X.shape[0], npzfile['samples'].shape[0])
    print(
        f"Class {cato[i]} - Real data size: {adata[adata.obs['celltype'] == cato[i]].X.shape[0]}, Generated data size: {length}")

    cell_gen_all.append(npzfile['samples'][:int(length)])
    gen_class += [f'gen {cato[i]}'] * int(length)

cell_gen_all = np.concatenate(cell_gen_all, axis=0)


Class CD14+ Monocyte - Real data size: 2862, Generated data size: 1024
Class CD19+ B - Real data size: 5908, Generated data size: 1024
Class CD34+ - Real data size: 277, Generated data size: 277
Class CD4+ T Helper2 - Real data size: 97, Generated data size: 97
Class CD4+/CD25 T Reg - Real data size: 6187, Generated data size: 1024
Class CD4+/CD45RA+/CD25- Naive T - Real data size: 1873, Generated data size: 1024
Class CD4+/CD45RO+ Memory - Real data size: 3061, Generated data size: 1024
Class CD56+ NK - Real data size: 8776, Generated data size: 1024
Class CD8+ Cytotoxic T - Real data size: 20773, Generated data size: 1024
Class CD8+/CD45RA+ Naive Cytotoxic - Real data size: 16666, Generated data size: 1024
Class Dendritic - Real data size: 2099, Generated data size: 1024


In [21]:
autoencoder = load_VAE()
cell_gen_all = autoencoder(torch.tensor(cell_gen_all).cuda(), return_decoded=True).cpu().detach().numpy()

gen_adata = ad.AnnData(cell_gen_all, dtype=np.float32)
gen_adata.obs['celltype'] = gen_class

auc_values = []

for cell_type in cato:
    print(f"\nEvaluating cell type: {cell_type}")

    real_data = adata[adata.obs['celltype'] == cell_type].X.toarray()
    gen_data = gen_adata[gen_adata.obs['celltype'] == f'gen {cell_type}'].X

    if real_data.shape[0] == 0 or gen_data.shape[0] == 0:
        print(f"No data for cell type {cell_type}. Skipping evaluation.")
        continue

    combined_data = np.concatenate((real_data, gen_data), axis=0)
    combined_labels = np.concatenate((np.ones(real_data.shape[0]), np.zeros(gen_data.shape[0])))


    combined_adata = ad.AnnData(combined_data, dtype=np.float32)
    sc.tl.pca(combined_adata, n_comps=2, svd_solver='arpack')
    pca_data = combined_adata.obsm['X_pca']


    X_train, X_val, y_train, y_val = train_test_split(pca_data, combined_labels, test_size=0.25, random_state=1)

    rfc = RandomForestClassifier(n_estimators=1000, max_depth=5, oob_score=True, class_weight="balanced",
                                 random_state=1)
    rfc.fit(X_train, y_train)

    y_pred_train = rfc.predict(X_train)
    y_pred_val = rfc.predict(X_val)

    auc = roc_auc_score(y_val, y_pred_val)
    auc_values.append((cell_type, auc))
    print(f"AUC for {cell_type}: {auc}")

print("\nSummary of AUC values for all cell types:")
for cell_type, auc in auc_values:
    print(f"{cell_type}: {auc:.4f}")




Evaluating cell type: CD14+ Monocyte




AUC for CD14+ Monocyte: 0.8583333333333332

Evaluating cell type: CD19+ B




AUC for CD19+ B: 0.7467967267933506

Evaluating cell type: CD34+




AUC for CD34+: 0.7304166666666666

Evaluating cell type: CD4+ T Helper2




AUC for CD4+ T Helper2: 0.5777591973244147

Evaluating cell type: CD4+/CD25 T Reg




AUC for CD4+/CD25 T Reg: 0.5543501451869737

Evaluating cell type: CD4+/CD45RA+/CD25- Naive T




AUC for CD4+/CD45RA+/CD25- Naive T: 0.5344770744572936

Evaluating cell type: CD4+/CD45RO+ Memory




AUC for CD4+/CD45RO+ Memory: 0.508309690091663

Evaluating cell type: CD56+ NK




AUC for CD56+ NK: 0.5815156304882332

Evaluating cell type: CD8+ Cytotoxic T




AUC for CD8+ Cytotoxic T: 0.5568047443293797

Evaluating cell type: CD8+/CD45RA+ Naive Cytotoxic




AUC for CD8+/CD45RA+ Naive Cytotoxic: 0.5850314824847053

Evaluating cell type: Dendritic




AUC for Dendritic: 0.6642091766645297

Summary of AUC values for all cell types:
CD14+ Monocyte: 0.8583
CD19+ B: 0.7468
CD34+: 0.7304
CD4+ T Helper2: 0.5778
CD4+/CD25 T Reg: 0.5544
CD4+/CD45RA+/CD25- Naive T: 0.5345
CD4+/CD45RO+ Memory: 0.5083
CD56+ NK: 0.5815
CD8+ Cytotoxic T: 0.5568
CD8+/CD45RA+ Naive Cytotoxic: 0.5850
Dendritic: 0.6642
