In [None]:
# Residual Network for batch effect correction
# See more at: https://github.com/ushaham/BatchEffectRemoval
# Hoa Tran
# Update code from python version 2 to python version 3, Keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as pl
from matplotlib import rcParams
import time
from datetime import timedelta
import scanpy as sc
from Calibration_Util import utils_resnet as utils     # our utils function
sc.settings.verbosity = 3  # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()

In [None]:
base_name = os.path.basename(os.getcwd())
print(base_name)

In [None]:
def save_images(base_name, dpi=300, fig_type = ".png"):
    output_dir = os.path.dirname(base_name)
    if not output_dir=="" and os.path.exists(output_dir): os.makedirs(output_dir)
    fn, fe = os.path.splitext(base_name)
    if(fe == ""):
        base_name = base_name + fig_type
    pl.savefig(base_name, dpi=dpi)
    pl.close()
    
def plotTSNE(adata, color_group, n_pcs=20, perplexity=30, save_filename='tsne', use_repx = False):
    #adata.var_names_make_unique()
    random.seed(42)
    if use_repx:
        sc.tl.tsne(adata, random_state=0, n_pcs=n_pcs, perplexity=perplexity, use_rep='X')
    else:    
        sc.tl.tsne(adata, random_state=0, n_pcs=n_pcs, perplexity=perplexity, n_jobs=20)
    sc.pl.tsne(adata, color = color_group, show=False, wspace=.4)
    save_images(save_filename) 
    
def plotUMAP(adata, color_group, save_filename, use_repx = False):
    
    if use_repx:
        sc.pp.neighbors(adata, use_rep='X')
    else:    
        sc.pp.neighbors(adata,n_neighbors=10, n_pcs=20)
        
    sc.tl.umap(adata)
    sc.pl.umap(adata, color = color_group, show=False, wspace=.4)
    save_images(save_filename)
    
    
def time_execute(t1, t2, usecase_name = 'MMDResNet',
                base_name = 'scGen'):
    time_taken = t2 - t1
    time_taken_mins = divmod(time_taken, 60)
    time_taken_hours, rest = divmod( time_taken, 3600)
    hours_mins, hours_secs = divmod( rest, 60)
    print('Took seconds: '+str(timedelta(seconds=round(time_taken))))
    print('Took minutes: '+str(time_taken_mins))
    print('Took hours_minutes_seconds: ',str(time_taken_hours),str(hours_mins),str(hours_secs))
    
    

    data = {'use_case':usecase_name, 'exetime_secs':str(round(time_taken)),
           'exetimehours': str(time_taken_hours),
           'exetimemins': str(hours_mins),
           'exetimesecs':str(round(hours_secs))} 

    df = pd.DataFrame(data, index =['exetime'])
    print(df)
    df.to_csv(base_name + "_exetime.csv") 
    
    
