In [None]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import scanpy as sc
from scipy.sparse import issparse
import episcanpy as epi

#import model as AENet
#import dataloader

import argparse
import numpy as np
import sys
import os
import imageio
import pandas as pd

In [None]:
# parse arguments
options = argparse.ArgumentParser()

options.add_argument('--file_path', '-f', type=str, nargs='+', default = 'drive/MyDrive/B2_Model/dataset/pbmc_granulocyte_sorted_atac.h5ad')
options.add_argument('--save-dir', action="store", dest="save_dir",default="outputs/")
options.add_argument('-pt', action="store", dest="pretrained_file", default=None)
options.add_argument('-bs', action="store", dest="batch_size", default = 128, type = int)
options.add_argument('-ds', action="store", dest="datadir", default = "drive/MyDrive/MultiDomainTranslation/data/nuclear_crops_all_experiments/")
options.add_argument('-lamb', action="store", dest="lamb", default=0.0000001, type = float)
options.add_argument('-lamb2', action="store", dest="lamb2", default=0.001, type = float)
options.add_argument('--conditional', action="store_true")
    
# encoder_dim = [encoder_dim_0,encoder_dim_1]
options.add_argument('--encoder-dim-0',type=int,default=1024)
options.add_argument('--encoder-dim-1',type=int,default=128)
options.add_argument('--latent-dim',type=int,default=10)

# draw graph
options.add_argument('--outdir', '-o', type=str, default='output/', help='Output path')
options.add_argument('--embed', type=str, default='UMAP')
options.add_argument('--cluster_method', type=str, default='leiden')
options.add_argument('--cluster-num',type=int,default=30)
    
# model parameters
options.add_argument('--n-centroids', type=int, default = 20)
options.add_argument('--n-hidden',type=int,default = 1024, help = ' number of hidden layers.')
options.add_argument('--n-latent',type=int,default = 10, help = 'number of dimensions of latent space.')
options.add_argument('--min-peaks', type=int, default=600, help='Remove low quality cells with few peaks')
options.add_argument('--min-genes', type=int, default=200, help='rna min genes')
options.add_argument('--min-cells', type=float, default=0.01, help='Remove low quality peaks')
options.add_argument('--atac-classifier',type=str,default="SVM",help="type of classifer used in the atac vae.")

# experiment settings
options.add_argument('--epochs', default = 100, type = int, help="number of training epochs.")
options.add_argument('-learning-rate', default=1e-4, type = float, help="learning rate.")
options.add_argument('--is-validate',default = False,type =bool, help='whether to verify model validity with a classifier.')
options.add_argument('--train-batch-size', type=float, default=100, help='train batch size')
options.add_argument('--test-batch-size', type=float, default=100, help='test batch size')

# task settigns
options.add_argument('--translation', default=True, type=bool, help="whether to perform the translation between RNAseq and ATACseq.")
options.add_argument('--reconstruction', default=True, type=bool, help="whether perform the reconstructon results.")
options.add_argument('--latent-space',default=False, type=bool, help="whether to perform the latent space results.")
args = options.parse_args(args=[])

In [None]:
def get_gamma(n_centroids, z, pi, mu_c, var_c):
        """
        Inference c from z
        gamma is q(c|x)
        The formula is p(z)q(c|x) = q(z,c|x) = p(c|z)p(z)
        since c,x are conditional independent on z
        the result we need is:
        q(c|x) = p(c|z) = p(c)p(z|c)/p(z)
        """

        N = z.size(0)
        z = z.unsqueeze(2).expand(z.size(0), z.size(1), n_centroids)
        pi = self.pi.repeat(N, 1) # NxK
        mu_c = self.mu_c.repeat(N,1,1) # NxDxK
        var_c = self.var_c.repeat(N,1,1) + 1e-8 # NxDxK

        # p(c,z) = p(c)*p(z|c) as p_c_z
        p_c_z = torch.exp(torch.log(pi) - torch.sum(0.5*torch.log(2*math.pi*var_c) + (z-mu_c)**2/(2*var_c), dim=1)) + 1e-10
        gamma = p_c_z / torch.sum(p_c_z, dim=1, keepdim=True)

        return gamma, mu_c, var_c, pi

