In [1]:
!pip install scanpy
!pip install episcanpy



In [2]:
import scanpy as sc
import numpy as np
import pandas as pd
from scipy.sparse import issparse
import scipy
import anndata as ad
import torch.nn as nn
import episcanpy as epi
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim
import torch
from torch.autograd import Variable


import argparse

In [3]:
# parse arguments
options = argparse.ArgumentParser()

options.add_argument('--file_path', '-f', type=str, nargs='+', default = 'drive/MyDrive/B2_Model/dataset/pbmc_granulocyte_sorted_atac.h5ad')
options.add_argument('--save-dir', action="store", dest="save_dir",default="outputs/")
options.add_argument('-pt', action="store", dest="pretrained_file", default=None)
options.add_argument('-bs', action="store", dest="batch_size", default = 128, type = int)
options.add_argument('-ds', action="store", dest="datadir", default = "drive/MyDrive/MultiDomainTranslation/data/nuclear_crops_all_experiments/")
options.add_argument('-lamb', action="store", dest="lamb", default=0.0000001, type = float)
options.add_argument('-lamb2', action="store", dest="lamb2", default=0.001, type = float)
options.add_argument('--conditional', action="store_true")
    
# encoder_dim = [encoder_dim_0,encoder_dim_1]
options.add_argument('--encoder-dim-0',type=int,default=1024)
options.add_argument('--encoder-dim-1',type=int,default=128)
options.add_argument('--latent-dim',type=int,default=10)

# draw graph
options.add_argument('--outdir', '-o', type=str, default='output/', help='Output path')
options.add_argument('--embed', type=str, default='UMAP')
options.add_argument('--cluster_method', type=str, default='leiden')
options.add_argument('--cluster-num',type=int,default=30)
    
# model parameters
options.add_argument('--n-centroids', type=int, default = 20)
options.add_argument('--n-hidden',type=int,default = 1024, help = ' number of hidden layers.')
options.add_argument('--n-latent',type=int,default = 10, help = 'number of dimensions of latent space.')
options.add_argument('--min-peaks', type=int, default=600, help='Remove low quality cells with few peaks')
options.add_argument('--min-genes', type=int, default=200, help='rna min genes')
options.add_argument('--min-cells', type=float, default=0.01, help='Remove low quality peaks')
options.add_argument('--atac-classifier',type=str,default="SVM",help="type of classifer used in the atac vae.")

# experiment settings
options.add_argument('--epochs', default = 100, type = int, help="number of training epochs.")
options.add_argument('-learning-rate', default=1e-4, type = float, help="learning rate.")
options.add_argument('--is-validate',default = False,type =bool, help='whether to verify model validity with a classifier.')
options.add_argument('--train-batch-size', type=float, default=100, help='train batch size')
options.add_argument('--test-batch-size', type=float, default=100, help='test batch size')

# task settigns
options.add_argument('--translation', default=True, type=bool, help="whether to perform the translation between RNAseq and ATACseq.")
options.add_argument('--reconstruction', default=True, type=bool, help="whether perform the reconstructon results.")
options.add_argument('--latent-space',default=False, type=bool, help="whether to perform the latent space results.")
args = options.parse_args(args=[])

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
raw_atac = sc.read_h5ad("drive/MyDrive/B2_Model/dataset/E18_mouse_brain/e18_mouse_brain_fresh_5k_atac_fragments.h5ad")
raw_rna = sc.read_h5ad("drive/MyDrive/B2_Model/dataset/E18_mouse_brain/e18_mouse_brain_fresh_5k_rna_fragments.h5ad")

In [6]:
class sc_Dataset(Dataset):
    """
    Dataset for dataloader
    """
    def __init__(self, rna_adata, atac_adata):
        self.rna_adata = rna_adata
        self.atac_adata = atac_adata

        self.rna_shape = rna_adata.shape
        self.atac_shape = atac_adata.shape

    def __len__(self):
        return self.rna_adata.X.shape[0]
    
    def __getitem__(self, idx):
        rna_x = self.rna_adata.X[idx].toarray().squeeze()
        rna_y = self.rna_adata.obs_names[idx]
        atac_x = self.atac_adata.X[idx].squeeze()
        atac_y = self.atac_adata.obs_names[idx]
#        domain_id = self.rna_adata.obs['batch'].cat.codes[idx]
#        return x, domain_id, idx
        return rna_x,atac_x,rna_y,atac_y

