In [None]:
# !pip install scgen
# Import package 
# Main using package here is scanpy 
import numpy as np
import pandas as pd
import matplotlib.pyplot as pl
from matplotlib import rcParams
import scanpy as sc
import os
import time
from datetime import timedelta
import random
import scgen
print(sc.__version__)
sc.settings.verbosity = 3  # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()

In [None]:
base_name = os.path.basename(os.getcwd())
print(base_name)

In [None]:
def save_images(base_name, dpi=300, fig_type = ".png"):
    output_dir = os.path.dirname(base_name)
    if not output_dir=="" and os.path.exists(output_dir): os.makedirs(output_dir)
    fn, fe = os.path.splitext(base_name)
    if(fe == ""):
        base_name = base_name + fig_type
    pl.savefig(base_name, dpi=dpi)
    pl.close()
    
def plotTSNE(adata, color_group, n_pcs=20, perplexity=30, save_filename='tsne', use_repx = False):
    #adata.var_names_make_unique()
    random.seed(42)
    if use_repx:
        sc.tl.tsne(adata, random_state=0, n_pcs=n_pcs, perplexity=perplexity, use_rep='X')
    else:    
        sc.tl.tsne(adata, random_state=0, n_pcs=n_pcs, perplexity=perplexity, n_jobs=20)
    sc.pl.tsne(adata, color = color_group, show=False, wspace=.4)
    save_images(save_filename) 
    
def plotUMAP(adata, color_group, save_filename, use_repx = False):
    
    if use_repx:
        sc.pp.neighbors(adata, use_rep='X')
    else:    
        sc.pp.neighbors(adata,n_neighbors=10, n_pcs=20)
        
    sc.tl.umap(adata)
    sc.pl.umap(adata, color = color_group, show=False, wspace=.4)
    save_images(save_filename)
    
    
def time_execute(t1, t2, usecase_name = 'scGen',
                base_name = 'scGen'):
    time_taken = t2 - t1
    time_taken_mins = divmod(time_taken, 60)
    time_taken_hours, rest = divmod( time_taken, 3600)
    hours_mins, hours_secs = divmod( rest, 60)
    print('Took seconds: '+str(timedelta(seconds=round(time_taken))))
    print('Took minutes: '+str(time_taken_mins))
    print('Took hours_minutes_seconds: ',str(time_taken_hours),str(hours_mins),str(hours_secs))
    
    

    data = {'use_case':usecase_name, 'exetime_secs':str(round(time_taken)),
           'exetimehours': str(time_taken_hours),
           'exetimemins': str(hours_mins),
           'exetimesecs':str(round(hours_secs))} 

    df = pd.DataFrame(data, index =['exetime'])
    print(df)
    df.to_csv(base_name + "_exetime.csv") 
    
    