In [None]:
class ATAC_classifier(nn.Module):
    
    def _init_(self,args,n_input):
        super(ATAC_classifier,self).__init__()
        # network
        self.n_input = n_input
        self.transform = nn.Linear(self.n_input,self.n_input)
        
        nn.init.xavier_normal_(self.transform.weight.data)
        nn.init.zeros_(self.transform.bias.data)
        
        

In [None]:
class RNA_VAE(nn.Module):
    """Fully connected variational Autoencoder"""
    def __init__(self, args,n_input):
        super(RNA_VAE, self).__init__()
        self.n_latent = args.n_latent
        self.n_input = n_input
        self.n_hidden = args.n_hidden

        self.encoder = nn.Sequential(
                                nn.Linear(self.n_input, self.n_hidden),
                                nn.ReLU(inplace=True),
                                nn.BatchNorm1d(self.n_hidden),
                                nn.Linear(self.n_hidden, self.n_hidden),
                                nn.BatchNorm1d(self.n_hidden),
                                nn.ReLU(inplace=True),
                                nn.Linear(self.n_hidden, self.n_hidden),
                                nn.BatchNorm1d(self.n_hidden),
                                nn.ReLU(inplace=True),
                                nn.Linear(self.n_hidden, self.n_hidden),
                                nn.BatchNorm1d(self.n_hidden),
                                nn.ReLU(inplace=True),
                                nn.Linear(self.n_hidden, self.n_hidden),
                                )

        self.fc1 = nn.Linear(self.n_hidden, self.n_latent)
        self.fc2 = nn.Linear(self.n_hidden, self.n_latent)

        self.decoder = nn.Sequential(nn.Linear(self.n_latent, self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_input),
                                    )
        self.encoder = self.encoder
        self.decoder = self.decoder

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        res = self.decode(z)
        return res, z, mu, logvar

    def encode(self, x):
        x = x
        h = self.encoder(x)
        return self.fc1(h), self.fc2(h)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def decode(self, z):
        return self.decoder(z)

    def get_latent_var(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return z
 
    def generate(self, z):
        res = self.decode(z)
        return res

In [None]:
class ATAC_VAE(nn.Module):
    """Fully connected variational Autoencoder"""
    def __init__(self, args,n_input):
        super(ATAC_VAE, self).__init__()
        self.n_latent = args.n_latent
        self.n_input = n_input
        self.n_hidden = args.n_hidden
        self.n_centroids = args.n_centroids
        
        if args.atac_classifier == "RandomForest":
            from sklearn.ensemble import RandomForestClassifier
            self.classifier = RandomForestClassifier()
        elif args.atac_classifier == "DecisionTree":
            from sklearn import tree
            self.classifier = tree.DecisionTreeClassifier()
        elif args.atac_classifier == "GBDT":
            from sklearn.ensemble import GradientBoostingClassifier
            self.classifier = GradientBoostingClassifier()
        elif args.atac_classifier == "AdaBoost":
            from sklearn.ensemble import  AdaBoostClassifier
            self.classifier = AdaBoostClassifier()
        elif args.atac_classifier == "SVM":
            from sklearn.svm import SVC
            self.classifier = SVC(kernel='rbf')
        
        # condition c for converting the distribution to standard normal distribution
        #z_dim = dims[1]

        # init c_params
        #self.pi = nn.Parameter(torch.ones(n_centroids)/n_centroids)  # pc
        #self.mu_c = nn.Parameter(torch.zeros(z_dim, n_centroids)) # mu
        #self.var_c = nn.Parameter(torch.ones(z_dim, n_centroids)) # sigma^2

        self.fc1 = nn.Linear(self.n_hidden, self.n_latent)
        self.fc2 = nn.Linear(self.n_hidden, self.n_latent)

        self.layer_1 = nn.Linear(self.n_input,self.n_hidden)
        self.batch_norm_1 = nn.BatchNorm1d(self.n_hidden)
        self.layer_2 = nn.Linear(self.n_hidden,self.n_hidden)
        self.batch_norm_2 = nn.BatchNorm1d(self.n_hidden)
        self.layer_3 = nn.Linear(self.n_hidden,self.n_hidden)
        self.batch_norm_3 = nn.BatchNorm1d(self.n_hidden)
        self.layer_4 = nn.Linear(self.n_hidden,self.n_hidden)
        self.batch_norm_4 = nn.BatchNorm1d(self.n_hidden)
        self.layer_5 = nn.Linear(self.n_hidden,self.n_hidden)
        self.loss_func = nn.ReLU()

        self.decoder = nn.Sequential(nn.Linear(self.n_latent, self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_hidden),
                                     nn.BatchNorm1d(self.n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(self.n_hidden, self.n_input),
                                    )
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        #gamma, mu_c, var_c, pi = get_gamma(n_centroids, z, pi, mu_c, var_c)
        res = self.decode(z)
        return res, z, mu, logvar

    def encode(self, x):

        x = self.layer_1(x)
        x = self.loss_func(x)
        x = self.batch_norm_1(x)
        x = self.layer_2(x)
        x = self.batch_norm_2(x)
        x = self.loss_func(x)
        x = self.layer_3(x)
        x = self.batch_norm_3(x)
        x = self.loss_func(x)
        x = self.layer_4(x)
        x = self.batch_norm_4(x)
        x = self.loss_func(x)
        h = self.layer_5(x)
        return self.fc1(h), self.fc2(h)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def decode(self, z):
        return self.decoder(z)

    def get_latent_var(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return z
 
    def generate(self, z):
        res = self.decode(z)
        return res

In [None]:
def read_mtx(path):
    """
    Read mtx format data folder including: 
        matrix file: e.g. count.mtx or matrix.mtx
        barcode file: e.g. barcode.txt
        feature file: e.g. feature.txt
    """
    for filename in glob(path+'/*'):
        if ('count' in filename or 'matrix' in filename or 'data' in filename) and ('mtx' in filename):
            adata = sc.read_mtx(filename).T
    for filename in glob(path+'/*'):
        if 'barcode' in filename:
            barcode = pd.read_csv(filename, sep='\t', dtype = str, header=None).iloc[:, -1].values
            print(len(barcode), adata.shape[0])
            if len(barcode) != adata.shape[0]:
                adata = adata.transpose()
            adata.obs = pd.DataFrame(index=barcode)
            print(adata.shape)
        if 'gene' in filename or 'peaks' in filename or 'feature' in filename:
            gene = pd.read_csv(filename, sep='\t', dtype = str, header=None).iloc[:, -1].values
            if len(gene) != adata.shape[1]:
                adata = adata.transpose()
            adata.var = pd.DataFrame(index=gene)
    return adata

In [None]:
def load_file(path):  
    """
    Load single cell dataset from file
    """
    if os.path.exists(path+'.h5ad'):
        adata = sc.read_h5ad(path+'.h5ad')
        
    #isdir needs to be ended with a '\'
    elif os.path.isdir(path): # mtx format
        adata = read_mtx(path)
    elif os.path.isfile(path):
        if path.endswith(('.csv', '.csv.gz')):
            adata = sc.read_csv(path).T
        elif path.endswith(('.txt', '.txt.gz', '.tsv', '.tsv.gz')):
            df = pd.read_csv(path, sep='\t', index_col=0).T
            adata = AnnData(df.values, dict(obs_names=df.index.values), dict(var_names=df.columns.values))
        elif path.endswith(('.mtx', '.mtx.gz')):
            adata = read_mtx(path)
        elif path.endswith('.h5'): 
            adata = sc.read_10x_h5(path)
        elif path.endswith('.h5ad'):
            adata = sc.read_h5ad(path)
    else:
        raise ValueError("File {} not exists".format(path))
        
    if not issparse(adata.X):
        adata.X = scipy.sparse.csr_matrix(adata.X)
    
    adata.var_names_make_unique()
    return adata

In [None]:
def preprocessing_atac(
        adata, 
        min_genes=200, 
        min_cells=0.01, 
        n_top_genes=30000,
        target_sum=None
    ):
    """
    preprocessing
    """
    print('Raw dataset for scATAC shape: {}'.format(adata.shape))

    if not issparse(adata.X):
        adata.X = scipy.sparse.csr_matrix(adata.X)
        
    #binarize:
    #adata.X[adata.X.nonzero()] = 1
    
    #Filtering cells
    #sc.pp.filter_cells(adata, min_genes=min_genes)
    
    #'Filtering genes'
    if min_cells < 1:
        min_cells = min_cells * adata.shape[0]
    sc.pp.filter_genes(adata, min_cells=min_cells)
    
    #Finding variable features
    adata = epi.pp.select_var_feature(adata, nb_features=n_top_genes, show=False, copy=True)
    
    #Normalizing total per cell
    sc.pp.normalize_total(adata, target_sum=target_sum)
        

    print('Processed dataset shape: {}'.format(adata.shape))
    return adata

In [None]:
def preprocessing_rna(
    adata,
    min_genes=200, 
    min_cells=3, 
    target_sum=1e4
):
    """
    preprocessing
    """
    print('Raw dataset for scRNA shape: {}'.format(adata.shape))
    if not issparse(adata.X):
        adata.X = scipy.sparse.csr_matrix(adata.X)

    #Filtering cells
    #sc.pp.filter_cells(adata, min_genes=min_genes)

    #Flitering genes
    sc.pp.filter_genes(adata, min_cells=min_cells)

    #LogNormalize
    sc.pp.normalize_total(adata, target_sum=target_sum)
    sc.pp.log1p(adata)
    adata.raw = adata

    print('Processed dataset shape: {}'.format(adata.shape))
    return adata

In [None]:
class sc_Dataset(Dataset):
    """
    Dataset for dataloader
    """
    def __init__(self, rna_adata, atac_adata):
        self.rna_adata = rna_adata
        self.atac_adata = atac_adata

        self.rna_shape = rna_adata.shape
        self.atac_shape = atac_adata.shape

    def __len__(self):
        return self.rna_adata.X.shape[0]
    
    def __getitem__(self, idx):
        rna_x = self.rna_adata.X[idx].toarray().squeeze()
        rna_y = self.rna_adata.obs_names[idx]
        atac_x = self.atac_adata.X[idx].toarray().squeeze()
        atac_y = self.atac_adata.obs_names[idx]
#        domain_id = self.rna_adata.obs['batch'].cat.codes[idx]
#        return x, domain_id, idx
        return rna_x,atac_x,rna_y,atac_y

In [None]:
## compute loss
def loss_function(recon_x, x, mu, logvar, latents):
    MSE = nn.MSELoss()
    lloss = MSE(recon_x,x)

    if args.lamb>0:
        KL_loss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        lloss = lloss + args.lamb*KL_loss

    return lloss

def compute_shared_loss(real_seq,generated_seq):
    
    MSE = nn.MSELoss()
    lloss = MSE(generated_seq,real_seq)
    
    return lloss

def compute_loss(latent, output):

    latent_dim = latent.shape[0]
    latent_ones = torch.ones(output.shape).cuda()
    loss = 1/latent_dim * torch.sum(latent*torch.log(output) + (latent_ones-latent)*torch.log(latent_ones-output))

    return loss

def compute_atac_regularization(classifier,generated_seq,real_seq):

    atac_regularization_fn = torch.nn.MSEloss()
    atac_regularization = 0
    
    for generated_sample,real_sample in zip(generated_seq,real_seq):
        print(generated_sample.shape)
        predicted_sample = classifier.predict(generated_sample)
        atac_regularization += atac_regularization_fn(predicted_sample,real_sample)
    print("MSE loss:{}".format(atac_regularization))
    
    return atac_regularization

In [None]:
def train(args,epoch,atac_model,rna_model,trainloader,atac_optimizer,rna_optimizer):
    
    atac_model.train()
    rna_model.train()

    train_loss = 0
    total_clf_loss = 0
    
    atac_recon_output, rna_recon_output = None, None
    atac_latents_output, rna_latents_output = None, None
    atac_trans_output, rna_trans_output = None, None
    atac_targets, rna_targets = None,None

    for batch_idx, train_batch in enumerate(trainloader):

        rna_batch, atac_batch, rna_batch_y, atac_batch_y = train_batch
        rna_batch, atac_batch = rna_batch.type(torch.FloatTensor), atac_batch.type(torch.FloatTensor)
        rna_batch, atac_batch = rna_batch.cuda(), atac_batch.cuda()
        rna_batch_y, atac_batch_y = list(rna_batch_y), list(atac_batch_y)
 
        if len(rna_batch) == 1 or len(atac_batch) == 1:
          continue

        atac_optimizer.zero_grad()
        rna_optimizer.zero_grad()

        atac_recon_inputs, atac_latents, atac_mu, atac_logvar = atac_model(atac_batch)
        atac_loss = loss_function(atac_recon_inputs, atac_batch, atac_mu, atac_logvar, atac_latents)

        rna_recon_inputs, rna_latents, rna_mu, rna_logvar = rna_model(rna_batch)
        rna_loss = loss_function(rna_recon_inputs, rna_batch, rna_mu, rna_logvar, rna_latents)
    
        atac_output = atac_model.decode(rna_latents)
        rna_output = rna_model.decode(atac_latents)
        
        rna_shared_loss = compute_shared_loss(rna_output,rna_batch)
        atac_shared_loss = compute_shared_loss(atac_output,atac_batch)
        
        atac_loss += atac_shared_loss
        rna_loss += rna_shared_loss
        
        """
        atac_extra_loss = 0
        if batch_idx != 0:
            atac_extra_loss = copmute_atac_regularization(atac_model,atac_recon_inputs,atac_batch)
        
        atac_loss += atac_extra_loss
        _atac_recon_inputs = atac_recon_inputs.reshape(-1,1)
        _atac_batch = atac_batch.reshape(-1,1)
        atac_model.classifier.fit(_atac_recon_inputs.cpu().detach().numpy(),_atac_batch.cpu().detach().numpy())
        """
        
        atac_loss.backward(retain_graph=True)
        rna_loss.backward()

        atac_optimizer.step()
        rna_optimizer.step()
        

        ## visualization
        # atac encoder + rna decoder
        #atac_recon_inputs, atac_latents, atac_mu, atac_logvar = atac_model(atac_batch)
        
        #atac_latents = torch.sigmoid(atac_latents)
        #atac_recon_inputs = torch.sign(atac_recon_inputs)
        
        if epoch == args.epochs-1:
            
            if rna_targets is None or atac_targets is None:
                rna_targets, atac_targets = rna_batch_y, atac_batch_y
            else:
                rna_targets = rna_targets + rna_batch_y
                atac_targets = atac_targets + atac_batch_y

            if args.reconstruction is True:
                atac_recon_numpy = atac_recon_inputs.cpu().detach().numpy()
                if atac_recon_output is None:
                    atac_recon_output = atac_recon_numpy
                else:
                    atac_recon_output = np.concatenate((atac_recon_output,atac_recon_numpy),axis=0)
                
                rna_recon_numpy = rna_recon_inputs.cpu().detach().numpy()

                if rna_recon_output is None:
                    rna_recon_output = rna_recon_numpy
                else:
                    rna_recon_output = np.concatenate((rna_recon_output,rna_recon_numpy),axis=0)

            
            if args.latent_space is True:
                atac_latent_numpy = atac_latents.cpu().detach().numpy()
                
                if atac_latents_output is None:
                    atac_latents_output = atac_latent_numpy
                else:
                    atac_latents_output = np.concatenate((atac_latents_output,atac_latent_numpy),axis=0)
                    
                rna_latent_numpy = rna_latents.cpu().detach().numpy()
                
                if rna_latents_output is None:
                    rna_latents_output = rna_latent_numpy
                else:
                    rna_latents_output = np.concatenate((rna_latents_output,rna_latent_numpy),axis=0)
                
            if args.translation is True:
                
                atac_output = atac_model.decode(rna_latents)
                atac_trans_numpy = atac_output.cpu().detach().numpy()
                
                if atac_trans_output is None:
                    atac_trans_output = atac_trans_numpy
                else:
                    atac_trans_output = np.concatenate((atac_trans_output,atac_trans_numpy),axis=0)

                rna_output = rna_model.decode(atac_latents)
                rna_trans_numpy = rna_output.cpu().detach().numpy()
                
                if rna_trans_output is None:
                    rna_trans_output = rna_trans_numpy
                else:
                    rna_trans_output = np.concatenate((rna_trans_output,rna_trans_numpy),axis=0)
                
                
        ## method 1
        #rna_output = torch.sigmoid(rna_output)
        #rna_output = torch.sign(rna_output)

        # rna encoder + atac decoder
        #rna_recon_inputs, rna_latents, rna_mu, rna_logvar = rna_model(rna_batch)
        #atac_output = atac_model.decode(rna_latents)
        
        ## method 2
        #atac_output = torch.sigmoid(atac_output)
        #atac_output = torch.sign(atac_output)
        
        """
        if args.conditional:
            targets = Variable(samples['binary_label'])
            if torch.cuda.is_available():
                targets = targets.cuda()
            clf_outputs = netCondClf(latents)
            class_clf_loss = CE(clf_outputs, targets.view(-1).long())
            loss += args.lamb2 * class_clf_loss
            total_clf_loss += class_clf_loss.data.item() * inputs.size(0)
        """
        print('Epoch: {} rna loss:{} atac loss:{}'.format(epoch, rna_loss / len(trainloader), atac_loss / len(trainloader)))
  
    #print(rna_latents_output)
    if epoch == args.epochs-1:
        
        with open("output/rna_targets_train.txt",'w') as output_targets:
            for target in rna_targets:
                output_targets.write(target)
                output_targets.write('\n')
        
        with open("output/atac_targets_train.txt",'w') as output_targets:
            for target in atac_targets:
                output_targets.write(target)
                output_targets.write('\n')            
        
        if args.reconstruction is True:
            np.save("output/rna_recon_output_train.npy",rna_latents_output)
            np.save("output/atac_recon_output_train.npy",atac_latents_output)
            
        if args.latent_space is True:
            np.save("output/rna_latents_output_train.npy",rna_latents_output)
            np.save("output/atac_latents_output_train.npy",atac_latents_output)
        
        if args.translation is True:
            np.save("output/rna_trans_output_train.npy", rna_trans_output)
            np.save("output/atac_trans_output_train.npy",atac_trans_output)
        
        


In [None]:
def test(args,epoch,atac_model,rna_model,testloader):
    
    atac_model.eval()
    rna_model.eval()

    test_loss = 0
    total_atac_loss = 0
    total_rna_loss = 0

    atac_latents_output, rna_latents_output = None, None
    atac_recon_output, rna_recon_output = None, None
    atac_trans_output, rna_trans_output = None, None
    atac_targets, rna_targets = None, None

    for batch_idx, test_batch in enumerate(testloader):

        rna_batch, atac_batch, rna_batch_y, atac_batch_y = test_batch
        rna_batch, atac_batch = rna_batch.type(torch.FloatTensor), atac_batch.type(torch.FloatTensor)
        rna_batch, atac_batch = rna_batch.cuda(), atac_batch.cuda()
        rna_batch_y, atac_batch_y = list(rna_batch_y), list(atac_batch_y)
        if len(rna_batch) == 1 or len(atac_batch) == 1:
            continue
 
        atac_recon_inputs, atac_latents, atac_mu, atac_logvar = atac_model(atac_batch)
        atac_loss = loss_function(atac_recon_inputs, atac_batch, atac_mu, atac_logvar, atac_latents)

        rna_recon_inputs, rna_latents, rna_mu, rna_logvar = rna_model(rna_batch)
        rna_loss = loss_function(rna_recon_inputs, rna_batch, rna_mu, rna_logvar, rna_latents)
        
        total_atac_loss += atac_loss
        total_rna_loss += rna_loss

        if epoch == args.epochs-1:

            if rna_targets is None or atac_targets is None:
                rna_targets, atac_targets = rna_batch_y, atac_batch_y
            else:
                rna_targets = rna_targets + rna_batch_y
                atac_targets = atac_targets + atac_batch_y

            if args.reconstruction is True:
                atac_recon_numpy = atac_recon_inputs.cpu().detach().numpy()
                if atac_recon_output is None:
                    atac_recon_output = atac_recon_numpy
                else:
                    atac_recon_output = np.concatenate((atac_recon_output,atac_recon_numpy),axis=0)
                
                rna_recon_numpy = rna_recon_inputs.cpu().detach().numpy()

                if rna_recon_output is None:
                    rna_recon_output = rna_recon_numpy
                else:
                    rna_recon_output = np.concatenate((rna_recon_output,rna_recon_numpy),axis=0)

            
            if args.latent_space is True:
                atac_latent_numpy = atac_latents.cpu().detach().numpy()
                
                if atac_latents_output is None:
                    atac_latents_output = atac_latent_numpy
                else:
                    atac_latents_output = np.concatenate((atac_latents_output,atac_latent_numpy),axis=0)
                    
                rna_latent_numpy = rna_latents.cpu().detach().numpy()
                
                if rna_latents_output is None:
                    rna_latents_output = rna_latent_numpy
                else:
                    rna_latents_output = np.concatenate((rna_latents_output,rna_latent_numpy),axis=0)
                
            if args.translation is True:
                
                atac_output = atac_model.decode(rna_latents)
                atac_trans_numpy = atac_output.cpu().detach().numpy()
                
                if atac_trans_output is None:
                    atac_trans_output = atac_trans_numpy
                else:
                    atac_trans_output = np.concatenate((atac_trans_output,atac_trans_numpy),axis=0)

                rna_output = rna_model.decode(atac_latents)
                rna_trans_numpy = rna_output.cpu().detach().numpy()
                
                if rna_trans_output is None:
                    rna_trans_output = rna_trans_numpy
                else:
                    rna_trans_output = np.concatenate((rna_trans_output,rna_trans_numpy),axis=0)
    
    total_atac_loss /= len(testloader)
    total_rna_loss /= len(testloader)

    print('Test atac loss:{} Test rna loss:{} '.format(total_atac_loss, total_rna_loss))
    
    if epoch == args.epochs-1:

        with open("output/rna_targets_test.txt",'w') as output_targets:
            for target in rna_targets:
                output_targets.write(target)
                output_targets.write('\n')
        
        with open("output/atac_targets_test.txt",'w') as output_targets:
            for target in atac_targets:
                output_targets.write(target)
                output_targets.write('\n')   
                
        if args.reconstruction is True:
            np.save("output/rna_recon_output_test.npy",rna_latents_output)
            np.save("output/atac_recon_output_test.npy",atac_latents_output)
            
        if args.latent_space is True:
            np.save("output/rna_latents_output_test.npy",rna_latents_output)
            np.save("output/atac_latents_output_test.npy",atac_latents_output)
        
        if args.translation is True:
            np.save("output/rna_trans_output_test.npy", rna_trans_output)
            np.save("output/atac_trans_output_test.npy",atac_trans_output)
    
    return total_atac_loss, total_rna_loss


In [None]:
def cluster(adata):
    sc.pp.neighbors(adata, n_neighbors=30, use_rep='latent')
    if args.cluster_method == 'leiden':
        sc.tl.leiden(adata)
    elif args.cluster_method == 'kmeans':
        kmeans = KMeans(n_clusters=k, n_init=20, random_state=0)
        adata.obs['kmeans'] = kmeans.fit_predict(adata.obsm['latent']).astype(str)
    return adata


In [None]:
def save(epoch):
    model_dir = os.path.join(args.save_dir, "models")
    os.makedirs(model_dir, exist_ok=True)
    torch.save(model.cpu().state_dict(), os.path.join(model_dir, str(epoch)+".pth"))
    if torch.cuda.is_available():
        model.cuda()

In [None]:
## main

## atac data loading

atac_path = 'B2_Model/dataset/E18_mouse_brain/e18_mouse_brain_fresh_5k_atac_fragments.h5ad'
rna_path = 'B2_Model/dataset/E18_mouse_brain/e18_mouse_brain_fresh_5k_rna_fragments.h5ad'

outdir = args.outdir+'/'
if not os.path.exists(outdir):
    os.makedirs(outdir)

atac_adata = load_file(atac_path)
rna_adata = load_file(rna_path)

atac_adata = preprocessing_atac(atac_adata, min_genes=args.min_peaks, min_cells=args.min_cells)
rna_adata = preprocessing_rna(rna_adata, min_genes=args.min_genes, min_cells=args.min_cells)

train_atac = atac_adata[:int(len(atac_adata)*0.8)]
test_atac = atac_adata[int(len(atac_adata)*0.8):]

train_rna = rna_adata[:int(len(rna_adata)*0.8)]
test_rna = rna_adata[int(len(rna_adata)*0.8):]

train_scDataset = sc_Dataset(train_rna,train_atac)
test_scDataset = sc_Dataset(test_rna,test_atac)

trainloader = DataLoader(train_scDataset,batch_size=args.train_batch_size,drop_last=False,shuffle=True)
testloader = DataLoader(test_scDataset,batch_size=args.test_batch_size, drop_last=False, shuffle=True)

train_cell_num = train_atac.shape[0]
test_cell_num = test_atac.shape[0]
#the input dimension of the atac data
input_dim_atac = train_atac.shape[1]
input_dim_rna = train_rna.shape[1]

print("The cell number of train set is: ", train_cell_num)
print("The cell number of test set is: ", test_cell_num)
print("The input dim of atac data is: ", input_dim_atac)
print("The input dim of rna data is: ", input_dim_rna)

Raw dataset for scATAC shape: (4474, 172193)




Processed dataset shape: (4474, 30040)
Raw dataset for scRNA shape: (4474, 32285)
Processed dataset shape: (4474, 24526)
The cell number of train set is:  3579
The cell number of test set is:  895
The input dim of atac data is:  30040
The input dim of rna data is:  24526


In [None]:
train_atac.var_names_make_unique()
test_atac.var_names_make_unique()

train_rna.var_names_make_unique()
test_rna.var_names_make_unique()

#import episcanpy as epi
#atac_data = epi.pp.select_var_feature(atac_data, nb_features=30000, show=False, copy=True)

In [None]:
atac_model = ATAC_VAE(args,n_input=input_dim_atac)
rna_model = RNA_VAE(args,n_input=input_dim_rna)
atac_model, rna_model = atac_model.cuda(), rna_model.cuda()

#if args.pretrained_file is not None:
#    model.load_state_dict(torch.load(args.pretrained_file))
#    print("Pre-trained model loaded")
#    sys.stdout.flush()

In [None]:
atac_optimizer = optim.Adam(atac_model.parameters(), lr = args.learning_rate)
rna_optimizer = optim.Adam(rna_model.parameters(), lr = args.learning_rate)

In [None]:
for epoch in range(args.epochs):

    train(args,epoch,atac_model,rna_model,trainloader,atac_optimizer,rna_optimizer)
    test(args,epoch,atac_model,rna_model,testloader)

Epoch: 0 rna loss:0.021122898906469345 atac loss:0.07802591472864151
Epoch: 0 rna loss:0.019635867327451706 atac loss:0.06382790207862854
Epoch: 0 rna loss:0.018651017919182777 atac loss:0.07189861685037613
Epoch: 0 rna loss:0.017443103715777397 atac loss:0.06527723371982574
Epoch: 0 rna loss:0.016666190698742867 atac loss:0.05098056048154831
Epoch: 0 rna loss:0.015791453421115875 atac loss:0.06674210727214813
Epoch: 0 rna loss:0.015148041769862175 atac loss:0.060279540717601776
Epoch: 0 rna loss:0.014279693365097046 atac loss:0.05454335734248161
Epoch: 0 rna loss:0.013706696219742298 atac loss:0.05197516083717346
Epoch: 0 rna loss:0.013210800476372242 atac loss:0.05643736571073532
Epoch: 0 rna loss:0.012745600193738937 atac loss:0.07362181693315506
Epoch: 0 rna loss:0.012171589769423008 atac loss:0.06605735421180725
Epoch: 0 rna loss:0.011846805922687054 atac loss:0.06581301987171173
Epoch: 0 rna loss:0.011700717732310295 atac loss:0.07230173051357269
Epoch: 0 rna loss:0.0110210906714