In [32]:
import scanpy as sc

import os
from os.path import join
import scipy.io as sio
import scipy.sparse as sps
import gc
import numpy as np
import pandas as pd
import scvi
import matplotlib.pyplot as plt

In [2]:
def reorder(ad1, ad2):
    shared_barcodes = ad1.obs_names.intersection(ad2.obs_names)
    ad1 = ad1[shared_barcodes].copy()
    ad2 = ad2[shared_barcodes].copy()
    return ad1, ad2

def load_peak_expr(_dir):
    data = sio.mmread(join(_dir, 'data.mtx'))
    cname = pd.read_csv(join(_dir, 'barcode.csv'), index_col=0)['x'].to_list()
    feat = pd.read_csv(join(_dir, 'feat.csv'), index_col=0)['x'].to_list()
    ad = sc.AnnData(sps.csr_matrix(data.T))
    ad.obs_names = cname
    ad.var_names = feat
    return ad

In [3]:
data_dir = '/disco_500t/xuhua/data/real_mosaic_cases/mouse_brain_rna+atac/'

df1_rna = pd.read_csv(join(data_dir, 'rna+atac/GSM6204636_MouseBrain_20um_matrix.tsv'), sep='\t')
df1_spatial_pos = pd.read_csv(join(data_dir, 'rna+atac/GSM6204623_MouseBrain_20um_spatial_rna_part/tissue_positions_list.csv'), header=None, index_col=0)
ad1_rna = sc.AnnData(df1_rna.T, obsm={'spatial': df1_spatial_pos.loc[df1_rna.columns, [2, 3]].values})

ad1_atac = load_peak_expr(join(data_dir, 'rna+atac/For_Imputation_Task/GSM6204623_peak_data'))
df1_atac_spatial = pd.read_csv(join(data_dir, 'rna+atac/GSM6204623_MouseBrain_20um_spatial_rna_part/tissue_positions_list.csv'), index_col=0, header=None)
ad1_atac.obsm['spatial'] = df1_atac_spatial.loc[ad1_atac.obs_names, [2, 3]].values
ad1_rna, ad1_atac = reorder(ad1_rna, ad1_atac)

# ===
df2_rna = pd.read_csv(join(data_dir, 'rna+atac/GSM6753041_MouseBrain_20um_repATAC_matrix.tsv'), sep='\t')
df2_rna_spatial = pd.read_csv(join(data_dir, 'rna+atac/GSM6753041_MouseBrain_20um_repATAC_spatial/tissue_positions_list.csv'), index_col=0, header=None)
ad2_rna = sc.AnnData(df2_rna.T, obsm={'spatial': df2_rna_spatial.loc[df2_rna.columns, [2, 3]].values})

ad2_atac = load_peak_expr(join(data_dir, 'rna+atac/For_Imputation_Task/GSM6758284_peak_data'))
df2_atac_spatial = pd.read_csv(join(data_dir, 'rna+atac//GSM6753041_MouseBrain_20um_repATAC_spatial/tissue_positions_list.csv'), index_col=0, header=None)
ad2_atac.obsm['spatial'] = df2_atac_spatial.loc[ad2_atac.obs_names, [2, 3]].values
ad2_rna, ad2_atac = reorder(ad2_rna, ad2_atac)

# ===
df3_rna = pd.read_csv(join(data_dir, 'rna+atac/GSM6753043_MouseBrain_20um_100barcodes_ATAC_matrix.tsv'), sep='\t')
df3_rna_spatial = pd.read_csv(join(data_dir, 'rna+atac/GSM6753043_MouseBrain_20um_100barcodes_ATAC_spatial/tissue_positions_list.csv'), index_col=0, header=None)
ad3_rna = sc.AnnData(df3_rna.T, obsm={'spatial': df3_rna_spatial.loc[df3_rna.columns, [2, 3]].values})

ad3_atac = load_peak_expr(join(data_dir, 'rna+atac/For_Imputation_Task/GSM6758285_peak_data'))
df3_atac_spatial = pd.read_csv(join(data_dir, 'rna+atac//GSM6753043_MouseBrain_20um_100barcodes_ATAC_spatial/tissue_positions_list.csv'), index_col=0, header=None)
ad3_atac.obsm['spatial'] = df3_atac_spatial.loc[ad3_atac.obs_names, [2, 3]].values
ad3_rna, ad3_atac = reorder(ad3_rna, ad3_atac)

shared_gene = ad1_rna.var_names.intersection(ad2_rna.var_names).intersection(ad3_rna.var_names)
shared_peak = ad1_atac.var_names.intersection(ad2_atac.var_names).intersection(ad3_atac.var_names)
ad1_rna = ad1_rna[:, shared_gene].copy(); ad2_rna = ad2_rna[:, shared_gene].copy(); ad3_rna = ad3_rna[:, shared_gene].copy()
ad1_atac = ad1_atac[:, shared_peak].copy(); ad2_atac = ad2_atac[:, shared_peak].copy(); ad3_atac = ad3_atac[:, shared_peak].copy()

ad1_rna.obs_names = [f's1-{_}' for _ in ad1_rna.obs_names]
ad1_atac.obs_names = [f's1-{_}' for _ in ad1_atac.obs_names]
ad2_rna.obs_names = [f's2-{_}' for _ in ad2_rna.obs_names]
ad2_atac.obs_names = [f's2-{_}' for _ in ad2_atac.obs_names]
ad3_rna.obs_names = [f's3-{_}' for _ in ad3_rna.obs_names]
ad3_atac.obs_names = [f's3-{_}' for _ in ad3_atac.obs_names]

