In [None]:
import scvi
import scanpy as sc
import numpy as np
import scgen

In [None]:
dir_path = "/home/krushna/Documents/Data_integration/SCRNA_Datasets/All_h5ad/"
def load_data(dataset,batch):
    adata =sc.read_h5ad(dir_path+dataset+'.h5ad')
    sc.pp.filter_genes(adata, min_counts=3)
    adata.layers["counts"] = adata.X.copy()
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    adata.raw = adata
    sc.pp.highly_variable_genes(
            adata,
            flavor="seurat",
            n_top_genes=2000,
            layer="counts",
            batch_key=batch,
            subset=True
    )

    return adata
    
batch_key_dic = {'Immune_Human' : 'batch',
                 'Immune_human_mouse' : 'batch',
                 'Lung' : 'batch',
                 'Mouse_brain' : 'batch',
                 'Pancreas' : 'tech',
                 'Simulation1' : 'Batch',
                 'Simulation2' : 'Batch',
                 'Human_Mouse' : 'batch',
                 'Human_Retina': "Batch"}
cell_type_key_dic = {'Immune_Human' : 'final_annotation',
                 'Immune_human_mouse' : 'final_annotation',
                 'Lung' : 'cell_type',
                 'Mouse_brain' : 'cell_type',
                 'Pancreas' : 'celltype',
                 'Simulation1' : 'Group',
                 'Simulation2' : 'Group',
                 'Human_Mouse' : "celltype",
                 "Human_Retina":"Subcluster"}  

In [None]:
dataset = 'Immune_Human'
batch = batch_key_dic[dataset]
cell_type = cell_type_key_dic[dataset]

In [None]:
adata = load_data(dataset,batch)

In [None]:
import time
start_time = time.time()

In [None]:
train = adata
scgen.SCGEN.setup_anndata(train, batch_key=batch, labels_key=cell_type)

In [None]:
model = scgen.SCGEN(train)

In [None]:
model.train(
    max_epochs=100,
    batch_size=32,
    early_stopping=True,
    early_stopping_patience=25,
)

In [None]:
adata_int = model.batch_removal()
adata_int

In [None]:
end_time = time.time()
print('total time taken:', end_time-start_time)

In [None]:
adata_int.write('scGen-Immune_Human.h5ad')

In [None]:
adata_int = sc.read_h5ad("scGen-Immune_Human.h5ad")

In [None]:
import scIB
#Trajectory is asking precomputed sudo time point
results,ilisi_all,clisi_all,kbet_all  =   scIB.metrics.metrics(
        adata,
        adata_int,
        batch_key = batch,
        label_key = cell_type,
        hvg_score_=False,
        cluster_key='cluster',
        cluster_nmi=None,
        ari_=True,
        nmi_=True,
        nmi_method='arithmetic',
        nmi_dir=None,
        silhouette_=True,
        embed= 'X_pca',
        si_metric='euclidean',
        pcr_=True,
        cell_cycle_=False,
        organism='mouse',
        isolated_labels_=True,  # backwards compatibility
        isolated_labels_f1_=False,
        isolated_labels_asw_=False,
        n_isolated=None,
        graph_conn_=True,
        kBET_=True,
        kBET_sub=0.5,
        lisi_graph_=True,
        lisi_raw=True,
        trajectory_=False,
        type_=None,
        verbose=False,
)

In [None]:
results

In [None]:
import numpy as np
np.savetxt(dataset+"_ilisi.csv", ilisi_all, delimiter=",")
np.savetxt(dataset+"_clisi.csv", clisi_all, delimiter=",")
np.savetxt(dataset+"_kbet_all.csv",np.concatenate([np.array(val).reshape(1,-1) for val in kbet_all],axis = 0), delimiter=',')

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300

In [None]:
sc.pp.neighbors(adata_int)  # use_rep = 'final_embeddings'
sc.tl.umap(adata_int)
sc.pl.umap(adata_int, color=cell_type, frameon=False)
sc.pl.umap(adata_int, color=batch, frameon=False)
sc.pl.umap(adata_int, color='cluster', frameon=False)

In [None]:
import scIB
#Trajectory is asking precomputed sudo time point
results,ilisi_all,clisi_all,kbet_all  =   scIB.metrics.metrics(
        adata,
        adata_int,
        batch_key = batch,
        label_key = cell_type,
        hvg_score_=False,
        cluster_key='cluster',
        cluster_nmi=None,
        ari_=True,
        nmi_=False,
        nmi_method='arithmetic',
        nmi_dir=None,
        silhouette_=False,
        embed= 'X_pca',
        si_metric='euclidean',
        pcr_=False,
        cell_cycle_=False,
        organism='mouse',
        isolated_labels_=False,  # backwards compatibility
        isolated_labels_f1_=False,
        isolated_labels_asw_=False,
        n_isolated=None,
        graph_conn_=False,
        kBET_=False,
        kBET_sub=0.5,
        lisi_graph_=False,
        lisi_raw=False,
        trajectory_=True,
        type_=None,
        verbose=False,
)
results