In [28]:
def train(args,epoch,atac_model,rna_model,trainloader,atac_optimizer,rna_optimizer):
    
    atac_model.train()
    rna_model.train()

    train_loss = 0
    total_clf_loss = 0
    
    atac_recon_output, rna_recon_output = None, None
    atac_latents_output, rna_latents_output = None, None
    atac_trans_output, rna_trans_output = None, None
    atac_targets, rna_targets = None,None

    for batch_idx, train_batch in enumerate(trainloader):

        rna_batch, atac_batch, rna_batch_y, atac_batch_y = train_batch
        rna_batch, atac_batch = rna_batch.type(torch.FloatTensor), atac_batch.type(torch.FloatTensor)
        rna_batch, atac_batch = rna_batch.cuda(), atac_batch.cuda()
        rna_batch_y, atac_batch_y = list(rna_batch_y), list(atac_batch_y)
 
        if len(rna_batch) == 1 or len(atac_batch) == 1:
          continue

        atac_optimizer.zero_grad()
        rna_optimizer.zero_grad()

        atac_recon_inputs, atac_latents, atac_mu, atac_logvar = atac_model(atac_batch)
        atac_loss = loss_function(atac_recon_inputs, atac_batch, atac_mu, atac_logvar, atac_latents)

        rna_recon_inputs, rna_latents, rna_mu, rna_logvar = rna_model(rna_batch)
        rna_loss = loss_function(rna_recon_inputs, rna_batch, rna_mu, rna_logvar, rna_latents)

        atac_output = atac_model.decode(rna_latents)
        rna_output = rna_model.decode(atac_latents)
        
        rna_shared_loss = compute_shared_loss(rna_output,rna_batch)
        atac_shared_loss = compute_shared_loss(atac_output,atac_batch)
        
        atac_loss += atac_shared_loss
        rna_loss += rna_shared_loss
        
        """
        atac_extra_loss = 0
        if batch_idx != 0:
            atac_extra_loss = copmute_atac_regularization(atac_model,atac_recon_inputs,atac_batch)
        
        atac_loss += atac_extra_loss
        _atac_recon_inputs = atac_recon_inputs.reshape(-1,1)
        _atac_batch = atac_batch.reshape(-1,1)
        atac_model.classifier.fit(_atac_recon_inputs.cpu().detach().numpy(),_atac_batch.cpu().detach().numpy())
        """
        
        atac_loss.backward(retain_graph=True)
        rna_loss.backward()

        atac_optimizer.step()
        rna_optimizer.step()
        

        ## visualization
        # atac encoder + rna decoder
        #atac_recon_inputs, atac_latents, atac_mu, atac_logvar = atac_model(atac_batch)
        
        #atac_latents = torch.sigmoid(atac_latents)
        #atac_recon_inputs = torch.sign(atac_recon_inputs)
        
        if epoch == args.epochs-1:
            
            if rna_targets is None or atac_targets is None:
                rna_targets, atac_targets = rna_batch_y, atac_batch_y
            else:
                rna_targets = rna_targets + rna_batch_y
                atac_targets = atac_targets + atac_batch_y

            if args.reconstruction is True:
                atac_recon_numpy = atac_recon_inputs.cpu().detach().numpy()
                if atac_recon_output is None:
                    atac_recon_output = atac_recon_numpy
                else:
                    atac_recon_output = np.concatenate((atac_recon_output,atac_recon_numpy),axis=0)
                
                rna_recon_numpy = rna_recon_inputs.cpu().detach().numpy()

                if rna_recon_output is None:
                    rna_recon_output = rna_recon_numpy
                else:
                    rna_recon_output = np.concatenate((rna_recon_output,rna_recon_numpy),axis=0)

            
            if args.latent_space is True:
                atac_latent_numpy = atac_latents.cpu().detach().numpy()
                
                if atac_latents_output is None:
                    atac_latents_output = atac_latent_numpy
                else:
                    atac_latents_output = np.concatenate((atac_latents_output,atac_latent_numpy),axis=0)
                    
                rna_latent_numpy = rna_latents.cpu().detach().numpy()
                
                if rna_latents_output is None:
                    rna_latents_output = rna_latent_numpy
                else:
                    rna_latents_output = np.concatenate((rna_latents_output,rna_latent_numpy),axis=0)
                
            if args.translation is True:
                
                atac_output = atac_model.decode(rna_latents)
                atac_trans_numpy = atac_output.cpu().detach().numpy()
                
                if atac_trans_output is None:
                    atac_trans_output = atac_trans_numpy
                else:
                    atac_trans_output = np.concatenate((atac_trans_output,atac_trans_numpy),axis=0)

                rna_output = rna_model.decode(atac_latents)
                rna_trans_numpy = rna_output.cpu().detach().numpy()
                
                if rna_trans_output is None:
                    rna_trans_output = rna_trans_numpy
                else:
                    rna_trans_output = np.concatenate((rna_trans_output,rna_trans_numpy),axis=0)
                
                
        ## method 1
        #rna_output = torch.sigmoid(rna_output)
        #rna_output = torch.sign(rna_output)

        # rna encoder + atac decoder
        #rna_recon_inputs, rna_latents, rna_mu, rna_logvar = rna_model(rna_batch)
        #atac_output = atac_model.decode(rna_latents)
        
        ## method 2
        #atac_output = torch.sigmoid(atac_output)
        #atac_output = torch.sign(atac_output)
        
        """
        if args.conditional:
            targets = Variable(samples['binary_label'])
            if torch.cuda.is_available():
                targets = targets.cuda()
            clf_outputs = netCondClf(latents)
            class_clf_loss = CE(clf_outputs, targets.view(-1).long())
            loss += args.lamb2 * class_clf_loss
            total_clf_loss += class_clf_loss.data.item() * inputs.size(0)
        """
        print('Epoch: {} rna loss:{} atac loss:{}'.format(epoch, rna_loss / len(trainloader), atac_loss / len(trainloader)))
  
    #print(rna_latents_output)
    if epoch == args.epochs-1:
        
        with open("drive/MyDrive/output_enhancer/rna_targets_train.txt",'w') as output_targets:
            for target in rna_targets:
                output_targets.write(target)
                output_targets.write('\n')
        
        with open("drive/MyDrive/output_enhancer/atac_targets_train.txt",'w') as output_targets:
            for target in atac_targets:
                output_targets.write(target)
                output_targets.write('\n')            
        
        if args.reconstruction is True:
            np.save("drive/MyDrive/output_enhancer/rna_recon_output_train.npy",rna_latents_output)
            np.save("drive/MyDrive/output_enhancer/atac_recon_output_train.npy",atac_latents_output)
            
        if args.latent_space is True:
            np.save("drive/MyDrive/output_enhancer/rna_latents_output_train.npy",rna_latents_output)
            np.save("drive/MyDrive/output_enhancer/atac_latents_output_train.npy",atac_latents_output)
        
        if args.translation is True:
            np.save("drive/MyDrive/output_enhancer/rna_trans_output_train.npy", rna_trans_output)
            np.save("drive/MyDrive/output_enhancer/atac_trans_output_train.npy",atac_trans_output)
        

