In [None]:
# Residual Network for batch effect correction
# See more at: https://github.com/ushaham/BatchEffectRemoval
# Hoa Tran
# Update code from python version 2 to python version 3, Keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as pl
from matplotlib import rcParams
import time
from datetime import timedelta
import scanpy as sc
sc.settings.verbosity = 3  # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()

In [None]:
# Create folder to save the results 
import os
dirname = os.getcwd()
print(dirname)
data_dir = os.path.join('/home/hoa/hoatran/demo_normalization/dataset/dataset9_Human_cell_atlas/')

if not os.path.exists('./results/'): os.makedirs('./results/')   
if not os.path.exists('./results/results_dataset9_HCA/'): os.makedirs('./results/results_dataset9_HCA/')    
save_dir = os.path.join(dirname, 'results/results_dataset9_HCA/')

# Function to save figure as image in the figures folder
save_fig_dir='./figures/dataset9_HCA/'
if not os.path.exists('./figures/'): os.makedirs('./figures/')
if not os.path.exists(save_fig_dir): os.makedirs(save_fig_dir)
def save_images(filename, save_fig_dir):    
    outname = save_fig_dir + filename + '.png'
    pl.savefig(outname, dpi=150)
    pl.close()

In [None]:
# myDataFn = 'filtered_genes_and_cells/HCA_genes_cells_filtered_filtered_UMI.txt'
# mySampleFn = 'HCA_genes_cells_filtered_filtered_cell_info_correct.txt'
# savefn = 'myRawData1.h5ad'
# adata = load_data(data_dir, myDataFn, mySampleFn, save_dir, savefn,saveh5ad=False)
# adata

In [None]:
# Load data using h5py package, only for big dataset
# https://www.h5py.org/
import h5py
data_dir = os.path.join('/acrc/jinmiao/CJM_lab/hoatran/demo_normalization/dataset/dataset9_Human_cell_atlas/')
myDataFn = 'filtered_genes_and_cells/HCA_genes_cells_filtered_filtered_UMI.h5'
f = h5py.File(os.path.join(data_dir, myDataFn), 'r')
keys = list(f.keys())
k2 = [x for x in keys if x not in ['gene_names', 'cell_names']]
print(k2[0])
print(keys)


In [None]:
myData = np.array(f[k2[0]])
print(myData.shape)
gene_names = f['gene_names']
cell_names = f['cell_names']
print(gene_names.shape)
print(cell_names.shape)
gene_names = [x.decode() for x in gene_names]
print(gene_names[1:3])

In [None]:
# cell_names = [x.decode() for x in cell_names]
print(cell_names[1:3])
mySample.index[1:3]

In [None]:
import pandas as pd
# species vectors --> batch vector and batch label
# keep only necessary infos in sample file

mySampleFn = 'HCA_genes_cells_filtered_filtered_cell_info_correct.txt'  
mySample = pd.read_csv(os.path.join(data_dir,mySampleFn),header=0, index_col=0, sep='\t')

mySample.values.shape

In [None]:
adata = sc.AnnData(np.transpose(myData))
adata

In [None]:
adata.obs_names = mySample.index
adata.var_names = gene_names
#     adata.obs['cell_type'] = mySample.loc[adata.obs_names,['celltype']]
adata.obs['batch'] = mySample.loc[adata.obs_names,['batch']]
adata.obs['batchlb'] = mySample.loc[adata.obs_names,['batchlb']]

In [None]:
adata

In [None]:
savefn = 'HCA_genes_cells_filtered_filtered_UMI_adata.h5ad'
adata.write_h5ad(os.path.join(save_dir,savefn))

In [None]:
# Already filtered
# sc.pp.filter_cells(adata, min_genes=300)
# sc.pp.filter_genes(adata, min_cells=10)
# sc.pp.log1p(adata)
# sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
adata

In [None]:
# Here authors extract 50 pca vectors to train the network model
npcs_train = 50  
sc.tl.pca(adata, svd_solver='arpack', n_comps=npcs_train)  # output save to adata.obsm['X_pca']

# Raw data visualization
# Can not apply on this dataset, need to downsample data first
npcs = 20  # our pre-defined
perplex = 30  # our pre-defined
# get_tsne_umap_raw(adata, perplex=30, npcs=20)
adata

In [None]:
np.unique(adata.obs['batch'])

In [None]:
# Extract data from batch 1 and batch 2
adata1_filtered = adata[adata.obs['batch']==1,:]  
print(adata1_filtered)
adata2_filtered = adata[adata.obs['batch']==2,:]
print(adata2_filtered)
adata1_filtered.obsm['X_pca'] *= -1 # multiply by -1 to match Seurat, as suggestion from Scanpy
adata2_filtered.obsm['X_pca'] *= -1 # multiply by -1 to match Seurat, as suggestion from Scanpy
print(adata2_filtered.obsm['X_pca'].shape[1])

