In [2]:
def extract_latent_space(data, model):
    model.eval()
    bio_z, mu1, logvar1, batch_z, batch_mu, batch_logvar, bio_batch_pred, batch_batch_pred, _mean, _disp, _pi, size_factor, size_mu, size_logvar = model(data)
    z1 = bio_z
    z2 = batch_z
    return z1.detach().cpu().numpy(), z2.detach().cpu().numpy(), _mean.detach().cpu().numpy(), _disp.detach().cpu().numpy(), _pi.detach().cpu().numpy()

In [None]:
import sys
import os
import importlib

scib_path = '/home/haiping_liu/code/My_model/Batch_VAE/Results/scib'
if scib_path not in sys.path:
    sys.path.append(scib_path)

import scib
importlib.reload(scib)

print(scib.__file__)

## Inference and visulization

In [4]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))

if project_root not in sys.path:
    sys.path.append(project_root)

from models.model import GeneVAE
from utils.dataset import GeneralDataset, GeneDataset

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np

# load model
checkpoint_path = '/home/haiping_liu/code/My_model/Batch_VAE1/saved/models/Immune/1228_194336/checkpoint-epoch65.pth'
checkpoint = torch.load(checkpoint_path)
config = checkpoint['config']._config if hasattr(checkpoint['config'], '_config') else checkpoint['config']
config_args = config['arch']['args']
model = GeneVAE(**config_args)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])

data_dir = "/home/haiping_liu/code/My_model/Batch_VAE1/Data/Gene_data/csv_format/human_immune.csv"
train_dataset = GeneDataset(data_dir)
dataloader = DataLoader(train_dataset, shuffle=False, batch_size=128)

cell_types = []
latent = []
batch_ids = []
combine_z2 = []

model.eval()
model.to('cuda') 
with torch.no_grad():
    for batch in dataloader:
        features, batch_id, cell_type = batch
        features = features.to('cuda')
        cell_type = cell_type.to('cuda')
        batch_id = batch_id.to('cuda')
        z1, z2, _mean, _disp, _pi = extract_latent_space(features, model)

        latent.append(z1)
        combine_z2.append(z2)
        cell_types.append(cell_type.cpu().numpy())
        batch_ids.append(batch_id.cpu().numpy())

latent =  np.concatenate(latent, axis=0)
z2 = np.concatenate(combine_z2, axis=0)
cell_types = np.concatenate(cell_types, axis=0)
batch_ids = np.concatenate(batch_ids, axis=0)

# cell_type_names = dataset.cell_type
# cell_types_named = [cell_type_names[code] for code in cell_types]
data_df = pd.read_csv(data_dir)
cell_types = data_df['cell_type'].values 
batch_ids = data_df['batch'].values 

adata_z1 = sc.AnnData(latent)
adata_z1.obs['batch'] = batch_ids
adata_z1.obs['cell_type'] = cell_types

In [None]:
sc.pp.pca(adata_z1)
sc.pp.neighbors(adata_z1, use_rep='X_pca')  
sc.tl.umap(adata_z1) 
sc.pl.umap(adata_z1, color=['batch'] , title="UMAP - Batch ID", size=10)
sc.pl.umap(adata_z1, color=['cell_type'] , title="UMAP - Batch ID", size=10)
plt.show()

In [42]:
import pandas as pd

df = pd.DataFrame(adata_z1.X, columns=[f'latent_{i}' for i in range(adata_z1.X.shape[1])])

df['BATCH'] = adata_z1.obs['batch'].values
df['celltype'] = adata_z1.obs['cell_type'].values

df.to_csv('human_berd.csv', index=False)

In [None]:
sc.pp.subsample(adata_z1, fraction=0.3)
sc.pp.neighbors(adata_z1, use_rep='X')  
sc.tl.umap(adata_z1) 
sc.pl.umap(adata_z1, color=['batch', 'cell_type'] , title="UMAP - Batch ID", size=10)
plt.show()

In [None]:
adata_z2 = sc.AnnData(z2)  
adata_z2.obs['batch'] = adata_z1.obs['batch'] 
adata_z2.obs['cell_type'] = adata_z1.obs['cell_type']

# visulization
sc.pp.subsample(adata_z2, fraction=0.3)
sc.pp.neighbors(adata_z2, use_rep='X')  
sc.tl.umap(adata_z2) 
sc.settings.set_figure_params(fontsize=12)  
sc.pl.umap(adata_z2, color=['batch', 'cell_type'] , title="UMAP - Batch ID", size=3)
plt.show()

## Evaluation

### 1. Batch effect

In [38]:
import scib
ilisi_score = scib.metrics.ilisi_graph(adata_z1, batch_key='batch', type_="full")

In [None]:
scib.me.ilisi_graph(adata_z1, batch_key="batch", type_="full")

In [None]:
ilisi_score

In [None]:
scib.me.graph_connectivity(adata_z1, label_key="cell_type")

In [None]:
scib.me.silhouette_batch(adata_z1, batch_key="batch", label_key="cell_type", embed='X_pca')

In [None]:
# kbet
scib.me.kBET(adata_z1, batch_key="batch", label_key="cell_type", type_="embed", embed="X_pca")

In [None]:
data_dir = "/home/haiping_liu/code/My_model/Batch_VAE/Results/data/processed_immune_data.csv"
table = pd.read_csv(data_dir)

data = table.iloc[:, 0:2000].values 

cell_type = table['cell_type'].values
batch_id = table['batch'].values

# build adata
adata = sc.AnnData(data)  
adata.obs['batch'] = batch_id 
adata.obs['cell_type'] = cell_type
# sc.pp.neighbors(adata)
# sc.tl.umap(adata)

scib.me.pcr_comparison(adata, adata_z1, covariate="batch")

### 2. Biological information

In [None]:
sc.pp.neighbors(adata_z1)
scib.me.cluster_optimal_resolution(adata_z1, cluster_key="cluster", label_key="cell type")
scib.me.ari(adata_z1, cluster_key="cluster", label_key="cell type")

In [None]:
scib.me.nmi(adata_z1, cluster_key="cluster", label_key="cell type")

In [None]:
sc.pp.pca(adata_z1)
scib.me.silhouette(adata_z1, label_key="cell_type", embed="X_pca")

In [None]:
from sklearn.metrics import (
    adjusted_rand_score,
    normalized_mutual_info_score,
)

sc.tl.leiden(adata_z1, resolution=0.4, key_added='leiden_clusters')
sc.settings.set_figure_params(fontsize=12)  
sc.pl.umap(adata_z1, color=['leiden_clusters', 'cell_type'], title="UMAP - Batch ID and Leiden Clusters", size=3)

cell_type_labels = adata_z1.obs['cell_type']
leiden_labels = adata_z1.obs['leiden_clusters']

# Adjusted Rand Index (ARI)
ari_score = adjusted_rand_score(cell_type_labels, leiden_labels)
print(f"Adjusted Rand Index (ARI) score: {ari_score}")

# Normalized Mutual Information (NMI)
nmi_score = normalized_mutual_info_score(cell_type_labels, leiden_labels)
print(f"Normalized Mutual Information (NMI) score: {nmi_score}")

plt.subplots_adjust(wspace=1)
plt.show()