In [8]:
def test(args,epoch,atac_model,rna_model,testloader):
    
    atac_model.eval()
    rna_model.eval()

    test_loss = 0
    total_atac_loss = 0
    total_rna_loss = 0

    atac_latents_output, rna_latents_output = None, None
    atac_recon_output, rna_recon_output = None, None
    atac_trans_output, rna_trans_output = None, None
    atac_targets, rna_targets = None, None

    for batch_idx, test_batch in enumerate(testloader):

        rna_batch, atac_batch, rna_batch_y, atac_batch_y = test_batch
        rna_batch, atac_batch = rna_batch.type(torch.FloatTensor), atac_batch.type(torch.FloatTensor)
        rna_batch, atac_batch = rna_batch.cuda(), atac_batch.cuda()
        rna_batch_y, atac_batch_y = list(rna_batch_y), list(atac_batch_y)
        if len(rna_batch) == 1 or len(atac_batch) == 1:
            continue
 
        atac_recon_inputs, atac_latents, atac_mu, atac_logvar = atac_model(atac_batch)
        atac_loss = loss_function(atac_recon_inputs, atac_batch, atac_mu, atac_logvar, atac_latents)

        rna_recon_inputs, rna_latents, rna_mu, rna_logvar = rna_model(rna_batch)
        rna_loss = loss_function(rna_recon_inputs, rna_batch, rna_mu, rna_logvar, rna_latents)
        
        total_atac_loss += atac_loss
        total_rna_loss += rna_loss

        if epoch == args.epochs-1:

            if rna_targets is None or atac_targets is None:
                rna_targets, atac_targets = rna_batch_y, atac_batch_y
            else:
                rna_targets = rna_targets + rna_batch_y
                atac_targets = atac_targets + atac_batch_y

            if args.reconstruction is True:
                atac_recon_numpy = atac_recon_inputs.cpu().detach().numpy()
                if atac_recon_output is None:
                    atac_recon_output = atac_recon_numpy
                else:
                    atac_recon_output = np.concatenate((atac_recon_output,atac_recon_numpy),axis=0)
                
                rna_recon_numpy = rna_recon_inputs.cpu().detach().numpy()

                if rna_recon_output is None:
                    rna_recon_output = rna_recon_numpy
                else:
                    rna_recon_output = np.concatenate((rna_recon_output,rna_recon_numpy),axis=0)

            
            if args.latent_space is True:
                atac_latent_numpy = atac_latents.cpu().detach().numpy()
                
                if atac_latents_output is None:
                    atac_latents_output = atac_latent_numpy
                else:
                    atac_latents_output = np.concatenate((atac_latents_output,atac_latent_numpy),axis=0)
                    
                rna_latent_numpy = rna_latents.cpu().detach().numpy()
                
                if rna_latents_output is None:
                    rna_latents_output = rna_latent_numpy
                else:
                    rna_latents_output = np.concatenate((rna_latents_output,rna_latent_numpy),axis=0)
                
            if args.translation is True:
                
                atac_output = atac_model.decode(rna_latents)
                atac_trans_numpy = atac_output.cpu().detach().numpy()
                
                if atac_trans_output is None:
                    atac_trans_output = atac_trans_numpy
                else:
                    atac_trans_output = np.concatenate((atac_trans_output,atac_trans_numpy),axis=0)

                rna_output = rna_model.decode(atac_latents)
                rna_trans_numpy = rna_output.cpu().detach().numpy()
                
                if rna_trans_output is None:
                    rna_trans_output = rna_trans_numpy
                else:
                    rna_trans_output = np.concatenate((rna_trans_output,rna_trans_numpy),axis=0)
    
    total_atac_loss /= len(testloader)
    total_rna_loss /= len(testloader)

    print('Test atac loss:{} Test rna loss:{} '.format(total_atac_loss, total_rna_loss))
    
    if epoch == args.epochs-1:

        with open("drive/MyDrive/output_enhancer/rna_targets_test.txt",'w') as output_targets:
            for target in rna_targets:
                output_targets.write(target)
                output_targets.write('\n')
        
        with open("drive/MyDrive/output_enhancer/atac_targets_test.txt",'w') as output_targets:
            for target in atac_targets:
                output_targets.write(target)
                output_targets.write('\n')   
                
        if args.reconstruction is True:
            np.save("drive/MyDrive/output_enhancer/rna_recon_output_test.npy",rna_latents_output)
            np.save("drive/MyDrive/output_enhancer/atac_recon_output_test.npy",atac_latents_output)
            
        if args.latent_space is True:
            np.save("drive/MyDrive/output_enhancer/rna_latents_output_test.npy",rna_latents_output)
            np.save("drive/MyDrive/output_enhancer/atac_latents_output_test.npy",atac_latents_output)
        
        if args.translation is True:
            np.save("drive/MyDrive/output_enhancer/rna_trans_output_test.npy", rna_trans_output)
            np.save("drive/MyDrive/output_enhancer/atac_trans_output_test.npy",atac_trans_output)
    
    return total_atac_loss, total_rna_loss