In [None]:
import keras.optimizers
from keras.layers import Input, Dense, merge, BatchNormalization, Activation
from keras.models import Model
from keras import callbacks as cb
from keras.regularizers import l2
from keras.callbacks import LearningRateScheduler
import math
from keras import backend as K
import sklearn.preprocessing as prep
from sklearn import decomposition
from statsmodels.distributions.empirical_distribution import ECDF
from keras import initializers
from keras.layers import add
from Calibration_Util import CostFunctions as cf  # author func
from Calibration_Util import Monitoring as mn     # author func
from Calibration_Util import utils_resnet as utils     # our func


In [None]:
target1 = adata1_filtered.obsm['X_pca']
source2 = adata2_filtered.obsm['X_pca']
space_dim = 50
batch_size = 1000
penalty = 1e-2
nbeps = 20
val_split = 0.25
savedfl = 'resnet_d9_target1_source2'


In [None]:
print(target1.shape)
print(source2.shape)

In [None]:
t1 = time.time()
calibMMDNet,block2_output = utils.createMMDResNetModel(target1, source2, space_dim, batch_size, penalty, 
                                                   save_dir, nbeps, val_split, savedfl)
afterCalib2 = calibMMDNet.predict(source2)   # align, calibrate source data to target data
print(afterCalib2.shape)
t2 = time.time()
print('Took '+str(timedelta(seconds=t2-t1)))

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
# Plot distribution, see the change after train and predict source data, compared to target data
# from Calibration_Util import ScatterDemo as sd

save_file_bf2 = os.path.join(save_fig_dir,'d12_distribution_before_target1_source2.png')
save_file_af2 = os.path.join(save_fig_dir, 'd12_distribution_after_target1_source2.png')

# The PCs most correlated with the batch are {1 and 2} or {3 and 5}
pc1 = 0
pc2 = 1
utils.myScatterHistDemo(target1[:,pc1], target1[:,pc2], source2[:,pc1], source2[:,pc2],save_file_bf2)
utils.myScatterHistDemo(target1[:,pc1], target1[:,pc2], afterCalib2[:,pc1], afterCalib2[:,pc2],save_file_af2)

In [None]:
## quantitative evaluation: MMD ###
# MMD with the scales used for training 
# Compute Maximum Mean Discrepancy Distance 
# between source data- before correction, source data-after correction, target data and Ground Truth - filter data, one layer in network 
# If MMD distance is smaller --> 2 distributions are shift closer, less of batch effect 
# If MMD distance is not many different --> this method can not remove batch effect
def calculMMD(target, source, afterCalib, block2_output):
    sourceInds = np.random.randint(low=0, high = source.shape[0], size = 1000)
    targetInds = np.random.randint(low=0, high = target.shape[0], size = 1000)
    mmd_before = K.eval(cf.MMD(block2_output,target).cost(K.variable(value=source[sourceInds]), K.variable(value=target[targetInds])))
    mmd_after = K.eval(cf.MMD(block2_output,target).cost(K.variable(value=afterCalib[sourceInds]), K.variable(value=target[targetInds])))
    print('MMD before calibration: ' + str(mmd_before))
    print('MMD after calibration: ' + str(mmd_after))
    return mmd_before,mmd_after
        
print('Distance b1 to b2')    
calculMMD(target1, source2, afterCalib2, block2_output)

In [None]:
pca_corrected = np.concatenate([target1,afterCalib2])
adata.obsm['X_pca'] = pca_corrected
# npcs = 20  # our pre-defined
# perplex = 30  # our pre-defined
# sc.pp.neighbors(adata,n_neighbors=15, n_pcs=npcs)
# sc.tl.tsne(adata, random_state=0, n_pcs=npcs, perplexity=perplex)
# sc.tl.umap(adata)
print(adata)

In [None]:
colnpc = []
for i in range(20):
    colnpc.append("X_pca"+str(i+1))

df = pd.DataFrame(pca_corrected[:, :20], columns=colnpc, index=adata.obs_names)
df['batch'] = pd.Series(adata.obs['batch'], index=adata.obs_names)
# df['celltype'] = pd.Series(adata.obs['cell_type'], index=adata.obs_names)
df.to_csv(save_dir+'resnet_pca_predicted.csv')

In [None]:
# Save execution time to file for evaluation   
# Evaluation runtime of main batch effect removal function
filename = 'resnet_exetime.csv'
usecase_name = 'resnet_exetime' 
utils.getExecutionTime(t1, t2, save_dir, usecase_name, filename)  # t1: start time, t2: end time       

In [None]:
npcs = 20  # our pre-defined
perplex = 30  # our pre-defined
nb_neighbors = 15
color_group = ["batchlb","cell_type"] 
save_fn_tsne = 'resnet_tsne'
save_fn_umap = 'resnet_umap'
utils.plotUMAP(adata, color_group, save_fn_umap, save_fig_dir, npcs, nb_neighbors, False)
# utils.plotTSNE(adata, color_group, save_fn_tsne, save_fig_dir, npcs, perplex, False)

In [None]:
adata.obs['batchlb'][1:5]

In [None]:
npcs = 20  # our pre-defined
perplex = 30  # our pre-defined
save_fn_tsne = 'resnet_tsne'
color_group = ["batchlb"]
utils.plotTSNE(adata, color_group, save_fn_tsne, save_fig_dir, npcs, perplex, False)

In [None]:
utils.save_output_txt(total_ann, save_dir)