In [6]:
import sys
sys.path.append('../')

import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from sklearn.decomposition import PCA

import scDART.utils as utils
import scDART.TI as ti
import scDART

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Train model

### HFA-50k

In [7]:
ad_atac = sc.read_h5ad('/data/HumanFetal_50k/ATAC/adata_atac_raw.h5ad')
ad_rna = sc.read_h5ad('/data/HumanFetal_50k/RNA/adata_rna_raw.h5ad')

ad_rna = ad_rna[:, ~ad_rna.var.gene_short_name.duplicated().to_numpy()].copy()
ad_rna.var_names = ad_rna.var.gene_short_name.to_numpy()
ad_rna, ad_atac

(AnnData object with n_obs × n_vars = 20000 × 56622
     obs: 'All_reads', 'Assay', 'Batch', 'Development_day', 'Exon_reads', 'Experiment_batch', 'Fetus_id', 'Intron_reads', 'Main_cluster_name', 'Main_cluster_umap_1', 'Main_cluster_umap_2', 'Organ', 'Organ_cell_lineage', 'RT_group', 'Sex', 'Size_Factor', 'batch', 'obs_names', 'sample', 'cell_type', 'domain'
     var: 'exon_intron', 'gene_id', 'gene_short_name', 'gene_type', 'index', 'var_names',
 AnnData object with n_obs × n_vars = 30000 × 1050819
     obs: 'cell_type', 'domain')

In [3]:
gact = pd.read_csv('HFA50k_gact.csv')
gact = gact.set_index('peak')
gact['peak'] = gact.index.values

valid_peak = gact.peak.to_numpy()
valid_gene = np.array(np.intersect1d(ad_rna.var_names, gact['gene.name']))
len(valid_peak), len(valid_gene)

(625315, 32293)

In [5]:
# reduce peak number
count_atac = (ad_atac[:, valid_peak].X.A > 0).astype('float32')
hv_peak = count_atac.sum(axis=0) >= 0.04*count_atac.shape[0]

hv_valid_peak = valid_peak[hv_peak]
count_atac = (ad_atac[:, hv_valid_peak].X.A > 0).astype('float')

hv_gene = gact.loc[hv_valid_peak, 'gene.name'].to_numpy()
hv_valid_gene = np.array(np.intersect1d(ad_rna.var_names, hv_gene))
count_rna = ad_rna[:, hv_valid_gene].X.A

count_rna.shape, count_atac.shape

((20000, 9800), (30000, 11077))

In [6]:
count_rna = count_rna/np.sum(count_rna, axis = 1)[:, None] * 100
count_rna = np.log1p(count_rna)
count_atac = (count_atac > 0).astype(np.float)

In [7]:
coarse_reg = np.zeros((len(hv_valid_peak), len(hv_valid_gene)))
peak2idx = {v:i for i,v in enumerate(hv_valid_peak)}  # chr1_x_x
gene2idx = {v:i for i,v in enumerate(hv_valid_gene)}

for _, r in gact.loc[hv_valid_peak].iterrows():  # chr_x_x
    gn = r['gene.name']
    pn = r['peak']
    if (gn not in gene2idx) or (pn not in peak2idx):
        continue
    rix = peak2idx[pn]
    cix = gene2idx[gn]
    coarse_reg[rix, cix] = 1

In [8]:
coarse_reg.sum(), coarse_reg.shape

(11043.0, (11077, 9800))

In [9]:
# all in one
seeds = [0]
latent_dim = 4
learning_rate = 3e-4
n_epochs = 500
use_anchor = False
reg_d = 1
reg_g = 1
reg_mmd = 1
ts = [30, 50, 70]
use_potential = True

label_rna = ad_rna.obs.cell_type.to_numpy()
label_atac = ad_atac.obs.cell_type.to_numpy()

scDART_op = scDART.scDART(n_epochs = n_epochs, latent_dim = latent_dim, batch_size=512, \
        ts = ts, use_anchor = use_anchor, use_potential = use_potential, k = 10, \
        reg_d = 1, reg_g = 1, reg_mmd = 1, l_dist_type = 'kl', seed = seeds[0],\
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

scDART_op = scDART_op.fit(rna_count = count_rna, 
                          atac_count = count_atac, 
                          reg = coarse_reg, 
                          rna_anchor = None, 
                          atac_anchor = None)
z_rna, z_atac = scDART_op.transform(rna_count = count_rna, 
                                    atac_count = count_atac, 
                                    rna_anchor = None, 
                                    atac_anchor = None)

Loaded Dataset
Calculating diffusion dist
init model
Training
epoch:  0
	 mmd loss: 0.045
	 ATAC dist loss: 0.354
	 RNA dist loss: 0.149
	 gene activity loss: 36229.246
	 anchor matching loss: 0.000
epoch:  100
	 mmd loss: 0.031
	 ATAC dist loss: 0.091
	 RNA dist loss: 0.051
	 gene activity loss: 1.538
	 anchor matching loss: 0.000
epoch:  200
	 mmd loss: 0.111
	 ATAC dist loss: 0.324
	 RNA dist loss: 0.064
	 gene activity loss: 1077.469
	 anchor matching loss: 0.000
epoch:  300
	 mmd loss: 0.027
	 ATAC dist loss: 0.050
	 RNA dist loss: 0.049
	 gene activity loss: 0.042
	 anchor matching loss: 0.000
epoch:  400
	 mmd loss: 0.027
	 ATAC dist loss: 0.048
	 RNA dist loss: 0.049
	 gene activity loss: 0.042
	 anchor matching loss: 0.000
epoch:  500
	 mmd loss: 0.026
	 ATAC dist loss: 0.046
	 RNA dist loss: 0.049
	 gene activity loss: 0.018
	 anchor matching loss: 0.000
Fit finished
Transform finished


In [11]:
z_rna.shape, z_atac.shape

((20000, 4), (30000, 4))

In [12]:
import torch
import scDART.dataset as dataset
from torch.utils.data import DataLoader

rna_dataset = dataset.dataset(count_rna, None)
atac_dataset = dataset.dataset(count_atac, None)

# batch_size = int(max([len(self.rna_dataset),len(self.atac_dataset)])/4) if self.batch_size is None else self.batch_size
test_rna_loader = DataLoader(rna_dataset, batch_size = len(rna_dataset), shuffle = False)
test_atac_loader = DataLoader(atac_dataset, batch_size = len(atac_dataset), shuffle = False)

with torch.no_grad():
    for data in test_atac_loader:
        rna_atac = scDART_op.model_dict["gene_act"](data['count'].to('cuda')).cpu().detach()


In [8]:
ad_prna = sc.AnnData(rna_atac.numpy(), obs=ad_atac.obs.copy())
ad_prna.write_h5ad('./adata_hfa50k_pseudo_rna.h5ad')