ad1_rna.obs['src'] = ['s1']*ad1_rna.n_obs
ad1_atac.obs['src'] = ['s1']*ad1_atac.n_obs
ad2_rna.obs['src'] = ['s2']*ad2_rna.n_obs
ad2_atac.obs['src'] = ['s2']*ad2_atac.n_obs
ad3_rna.obs['src'] = ['s3']*ad3_rna.n_obs
ad3_atac.obs['src'] = ['s3']*ad3_atac.n_obs

In [4]:
ad_rna_all = sc.concat([ad1_rna, ad2_rna, ad3_rna])
ad_atac_all = sc.concat([ad1_atac, ad2_atac, ad3_atac])

sc.pp.highly_variable_genes(ad_rna_all, flavor='seurat_v3', n_top_genes=5000, batch_key='src')
hvg_names = ad_rna_all.var.query('highly_variable').index.to_numpy()

# ac.pp.tfidf(ad_atac_all, scale_factor=1e4)
sc.pp.highly_variable_genes(ad_atac_all, flavor='seurat_v3', n_top_genes=50000, batch_key='src')
hvp_names = ad_atac_all.var.query('highly_variable').index.to_numpy()

In [5]:
ad1_rna = ad1_rna[:, hvg_names].copy(); ad1_atac = ad1_atac[:, hvp_names].copy()
ad2_rna = ad2_rna[:, hvg_names].copy(); ad2_atac = ad2_atac[:, hvp_names].copy()
ad3_rna = ad3_rna[:, hvg_names].copy(); ad3_atac = ad3_atac[:, hvp_names].copy()

## filter feat names
filtered_atac_feats = [_ for _ in ad1_atac.var_names if _.startswith('chr')]
ad1_atac = ad1_atac[:, filtered_atac_feats].copy()
ad2_atac = ad2_atac[:, filtered_atac_feats].copy()
ad3_atac = ad3_atac[:, filtered_atac_feats].copy()

In [6]:
RNA_ADS = [ad1_rna, ad2_rna, ad3_rna]
ATAC_ADS = [ad1_atac, ad2_atac, ad3_atac]
IDS = [np.arange(ad1_rna.n_obs), ad1_rna.n_obs+np.arange(ad2_rna.n_obs), ad1_rna.n_obs+ad2_rna.n_obs+np.arange(ad3_rna.n_obs)]
n_batches = 3

In [7]:
output_dir = './MB_3slices_RNA+ATAC'
os.makedirs(output_dir, exist_ok=True)

In [8]:
for i in range(n_batches):
    print(f'==> cv{i+1}')
    train_idx = list(set(np.arange(n_batches)) - set({i}))
    RNA_data = sc.concat([RNA_ADS[idx] for idx in train_idx])
    ATAC_data = sc.concat([ATAC_ADS[idx] for idx in train_idx])   # input raw count data
    test_RNA_data = RNA_ADS[i]
    adata_paired = sc.concat([RNA_data, ATAC_data], merge='same', axis=1)
    adata_paired.var_names_make_unique()
    adata_paired.var['modality'] = ['Gene Expression']*RNA_data.shape[1]+['Peaks']*ATAC_data.shape[1]
    test_RNA_data.var['modality'] = ['Gene Expression']*test_RNA_data.shape[1]

    adata_mvi = scvi.data.organize_multiome_anndatas(adata_paired, rna_anndata=test_RNA_data)
    adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()
    
    scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality')
    mvi = scvi.model.MULTIVI(
        adata_mvi,
        n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
        n_regions=(adata_mvi.var['modality']=='Peaks').sum()
    )
    mvi.view_anndata_setup()
    mvi.train(max_epochs=100,use_gpu='cuda:1')
    imputed_accessibility = mvi.get_accessibility_estimates()
    
    pred = sc.AnnData(imputed_accessibility[adata_paired.n_obs:])
    obs_name = [name.rsplit('_',1)[0] for name in list(pred.obs_names)]
    
    pred.obs_names = obs_name
    pred.write_h5ad(join(output_dir, f'cv{i}_imputedATAC.h5ad'))

==> cv1


  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)

For instance checks, use `isinstance(X, (anndata.experimental.CSRDataset, anndata.experimental.CSCDataset))` instead.

For creation, use `anndata.experimental.sparse_dataset(X)` instead.

  return _abc_instancecheck(cls, instance)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


  self.pid = os.fork()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
  self.pid = os.fork()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Epoch 1/100:   0%|                                                                                                                                                                        | 0/100 [00:00<?, ?it/s]


For instance checks, use `isinstance(X, (anndata.experimental.CSRDataset, anndata.experimental.CSCDataset))` instead.

For creation, use `anndata.experimental.sparse_dataset(X)` instead.

  return _abc_instancecheck(cls, instance)


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [16:16<00:00,  9.60s/it, loss=2.19e+03, v_num=1]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [16:16<00:00,  9.76s/it, loss=2.19e+03, v_num=1]
==> cv2


  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)


  self.pid = os.fork()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
  self.pid = os.fork()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [12:09<00:00,  7.46s/it, loss=2.08e+03, v_num=1]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [12:09<00:00,  7.29s/it, loss=2.08e+03, v_num=1]
==> cv3


  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)


  self.pid = os.fork()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
  self.pid = os.fork()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [12:00<00:00,  7.12s/it, loss=1.16e+03, v_num=1]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [12:00<00:00,  7.20s/it, loss=1.16e+03, v_num=1]