In [9]:
class RNA_VAE(nn.Module):
    """Fully connected variational Autoencoder"""
    def __init__(self, args,n_input):
        super(RNA_VAE, self).__init__()
        self.n_latent = args.n_latent
        self.n_input = n_input
        self.n_hidden = args.n_hidden

        self.encoder = nn.Sequential(
                                nn.Linear(self.n_input, self.n_hidden),
                                nn.ReLU(inplace=True),
                                nn.BatchNorm1d(self.n_hidden),
                                nn.Linear(self.n_hidden, self.n_hidden),
                                nn.BatchNorm1d(self.n_hidden),
                                nn.ReLU(inplace=True),
                                nn.Linear(self.n_hidden, self.n_hidden),
                                nn.BatchNorm1d(self.n_hidden),
                                nn.ReLU(inplace=True),
                                nn.Linear(self.n_hidden, self.n_hidden),
                                nn.BatchNorm1d(self.n_hidden),
                                nn.ReLU(inplace=True),
                                nn.Linear(self.n_hidden, self.n_hidden),
                                )

        self.fc1 = nn.Linear(self.n_hidden, self.n_latent)
        self.fc2 = nn.Linear(self.n_hidden, self.n_latent)

        self.decoder = nn.Sequential(nn.Linear(self.n_latent, self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_input),
                                    )
        self.encoder = self.encoder
        self.decoder = self.decoder

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        res = self.decode(z)
        return res, z, mu, logvar

    def encode(self, x):
        x = x
        h = self.encoder(x)
        return self.fc1(h), self.fc2(h)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def decode(self, z):
        return self.decoder(z)

    def get_latent_var(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return z
 
    def generate(self, z):
        res = self.decode(z)
        return res

In [10]:
class ATAC_VAE(nn.Module):
    """Fully connected variational Autoencoder"""
    def __init__(self, args,n_input):
        super(ATAC_VAE, self).__init__()
        self.n_latent = args.n_latent
        self.n_input = n_input
        self.n_hidden = args.n_hidden
        self.n_centroids = args.n_centroids
        
        if args.atac_classifier == "RandomForest":
            from sklearn.ensemble import RandomForestClassifier
            self.classifier = RandomForestClassifier()
        elif args.atac_classifier == "DecisionTree":
            from sklearn import tree
            self.classifier = tree.DecisionTreeClassifier()
        elif args.atac_classifier == "GBDT":
            from sklearn.ensemble import GradientBoostingClassifier
            self.classifier = GradientBoostingClassifier()
        elif args.atac_classifier == "AdaBoost":
            from sklearn.ensemble import  AdaBoostClassifier
            self.classifier = AdaBoostClassifier()
        elif args.atac_classifier == "SVM":
            from sklearn.svm import SVC
            self.classifier = SVC(kernel='rbf')
        
        # condition c for converting the distribution to standard normal distribution
        #z_dim = dims[1]

        # init c_params
        #self.pi = nn.Parameter(torch.ones(n_centroids)/n_centroids)  # pc
        #self.mu_c = nn.Parameter(torch.zeros(z_dim, n_centroids)) # mu
        #self.var_c = nn.Parameter(torch.ones(z_dim, n_centroids)) # sigma^2

        self.fc1 = nn.Linear(self.n_hidden, self.n_latent)
        self.fc2 = nn.Linear(self.n_hidden, self.n_latent)

        self.layer_1 = nn.Linear(self.n_input,self.n_hidden)
        self.batch_norm_1 = nn.BatchNorm1d(self.n_hidden)
        self.layer_2 = nn.Linear(self.n_hidden,self.n_hidden)
        self.batch_norm_2 = nn.BatchNorm1d(self.n_hidden)
        self.layer_3 = nn.Linear(self.n_hidden,self.n_hidden)
        self.batch_norm_3 = nn.BatchNorm1d(self.n_hidden)
        self.layer_4 = nn.Linear(self.n_hidden,self.n_hidden)
        self.batch_norm_4 = nn.BatchNorm1d(self.n_hidden)
        self.layer_5 = nn.Linear(self.n_hidden,self.n_hidden)
        self.loss_func = nn.ReLU()

        self.decoder = nn.Sequential(nn.Linear(self.n_latent, self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_input),
                                    )
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        #gamma, mu_c, var_c, pi = get_gamma(n_centroids, z, pi, mu_c, var_c)
        res = self.decode(z)
        return res, z, mu, logvar

    def encode(self, x):

        x = self.layer_1(x)
        x = self.loss_func(x)
        x = self.batch_norm_1(x)
        x = self.layer_2(x)
        x = self.batch_norm_2(x)
        x = self.loss_func(x)
        x = self.layer_3(x)
        x = self.batch_norm_3(x)
        x = self.loss_func(x)
        x = self.layer_4(x)
        x = self.batch_norm_4(x)
        x = self.loss_func(x)
        h = self.layer_5(x)
        return self.fc1(h), self.fc2(h)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def decode(self, z):
        return self.decoder(z)

    def get_latent_var(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return z
 
    def generate(self, z):
        res = self.decode(z)
        return res

In [11]:
# compute loss
def loss_function(recon_x, x, mu, logvar, latents):
    MSE = nn.MSELoss()
    lloss = MSE(recon_x,x)

    if args.lamb>0:
        KL_loss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        lloss = lloss + args.lamb*KL_loss

    return lloss

def compute_shared_loss(real_seq,generated_seq):
    
    MSE = nn.MSELoss()
    lloss = MSE(generated_seq,real_seq)
    
    return lloss

In [12]:
def preprocessing_atac(
        adata, 
        min_genes=200, 
        min_cells=0.01, 
        n_top_genes=30000,
        target_sum=None
    ):
    """
    preprocessing
    """
    print('Raw dataset for scATAC shape: {}'.format(adata.shape))

    if not issparse(adata.X):
        adata.X = scipy.sparse.csr_matrix(adata.X)
    
    print(np.max(adata.X))    
    
    #binarize:
    #epi.pp.binarize(adata)
    
    #Filtering cells
    #sc.pp.filter_cells(adata, min_genes=min_genes)
    
    #'Filtering genes'
    if min_cells < 1:
        min_cells = min_cells * adata.shape[0]
    sc.pp.filter_genes(adata, min_cells=min_cells)
    
    #Finding variable features
    adata = epi.pp.select_var_feature(adata, nb_features=n_top_genes, show=False, copy=True)
    
    #Normalizing total per cell
    sc.pp.normalize_total(adata, target_sum=target_sum)
        

    print('Processed dataset shape: {}'.format(adata.shape))
    return adata

In [13]:
def preprocessing_rna(
    adata,
    min_genes=200, 
    min_cells=0.01, 
    target_sum=1e4
):
    """
    preprocessing
    """
    print('Raw dataset for scRNA shape: {}'.format(adata.shape))
    if not issparse(adata.X):
        adata.X = scipy.sparse.csr_matrix(adata.X)

    #Filtering cells
    #sc.pp.filter_cells(adata, min_genes=min_genes)

    #Flitering genes
    sc.pp.filter_genes(adata, min_cells=min_cells)

    #LogNormalize
    sc.pp.normalize_total(adata, target_sum=target_sum)
    sc.pp.log1p(adata)
    adata.raw = adata

    print('Processed dataset shape: {}'.format(adata.shape))
    return adata

In [14]:
pre_atac = preprocessing_atac(raw_atac)

Raw dataset for scATAC shape: (4482, 172193)
239.0




Processed dataset shape: (4482, 30002)


In [15]:
# read atac dataset (raw)
df = pre_atac.to_df()
genes = [column for column in df]
gene_names,gene_starts,gene_ends = [],[],[]

for gene in genes: 
  tmp_gene = gene.split('-')
  gene_names.append(tmp_gene[0])
  gene_starts.append(tmp_gene[1])
  gene_ends.append(tmp_gene[2])

In [16]:
# read enhancer dataset

enhancer_names,enhancer_starts,enhancer_ends = [],[],[]
enhancers = []
with open("drive/MyDrive/Brain.bed.txt",'r') as inputFile:
    for line in inputFile:
        enhancers.append(line)

for enhancer in enhancers:
    
    enhancer = enhancer.replace("\t"," ").replace("\n"," ")
    tmp_enhancer = enhancer.split(" ")
    enhancer_names.append(tmp_enhancer[0])
    enhancer_starts.append(tmp_enhancer[1])
    enhancer_ends.append(tmp_enhancer[2])

In [17]:
enhancer_separated, gene_separated = {},{}
mixed_sequences = {}

count = 0
for name,start,end in zip(enhancer_names,enhancer_starts,enhancer_ends):
    
    if name not in enhancer_separated:
        enhancer_separated[name] = []
        
    enhancer_tuple = (int(start),int(end),"enhancer")
    enhancer_separated[name].append(enhancer_tuple)

for idx,(name,start,end) in enumerate(zip(gene_names,gene_starts,gene_ends)):

    if name not in enhancer_separated.keys():
        continue
        
    if name not in gene_separated:
        gene_separated[name] = []
        
    gene_tuple = (int(start),int(end),"gene")
    gene_separated[name].append(gene_tuple)

In [18]:
mixed_sequences = {}

count = 0
for item in gene_separated.items():
    
    name, gene_sequences = item
    if name not in enhancer_separated.keys():
        continue
    else:
        mixed_sequences[name] = []
        for gene_seq in gene_sequences:
            tmp_gene_tuple = (gene_seq[0],gene_seq[1],"gene")
            mixed_sequences[name].append(tmp_gene_tuple)
    
for item in enhancer_separated.items():
    
    name,enh_sequences = item
    for enh_seq in enh_sequences:
        tmp_enh_tuple = (enh_seq[0],enh_seq[1],"enhancer")
        mixed_sequences[name].append(tmp_enh_tuple)

In [19]:
for item in mixed_sequences.items():
    
    name,sequences = item
    result_seqs = sorted(sequences,key=lambda x:x[0])
    mixed_sequences[name] = result_seqs

In [20]:
gene_enhancers = []

for item in mixed_sequences.items():
    
    name,sequences = item
    for idx,seq in enumerate(sequences):
        
        if idx == len(sequences)-1:
            continue
            
        if not seq[1] > sequences[idx+1][0]:
            continue
        else:
            if seq[2] == "enhancer":
                tmp_tuple = (sequences[idx+1][0],sequences[idx+1][1],name)
                gene_enhancers.append(tmp_tuple)
            else:
                tmp_tuple = (seq[0],seq[1],name)
                gene_enhancers.append(tmp_tuple)
            print("seq 1:{}\t seq 2:{}".format(seq,sequences[idx+1]))

seq 1:(14293760, 14294619, 'gene')	 seq 2:(14294550, 14299720, 'enhancer')
seq 1:(25162990, 25165980, 'enhancer')	 seq 2:(25165768, 25166699, 'gene')
seq 1:(33814021, 33814848, 'gene')	 seq 2:(33814180, 33814420, 'enhancer')
seq 1:(34801495, 34802369, 'gene')	 seq 2:(34801950, 34812330, 'enhancer')
seq 1:(36603070, 36603960, 'gene')	 seq 2:(36603100, 36603750, 'enhancer')
seq 1:(38185274, 38186160, 'gene')	 seq 2:(38185450, 38186300, 'enhancer')
seq 1:(38450440, 38451840, 'enhancer')	 seq 2:(38450484, 38451400, 'gene')
seq 1:(39251358, 39252257, 'gene')	 seq 2:(39251580, 39255920, 'enhancer')
seq 1:(60097735, 60098579, 'gene')	 seq 2:(60098550, 60099640, 'enhancer')
seq 1:(66862580, 66863940, 'enhancer')	 seq 2:(66862712, 66863570, 'gene')
seq 1:(75142351, 75143230, 'gene')	 seq 2:(75142380, 75143000, 'enhancer')
seq 1:(75179690, 75180030, 'enhancer')	 seq 2:(75179934, 75180772, 'gene')
seq 1:(75236178, 75236987, 'gene')	 seq 2:(75236240, 75236490, 'enhancer')
seq 1:(75244480, 75244670

In [21]:
gene_enhancers[1]

(25165768, 25166699, 'chr1')

In [22]:
peak_names = list()
for i in range(len(gene_enhancers)):
  peaks = gene_enhancers[i][2] + '-' + str(gene_enhancers[i][0]) + '-' + str(gene_enhancers[i][1])
  peak_names.append(peaks)

enhancer_atac = df[peak_names]


In [36]:
rna_adata = preprocessing_rna(raw_rna)
print(rna_adata)
atac_adata = ad.AnnData(enhancer_atac,dtype="float32")
print(type(rna_adata))
print(type(atac_adata))

train_atac = atac_adata[:int(len(atac_adata)*0.8)]
test_atac = atac_adata[int(len(atac_adata)*0.8):]

train_rna = rna_adata[:int(len(rna_adata)*0.8)]
test_rna = rna_adata[int(len(rna_adata)*0.8):]

train_scDataset = sc_Dataset(train_rna,train_atac)
test_scDataset = sc_Dataset(test_rna,test_atac)

trainloader = DataLoader(train_scDataset,batch_size=args.train_batch_size,drop_last=False,shuffle=True)
testloader = DataLoader(test_scDataset,batch_size=args.test_batch_size, drop_last=False, shuffle=True)

train_cell_num = train_atac.shape[0]
test_cell_num = test_atac.shape[0]
#the input dimension of the atac data
input_dim_atac = train_atac.shape[1]
input_dim_rna = train_rna.shape[1]

print("The cell number of train set is: ", train_cell_num)
print("The cell number of test set is: ", test_cell_num)
print("The input dim of atac data is: ", input_dim_atac)
print("The input dim of rna data is: ", input_dim_rna)

Raw dataset for scRNA shape: (4482, 24528)


Variable names are not unique. To make them unique, call `.var_names_make_unique`.


Processed dataset shape: (4482, 24528)
AnnData object with n_obs × n_vars = 4482 × 24528
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ATAC', 'nFeature_ATAC', 'nucleosome_signal', 'nucleosome_percentile', 'TSS.enrichment', 'TSS.percentile'
    var: 'features', 'n_cells'
    uns: 'log1p'
<class 'anndata._core.anndata.AnnData'>
<class 'anndata._core.anndata.AnnData'>
The cell number of train set is:  3585
The cell number of test set is:  897
The input dim of atac data is:  1416
The input dim of rna data is:  24528


In [37]:
train_atac.var_names_make_unique()
test_atac.var_names_make_unique()

train_rna.var_names_make_unique()
test_rna.var_names_make_unique()

#import episcanpy as epi
#atac_data = epi.pp.select_var_feature(atac_data, nb_features=30000, show=False, copy=True)

Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.


In [38]:
atac_model = ATAC_VAE(args,n_input=input_dim_atac)
rna_model = RNA_VAE(args,n_input=input_dim_rna)
atac_model, rna_model = atac_model.cuda(), rna_model.cuda()

#if args.pretrained_file is not None:
#    model.load_state_dict(torch.load(args.pretrained_file))
#    print("Pre-trained model loaded")
#    sys.stdout.flush()

In [39]:
atac_optimizer = optim.Adam(atac_model.parameters(), lr = args.learning_rate)
rna_optimizer = optim.Adam(rna_model.parameters(), lr = args.learning_rate)

In [315]:
for epoch in range(args.epochs):

    train(args,epoch,atac_model,rna_model,trainloader,atac_optimizer,rna_optimizer)
    test(args,epoch,atac_model,rna_model,testloader)

Epoch: 0 rna loss:0.006259705871343613 atac loss:0.036455877125263214
Epoch: 0 rna loss:0.006197888404130936 atac loss:0.036590464413166046
Epoch: 0 rna loss:0.0066709816455841064 atac loss:0.039920151233673096
Epoch: 0 rna loss:0.006084425840526819 atac loss:0.039966244250535965
Epoch: 0 rna loss:0.006549293641000986 atac loss:0.03601861000061035
Epoch: 0 rna loss:0.00630862545222044 atac loss:0.03479922190308571
Epoch: 0 rna loss:0.006344647146761417 atac loss:0.03635953739285469
Epoch: 0 rna loss:0.006455526687204838 atac loss:0.04183858633041382
Epoch: 0 rna loss:0.006770836189389229 atac loss:0.04235236719250679
Epoch: 0 rna loss:0.006488323211669922 atac loss:0.03511472046375275
Epoch: 0 rna loss:0.006428971420973539 atac loss:0.037621207535266876
Epoch: 0 rna loss:0.006196955218911171 atac loss:0.03405233845114708
Epoch: 0 rna loss:0.006492896471172571 atac loss:0.038306355476379395
Epoch: 0 rna loss:0.006392261479049921 atac loss:0.04246721789240837
Epoch: 0 rna loss:0.00630837

In [316]:
clusted_samples_idx = [2222, 1864, 628, 1921, 2234, 2222, 3113, 2339, 3068, 238, 2222, 3865, 4471]

In [317]:
clustered_samples = enhancer_atac.iloc[[2222, 1864, 628, 1921, 2234, 2222, 3113, 2339, 3068, 238, 2222, 3865, 4471]]

In [318]:
clustered_samples.index

Index(['GACCGTTCACAGGAAT-1', 'CTAAGGTTCACAGACT-1', 'AGCATTTCATTGCAGC-1',
       'CTATGACAGTTATGTG-1', 'GACCTGATCGCAAACT-1', 'GACCGTTCACAGGAAT-1',
       'GTATTGCAGTTAGTGC-1', 'GAGTCATTCGCGCTAA-1', 'GTAGCCATCCCATAGG-1',
       'AATTACCCAGGGAGGA-1', 'GACCGTTCACAGGAAT-1', 'TCTATGTTCCTTGAGG-1',
       'TTTGTGAAGCCGCTAA-1'],
      dtype='object')

In [377]:
clustered_atac_adata = ad.AnnData(clustered_samples,dtype="float32")
clustered_atac = clustered_atac_adata.X.squeeze()
print(clustered_atac.shape)

np.set_printoptions(precision=4,suppress=True)
clustered_atac_ana = clustered_atac[0]
print(clustered_atac_ana)
print(clustered_atac_ana.shape)

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.


(13, 1416)
[90.     20.3566 21.2481 20.5349 20.     20.3566 20.5349 20.     20.1783
 20.3566 20.7132 20.3566 20.     20.     20.     20.7132 20.     20.
 20.1783 20.3566 20.     21.4264 20.     20.8915 20.3566 20.     20.
 20.5349 20.1783 20.1783 20.3566 20.1783 20.     20.     20.     20.
 20.     20.     20.     20.5349 20.5349 20.1783 20.     20.     21.0698
 20.     20.1783 20.     21.2481 20.     20.3566 20.8915 20.3566 20.
 20.     21.0698 20.7132 20.8915 20.     20.     20.     20.     20.
 20.     20.1783 20.8915 20.     20.3566 20.7132 20.     20.5349 20.7132
 20.7132 20.7132 20.     20.1783 20.1783 20.     20.3566 21.0698 20.
 20.     20.     20.3566 20.     20.     20.7132 20.3566 20.     20.3566
 20.     20.5349 20.3566 20.3566 20.3566 20.     20.3566 20.3566 20.
 20.3566 20.     20.     20.3566 20.1783 20.3566 20.7132 20.3566 20.3566
 20.3566 20.     20.     20.7132 20.3566 20.5349 20.     20.     20.3566
 20.     20.     20.     20.3566 20.     20.8915 20.     20.3566 20.

In [365]:
def compute_jacobian(inputs, output):
  """
  :param inputs: Batch X Size (e.g. Depth X Width X Height)
  :param output: Batch X Classes
  :return: jacobian: Batch X Classes X Size
  """
  assert inputs.requires_grad

  num_classes = output.size()[1]
  print("num_classes:{}".format(num_classes))
  jacobian = torch.zeros(num_classes, *inputs.size())
  grad_output = torch.zeros(*output.size())

  if inputs.is_cuda:
    grad_output = grad_output.cuda()
    jacobian = jacobian.cuda()

  for i in range(num_classes):
    grad_output.zero_()
    grad_output[:, i] = 1
    output.backward(grad_output, retain_graph=True)
    jacobian[i] = inputs.grad.data
    if i %100==0:
      print(i)

  return torch.transpose(jacobian, dim0=0, dim1=1)

In [366]:
clustered_atac_eps = 10*torch.ones((clustered_atac_ana.shape))
clustered_atac_ana = torch.add(clustered_atac_ana,clustered_atac_eps)

clustered_atac_tensor = clustered_atac_tensor
clustered_atac_input_added = clustered_atac_tensor.cuda().requires_grad_()
clustered_atac_input_origin = clustered_atac_tensor.cuda().requires_grad_()

clustered_atac_recon_inputs_added, clustered_atac_latents_added, clustered_atac_mu_added, clustered_atac_logvar_added = atac_model(clustered_atac_input_added)
clustered_atac_output_added = rna_model.decode(clustered_atac_latents_added)
clustered_atac_recon_inputs_origin, clustered_atac_latents_origin, clustered_atac_mu_origin, clustered_atac_logvar_origin = atac_model(clustered_atac_input_origin)
clustered_atac_output_origin = rna_model.decode(clustered_atac_latents_origin)
clustered_derivative = torch.sub(clustered_atac_output_added,clustered_atac_output_origin)
jacobian_mat = compute_jacobian(clustered_atac_input_added,clustered_derivative)
print(jacobian_mat.shape)

num_classes:24528
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200
14300
14400
14500
14600
14700
14800
14900
15000
15100
15200
15300
15400
15500
15600
15700
15800
15900
16000
16100
16200
16300
16400
16500
16600
16700
16800
16900
17000
17100
17200
17300
17400
17500
17600
17700
17800
17900
18000
18100
18

In [373]:
jacobian_mat_squeezed = jacobian_mat.squeeze()
jacobian_mat_squeezed_numpy = jacobian_mat_squeezed.cpu().detach().numpy()
print(jacobian_mat_squeezed)

tensor([[ 6.9324e-04, -7.2893e-05, -4.2670e-04,  ..., -7.9549e-04,
          2.4947e-05, -3.2211e-04],
        [ 6.4632e-04,  1.4845e-04, -1.8757e-04,  ..., -4.6115e-04,
          2.2009e-04, -6.4536e-04],
        [ 2.4125e-05,  4.1106e-04, -6.1556e-05,  ..., -4.3516e-04,
          3.7101e-04, -5.1749e-04],
        ...,
        [-1.4954e+00,  3.7584e-01,  8.8122e-01,  ...,  2.7002e-01,
          6.4553e-01,  1.6389e-01],
        [-1.4955e+00,  3.7584e-01,  8.8116e-01,  ...,  2.7001e-01,
          6.4549e-01,  1.6395e-01],
        [-1.4961e+00,  3.7635e-01,  8.8173e-01,  ...,  2.7043e-01,
          6.4614e-01,  1.6367e-01]], device='cuda:0')


In [371]:
np.save("drive/MyDrive/output_enhancer/enhancer_jacobian.npy",jacobian_mat_squeezed_numpy)

In [225]:
clustered_atac_recon_inputs_added, clustered_atac_latents_added, clustered_atac_mu_added, clustered_atac_logvar_added = atac_model(clustered_atac_input_added)
clustered_atac_output_added = rna_model.decode(clustered_atac_latents_added)

In [224]:
clustered_atac_input_tensor = clustered_atac_tensor.cuda()
clustered_atac_recon_inputs_origin, clustered_atac_latents_origin, clustered_atac_mu_origin, clustered_atac_logvar_origin = atac_model(clustered_atac_input_tensor)
clustered_atac_output_origin = rna_model.decode(clustered_atac_latents_origin)

In [252]:
derivation_y = torch.autograd.grad(torch.sub(clustered_atac_output_added,clustered_atac_output_origin),torch.from_numpy(clustered_atac[0]))
derivation_y_numpy = derivation_y.cpu().detach().numpy()

RuntimeError: ignored

In [241]:
derivation_max = np.argmax(derivation_y_numpy,axis=0)

(24528,)
