In [12]:
import sys
sys.path.append('../')

import scanpy as sc
import numpy as np
import pandas as pd

import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

import scDART.utils as utils
import scDART.TI as ti
import scDART

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Train model

### PBMC

In [5]:
pbmc_atac = sc.read_h5ad('/data/pbmc_10x/ATAC/adata_atac.h5ad')
pbmc_rna = sc.read_h5ad('/data/pbmc_10x/RNA/adata_rna.h5ad')
pbmc_rna, pbmc_atac

(AnnData object with n_obs × n_vars = 10412 × 36601
     obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'seurat_annotations', 'domain', 'cell_type',
 AnnData object with n_obs × n_vars = 10412 × 108377
     obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'seurat_annotations', 'domain', 'cell_type')

In [3]:
gact = pd.read_csv('pbmc_gact.csv')
gact = gact.set_index('peak')
gact['peak'] = gact.index.values

valid_peak = np.array([_.replace("_", '-') for _ in gact.peak.to_numpy()])
valid_gene = np.array(list(pbmc_rna.var_names.intersection(gact['gene.name'])))
len(valid_peak), len(valid_gene)

(72372, 11046)

In [4]:
# reduce peak number
count_atac = (pbmc_atac[:, valid_peak].X.A > 0).astype('float')
hv_peak = count_atac.sum(axis=0) >= 0.05*count_atac.shape[0]

hv_valid_peak = valid_peak[hv_peak]
count_atac = (pbmc_atac[:, hv_valid_peak].X.A > 0).astype('float')

hv_gene = gact.loc[[_.replace('-', '_') for _ in hv_valid_peak], 'gene.name'].to_numpy()
hv_valid_gene = np.array(list(pbmc_rna.var_names.intersection(hv_gene)))
count_rna = pbmc_rna[:, hv_valid_gene].X.A

count_rna.shape, count_atac.shape

((10412, 7280), (10412, 21855))

In [5]:
# count_rna = pbmc_rna[:, valid_gene].X.A
# count_atac = (pbmc_atac[:, valid_peak].X.A > 0).astype('float')
# count_rna.shape, count_atac.shape

In [6]:
count_rna = count_rna/np.sum(count_rna, axis = 1)[:, None] * 100
count_rna = np.log1p(count_rna)
count_atac = (count_atac > 0).astype(np.float)

In [7]:
pbmc_coarse_reg = np.zeros((len(hv_valid_peak), len(hv_valid_gene)))
peak2idx = {v:i for i,v in enumerate(hv_valid_peak)}  # chr1-x-x
gene2idx = {v:i for i,v in enumerate(hv_valid_gene)}

for _, r in gact.iterrows():  # chr_x_x
    gn = r['gene.name']
    pn = r['peak'].replace('_', '-')
    if (gn not in gene2idx) or (pn not in peak2idx):
        continue
    rix = peak2idx[pn]
    cix = gene2idx[gn]
    pbmc_coarse_reg[rix, cix] = 1

In [8]:
pbmc_coarse_reg.sum()

17620.0

In [9]:
# all in one
seeds = [0]
latent_dim = 4
learning_rate = 3e-4
n_epochs = 500
use_anchor = False
reg_d = 1
reg_g = 1
reg_mmd = 1
ts = [30, 50, 70]
use_potential = True

label_rna = pbmc_rna.obs.seurat_annotations.to_numpy()
label_atac = pbmc_atac.obs.seurat_annotations.to_numpy()
coarse_reg = pbmc_coarse_reg

scDART_op = scDART.scDART(n_epochs = n_epochs, latent_dim = latent_dim, batch_size=512, \
        ts = ts, use_anchor = use_anchor, use_potential = use_potential, k = 10, \
        reg_d = 1, reg_g = 1, reg_mmd = 1, l_dist_type = 'kl', seed = seeds[0],\
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

scDART_op = scDART_op.fit(rna_count = count_rna, 
                          atac_count = count_atac, 
                          reg = coarse_reg, 
                          rna_anchor = None, 
                          atac_anchor = None)
z_rna, z_atac = scDART_op.transform(rna_count = count_rna, 
                                    atac_count = count_atac, 
                                    rna_anchor = None, 
                                    atac_anchor = None)

Loaded Dataset
Calculating diffusion dist
running time(sec): 619.0065658092499
running time(sec): 591.1360161304474
running time(sec): 567.9182906150818
running time(sec): 530.4377608299255
running time(sec): 548.0922586917877
running time(sec): 568.3229320049286
init model
Training
epoch:  0
	 mmd loss: 0.123
	 ATAC dist loss: 0.316
	 RNA dist loss: 0.120
	 gene activity loss: 55581.633
	 anchor matching loss: 0.000
epoch:  100
	 mmd loss: 0.047
	 ATAC dist loss: 0.039
	 RNA dist loss: 0.026
	 gene activity loss: 36.223
	 anchor matching loss: 0.000
epoch:  200
	 mmd loss: 0.045
	 ATAC dist loss: 0.040
	 RNA dist loss: 0.025
	 gene activity loss: 1.348
	 anchor matching loss: 0.000
epoch:  300
	 mmd loss: 0.041
	 ATAC dist loss: 0.036
	 RNA dist loss: 0.024
	 gene activity loss: 0.174
	 anchor matching loss: 0.000
epoch:  400
	 mmd loss: 0.041
	 ATAC dist loss: 0.033
	 RNA dist loss: 0.023
	 gene activity loss: 0.143
	 anchor matching loss: 0.000
epoch:  500
	 mmd loss: 0.040
	 ATAC d

In [10]:
z_rna.shape, z_atac.shape

((10412, 4), (10412, 4))

In [11]:
import torch
import scDART.dataset as dataset
from torch.utils.data import DataLoader

rna_dataset = dataset.dataset(count_rna, None)
atac_dataset = dataset.dataset(count_atac, None)

# batch_size = int(max([len(self.rna_dataset),len(self.atac_dataset)])/4) if self.batch_size is None else self.batch_size
test_rna_loader = DataLoader(rna_dataset, batch_size = len(rna_dataset), shuffle = False)
test_atac_loader = DataLoader(atac_dataset, batch_size = len(atac_dataset), shuffle = False)

with torch.no_grad():
    for data in test_atac_loader:
        rna_atac = scDART_op.model_dict["gene_act"](data['count'].to('cuda')).cpu().detach()


In [18]:
ad_prna = sc.AnnData(rna_atac.numpy(), obs=pbmc_rna.obs.copy())

ad_prna.var_names = hv_valid_gene
ad_prna.write_h5ad('./adata_pbmc_pseudo_rna.h5ad')