def save_output_csv(adata, save_dir, usecase_name = 'MMDResNet'): 
    colnu = []
    for i in range(adata.obsm['X_umap'].shape[1]):
        colnu.append("UMAP"+str(i+1))
    df = pd.DataFrame(adata.obsm['X_umap'], columns=colnu, index=adata.obs_names)
    df['batch'] = pd.Series(adata.obs['batch'], index=adata.obs_names)
    df['celltype'] = pd.Series(adata.obs['cell_type'], index=adata.obs_names)
    df.to_csv(os.path.join(save_dir, usecase_name + '_umap.csv')  

    # Save output of tsne for visualization
    colnt = []
    for i in range(adata.obsm['X_tsne'].shape[1]):
        colnt.append("tSNE_"+str(i+1))

    df = pd.DataFrame(adata.obsm['X_tsne'], columns=colnt, index=adata.obs_names)
    df['batch'] = pd.Series(adata.obs['batch'], index=adata.obs_names)
    df['celltype'] = pd.Series(adata.obs['cell_type'], index=adata.obs_names)
    df.to_csv(os.path.join(save_dir, usecase_name + '_tsne.csv') 

    # Save output of pca for evaluation ASW
    colnpc = []
    for i in range(20):
        colnpc.append("X_pca"+str(i+1))

    df = pd.DataFrame(adata.obsm['X_pca'][:, :20], columns=colnpc, index=adata.obs_names)
    df['batch'] = pd.Series(adata.obs['batch'], index=adata.obs_names)
    df['celltype'] = pd.Series(adata.obs['cell_type'], index=adata.obs_names)
    df.to_csv(os.path.join(save_dir, usecase_name + '_pca.csv')



In [None]:
# Read data from read count text table, data in R: genes x cells, 
# Transpose data to cells x genes in order to include to anndata object
# expr_mtx: total filtered data = [data_batch1, data_batch2]
expr_filename = 'dataset2/filtered_total_batch1_seqwell_batch2_10x_transpose.txt'
adata = sc.read_text(expr_filename, delimiter='\t', first_column_names=True, dtype='float64')
print(adata)  

# Read sample info
metadata_filename = "dataset2/filtered_total_sample_ext_organ_celltype_batch.txt"
sample_adata = pd.read_csv(metadata_filename, header=0, index_col=0, sep='\t')
print(sample_adata.values.shape)
print(sample_adata.keys())
print(sample_adata.index)

adata.obs['batch'] = sample_adata.loc[adata.obs_names, "batch"]
print(len(adata.obs['batch']))
adata.obs['cell_type'] = sample_adata.loc[adata.obs_names, "cell_type"]
print(len(adata.obs['cell_type']))

# Save output into h5ad, easy to access 
adata.write_h5ad(os.path.join(data_dir,'dataset2_cellatlas.h5ad'))
print(adata)

In [None]:
# Here authors extract 50 pca vectors to train the network model
npcs_train = 50  
sc.tl.pca(adata, svd_solver='arpack', n_comps=npcs_train)  # output save to adata.obsm['X_pca']

# Raw data visualization
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=20)
plotTSNE(corrected_adata, color_group, 20, 30, base_name + '_d2_raw_tsne')
plotUMAP(corrected_adata, color_group, base_name + '_d2_raw_umap')

In [None]:
# Extract data from batch 1 and batch 2
# For ResNet, we calibrate (align) batch 1 to batch 2 
# or align batch 2 to batch 1
# In our work, we do both options, compare the results and choose the best output
adata1 = adata[adata.obs['batch']==1,:].copy()
print(adata1)
adata2 = adata[adata.obs['batch']==2,:].copy()
print(adata2)

In [None]:
import keras.optimizers
from keras.layers import Input, Dense, merge, BatchNormalization, Activation
from keras.models import Model
from keras import callbacks as cb
from keras.regularizers import l2
from keras.callbacks import LearningRateScheduler
import math
from keras import backend as K
import sklearn.preprocessing as prep
from sklearn import decomposition
from statsmodels.distributions.empirical_distribution import ECDF
from keras import initializers
from keras.layers import add
from Calibration_Util import CostFunctions as cf  # author func
from Calibration_Util import Monitoring as mn     # author func


In [None]:
target = adata1.obsm['X_pca']
source = adata2.obsm['X_pca']
space_dim = target.shape[1]  #50
batch_size = 30
penalty = 1e-2
nbeps = 50
val_split = 0.15
savedfl = 'resnet_dataset2'
t1 = time.time()
calibMMDNet, block2_output = utils.createMMDResNetModel(target, source, space_dim, batch_size, penalty, save_dir, nbeps, val_split, savedfl)


In [None]:
afterCalib = calibMMDNet.predict(source)   # align, calibrate source data to target data
afterCalib.shape
t2 = time.time()

In [None]:
time_execute(t1, t2, 'MMDResNet', os.path.join(base_name,'MMDResNet'))

In [None]:
# Visualize distribution
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
# Plot distribution, see the change after train and predict source data, compared to target data
# from Calibration_Util import ScatterDemo as sd

if not os.path.exists(os.path.join(base_name,'figures')): 
    os.makedirs(os.path.join(base_name,'figures')

fn_bf = os.path.join(base_name,'figures/distribution_d2_before_calibration.png')
fn_af = os.path.join(base_name, 'figures/distribution_d2_after_calibration.png')
# The PCs most correlated with the batch are {1 and 2} or {3 and 5}
pc1 = 1
pc2 = 2
utils.myScatterHistDemo(target[:,pc1], target[:,pc2], source[:,pc1], source[:,pc2], fn_bf)
utils.myScatterHistDemo(target[:,pc1], target[:,pc2], afterCalib[:,pc1], afterCalib[:,pc2], fn_af)

In [None]:
## quantitative evaluation: MMD ###
# MMD with the scales used for training 
# Compute Maximum Mean Discrepancy Distance 
# between source data- before correction, source data-after correction, target data and Ground Truth - filter data, one layer in network 
# If MMD distance is smaller --> 2 distributions are shift closer, less of batch effect 
# If MMD distance is not many different --> this method can not remove batch effect
sourceInds = np.random.randint(low=0, high = source.shape[0], size = 1000)
targetInds = np.random.randint(low=0, high = target.shape[0], size = 1000)
mmd_before = K.eval(cf.MMD(block2_output,target).cost(K.variable(value=source[sourceInds]), K.variable(value=target[targetInds])))
mmd_after = K.eval(cf.MMD(block2_output,target).cost(K.variable(value=afterCalib[sourceInds]), K.variable(value=target[targetInds])))
print('MMD before calibration: ' + str(mmd_before))
print('MMD after calibration: ' + str(mmd_after))

In [None]:
adata1.obsm['X_pca'] = target
adata2.obsm['X_pca'] = afterCalib

In [None]:
pca_corrected = np.concatenate([target,afterCalib])
adata.obsm['X_pca'] = pca_corrected  # update corrected results to pca vectors in adata
npcs = 20  
perplex = 30  
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=npcs)
plotTSNE(adata, color_group, 20, 30, base_name + '_resnet_corrected_tsne')
plotUMAP(adata, color_group, base_name + '_resnet_corrected_umap')
save_output_csv(adata, base_name)

In [None]:
adata.write_h5ad(os.path.join(base_name,'resnet_corrected_pca_dataset2.h5ad'))