def save_output_csv(adata, save_dir, usecase_name = 'scGen'): 
    colnu = []
    for i in range(adata.obsm['X_umap'].shape[1]):
        colnu.append("UMAP"+str(i+1))
    df = pd.DataFrame(adata.obsm['X_umap'], columns=colnu, index=adata.obs_names)
    df['batch'] = pd.Series(adata.obs['batch'], index=adata.obs_names)
    df['celltype'] = pd.Series(adata.obs['cell_type'], index=adata.obs_names)
    df.to_csv(os.path.join(save_dir, usecase_name + '_umap.csv')  

    # Save output of tsne for visualization
    colnt = []
    for i in range(adata.obsm['X_tsne'].shape[1]):
        colnt.append("tSNE_"+str(i+1))

    df = pd.DataFrame(adata.obsm['X_tsne'], columns=colnt, index=adata.obs_names)
    df['batch'] = pd.Series(adata.obs['batch'], index=adata.obs_names)
    df['celltype'] = pd.Series(adata.obs['cell_type'], index=adata.obs_names)
    df.to_csv(os.path.join(save_dir, usecase_name + '_tsne.csv') 

    # Save output of pca for evaluation ASW
    colnpc = []
    for i in range(20):
        colnpc.append("X_pca"+str(i+1))

    df = pd.DataFrame(adata.obsm['X_pca'][:, :20], columns=colnpc, index=adata.obs_names)
    df['batch'] = pd.Series(adata.obs['batch'], index=adata.obs_names)
    df['celltype'] = pd.Series(adata.obs['cell_type'], index=adata.obs_names)
    df.to_csv(os.path.join(save_dir, usecase_name + '_pca.csv')



In [None]:
from scipy import sparse
import anndata
# Using trained model to correct, normalize data 
# Using batch removal function from scGen package 

# corrected_adata = scgen.batch_removal(network, total_ann)

# In case this function does not work or can not return cell names, 
# replace batch_removal function by this function: batch_removal_v2
# Hoa Tran
def batch_removal_v2(network, adata):
    if sparse.issparse(adata.X):
        latent_all = network.to_latent(adata.X.A)
    else:
        latent_all = network.to_latent(adata.X)
    adata_latent = anndata.AnnData(latent_all)
    adata_latent.obs["cell_type"] = adata.obs["cell_type"].tolist()
    adata_latent.obs["batch"] = adata.obs["batch"].tolist()
    adata_latent.obs["cell_name"] = adata.obs["cell_name"].tolist()   #Hoa keep cell name infos
    unique_cell_types = np.unique(adata_latent.obs["cell_type"])
    shared_ct = []
    not_shared_ct = []
    for cell_type in unique_cell_types:
        temp_cell = adata_latent[adata_latent.obs["cell_type"] == cell_type]
        if len(np.unique(temp_cell.obs["batch"])) < 2:
            cell_type_ann = adata_latent[adata_latent.obs["cell_type"] == cell_type]
            not_shared_ct.append(cell_type_ann)
            continue
        temp_cell = adata_latent[adata_latent.obs["cell_type"] == cell_type]
        batch_list = {}
        batch_ind = {}
        max_batch = 0
        max_batch_ind = ""
        batches = np.unique(temp_cell.obs["batch"])
        for i in batches:
            temp = temp_cell[temp_cell.obs["batch"] == i]
            temp_ind = temp_cell.obs["batch"] == i
            if max_batch < len(temp):
                max_batch = len(temp)
                max_batch_ind = i
            batch_list[i] = temp
            batch_ind[i] = temp_ind
        max_batch_ann = batch_list[max_batch_ind]
        for study in batch_list:
            delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
            batch_list[study].X = delta + batch_list[study].X
            temp_cell[batch_ind[study]].X = batch_list[study].X
        shared_ct.append(temp_cell)
    all_shared_ann = anndata.AnnData.concatenate(*shared_ct, batch_key="concat_batch")
    del all_shared_ann.obs["concat_batch"]
    if len(not_shared_ct) < 1:
        corrected = anndata.AnnData(network.reconstruct(all_shared_ann.X, use_data=True))
        corrected.obs["cell_type"] = all_shared_ann.obs["cell_type"].tolist()
        corrected.obs["batch"] = all_shared_ann.obs["batch"].tolist()
        corrected.obs["cell_name"] = all_shared_ann.obs["cell_name"].tolist() #Hoa keep cell name infos
        corrected.var_names = adata.var_names.tolist()
        corrected.obs_names = corrected.obs['cell_name'] #Hoa assign cell name infos
        return corrected
    else:
        all_not_shared_ann = anndata.AnnData.concatenate(*not_shared_ct, batch_key="concat_batch")
        all_corrected_data = anndata.AnnData.concatenate(all_shared_ann, all_not_shared_ann, batch_key="concat_batch")
        del all_corrected_data.obs["concat_batch"]
        corrected = anndata.AnnData(network.reconstruct(all_corrected_data.X, use_data=True), )
        corrected.obs["cell_type"] = all_shared_ann.obs["cell_type"].tolist() + all_not_shared_ann.obs[
            "cell_type"].tolist()
        corrected.obs["batch"] = all_shared_ann.obs["batch"].tolist() + all_not_shared_ann.obs["batch"].tolist()
        corrected.obs["cell_name"] = all_shared_ann.obs["cell_name"].tolist() + all_not_shared_ann.obs[
            "cell_name"].tolist()     #Hoa keep cell name infos
        corrected.var_names = adata.var_names.tolist()
        corrected.obs_names = corrected.obs['cell_name'] #Hoa assign cell name infos
        return corrected

In [None]:
# Read data from read count text table, data in R: genes x cells, 
# Transpose data to cells x genes in order to include to anndata object
expr_filename = 'dataset2/filtered_total_batch1_seqwell_batch2_10x_transpose.txt'
adata = sc.read_text(expr_filename, delimiter='\t', first_column_names=True, dtype='float64')
print(adata)  

# Read sample info
metadata_filename = "dataset2/filtered_total_sample_ext_organ_celltype_batch.txt"
sample_adata = pd.read_csv(metadata_filename, header=0, index_col=0, sep='\t')
print(sample_adata.values.shape)
print(sample_adata.keys())
print(sample_adata.index)

adata.obs['batch'] = sample_adata.loc[adata.obs_names, "batch"]
print(len(adata.obs['batch']))
adata.obs['cell_type'] = sample_adata.loc[adata.obs_names, "cell_type"]
print(len(adata.obs['cell_type']))

# Save output into h5ad, easy to access 
adata.write_h5ad(os.path.join(data_dir,'dataset2_cellatlas.h5ad'))

In [None]:
# Observe batch effect in filtered data
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata,n_neighbors=15, n_pcs=20)
sc.tl.umap(adata)
sc.pl.umap(adata, color=["batch"], wspace=.3, show=False)
save_images('dataset10_umap')
color_group = ["cell_type","batch"]
plotTSNE(adata, color_group, 20, 90, base_name + '_tsne')

In [None]:
print("Create a network")
t1 = time.time()
# Initialize scGen with input is number of genes
import scgen
network = scgen.VAEArith(x_dimension=adata.shape[1], model_path=os.path.join(base_name,'dataset2_scgen_model'))
# Need to check batch_size
print("Train a network")
# Requirement: adata should contain 2 vector: adata.obs["cell_type"] and adata.obs["batch"]
network.train(train_data=adata, n_epochs=100, batch_size=50)

In [None]:
print("Correct data")
# Correct data using batch_removal function
# Input: adata and network model 
adata.obs['cell_name'] = adata.obs_names
adata.obs['batch'] = adata.obs['batch'].astype('category')
corrected_adata = batch_removal_v2(network, adata)
t2 = time.time()
print('Took '+str(timedelta(seconds=t2-t1)))
# corrected_adata = scgen.batch_removal(network, adata1)
print(corrected_adata)

In [None]:
time_execute(t1, t2, 'scGen', os.path.join(base_name,'scGen'))

In [None]:
sc.tl.pca(corrected_adata, svd_solver='arpack', n_comps=20)
corrected_adata.obsm['X_pca'] *= -1 # multiply by -1 to match Seurat visualization

In [None]:
plotTSNE(corrected_adata, color_group, 20, 90, base_name + '_scgen_corrected_tsne')
plotUMAP(corrected_adata, color_group, base_name + '_scgen_corrected_umap')
save_output_csv(adata, base_name)

In [None]:
corrected_adata.write_csvs(base_name + "_corrected_dataset2")
corrected_adata.write_h5ad(os.path.join(base_name,'corrected_dataset2.h5ad'))