In [1]:
import scanpy as sc

import os
from os.path import join
import scipy.io as sio
import scipy.sparse as sps
import gc
import numpy as np
import pandas as pd
import scvi
import h5py
import matplotlib.pyplot as plt

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


In [14]:
def subset_ad(ad, subset_index):
    ad = ad[subset_index].copy()
    return ad

def reorder(ad1, ad2):
    shared_barcodes = ad1.obs_names.intersection(ad2.obs_names)
    ad1 = ad1[shared_barcodes].copy()
    ad2 = ad2[shared_barcodes].copy()
    return ad1, ad2
   
def load_h5(path):
    with h5py.File(path, 'r') as f:
        print(f['matrix'].keys())
        print(f['matrix']['features'].keys())

        barcodes = [_.decode('utf-8') for _ in f['matrix']['barcodes'][:]]
        data = f['matrix']['data'][:]
        indices = f['matrix']['indices'][:]
        indptr = f['matrix']['indptr'][:]
        shape = f['matrix']['shape'][:]

        feature_type = [_.decode('utf-8') for _ in f['matrix']['features']['feature_type'][:]]
        feature_id   = [_.decode('utf-8') for _ in f['matrix']['features']['id'][:]]
        feature_name = [_.decode('utf-8') for _ in f['matrix']['features']['name'][:]]
        feature_interval = [_.decode('utf-8') for _ in f['matrix']['features']['interval'][:]]


        X = sps.csc_matrix(
            (data, indices, indptr), 
            shape = shape
        ).tocsc().astype(np.float32).T.toarray()

        adata = sc.AnnData(X)
        adata.obs_names = barcodes
        adata.var_names = feature_id
        adata.var['type'] = feature_type
        adata.var['name'] = feature_name
        adata.var['interval'] = feature_interval
    return adata

def load_peak_expr(_dir):
    data = sio.mmread(join(_dir, 'data.mtx'))
    cname = pd.read_csv(join(_dir, 'barcode.csv'), index_col=0).index.to_list()
    feat = pd.read_csv(join(_dir, 'feat.csv'), index_col=0)['x'].to_list()
    ad = sc.AnnData(sps.csr_matrix(data.T))
    ad.obs_names = cname
    ad.var_names = feat
    return ad

In [15]:
data_dir = '/disco_500t/xuhua/data/MISAR_seq/'
ad_bridge = load_h5(join(data_dir, 'E15_5-S1_raw_feature_bc_matrix.h5'))
ad_test1 = load_h5(join(data_dir, 'E13_5-S1_raw_feature_bc_matrix.h5'))
ad_test2 = load_h5(join(data_dir, 'E18_5-S1_raw_feature_bc_matrix.h5'))  # inconsistent peak name across batches
peak_mat = sps.csr_matrix(sio.mmread(join(data_dir, 'BaiduDisk/section1/peak_mat.mtx')).T)
peak_spot_name = pd.read_csv(join(data_dir, 'BaiduDisk/section1/peak_spot_names.csv')).x.values

meta = pd.read_csv(join(data_dir, 'BaiduDisk/section1/meta_data.csv'), index_col=0)

ad_bridge.obs_names = [f'E15_5-S1#{_}' for _ in ad_bridge.obs_names]
ad_test1.obs_names = [f'E13_5-S1#{_}' for _ in ad_test1.obs_names]
ad_test2.obs_names = [f'E18_5-S1#{_}' for _ in ad_test2.obs_names]

# split rna and peak
ad15_rna = ad_bridge[:, ad_bridge.var['type'] == 'Gene Expression'].copy()
ad13_rna = ad_test1[:, ad_test1.var['type'] == 'Gene Expression'].copy()
ad18_rna = ad_test2[:, ad_test2.var['type'] == 'Gene Expression'].copy()

# subset peak matrices
# bridge_mask = np.in1d(peak_spot_name, ad_bridge.obs_names)
# ad15_atac = sc.AnnData(peak_mat[bridge_mask])
# ad15_atac.obs_names = peak_spot_name[bridge_mask]
# test1_mask = np.in1d(peak_spot_name, ad_test1.obs_names)
# ad13_atac = sc.AnnData(peak_mat[test1_mask])
# ad13_atac.obs_names = peak_spot_name[test1_mask]
# test2_mask = np.in1d(peak_spot_name, ad_test2.obs_names)
# ad18_atac = sc.AnnData(peak_mat[test2_mask])
# ad18_atac.obs_names = peak_spot_name[test2_mask]
ad13_atac = load_peak_expr(join(data_dir, 'S1-E13-E15-18-peak_data/E13'))
ad15_atac = load_peak_expr(join(data_dir, 'S1-E13-E15-18-peak_data/E15'))
ad18_atac = load_peak_expr(join(data_dir, 'S1-E13-E15-18-peak_data/E18'))
ad13_atac.obs_names = [f'E13_5-S1#{_}' for _ in ad13_atac.obs_names]
ad15_atac.obs_names = [f'E15_5-S1#{_}' for _ in ad15_atac.obs_names]
ad18_atac.obs_names = [f'E18_5-S1#{_}' for _ in ad18_atac.obs_names]

ad15_rna = subset_ad(ad15_rna, ad15_rna.obs_names.intersection(meta.index))
ad13_rna = subset_ad(ad13_rna, ad13_rna.obs_names.intersection(meta.index))
ad18_rna = subset_ad(ad18_rna, ad18_rna.obs_names.intersection(meta.index))

ad15_rna.obs = meta.loc[ad15_rna.obs_names].copy()
ad15_atac.obs = meta.loc[ad15_atac.obs_names].copy()
ad13_rna.obs = meta.loc[ad13_rna.obs_names].copy()
ad13_atac.obs = meta.loc[ad13_atac.obs_names].copy()
ad18_rna.obs = meta.loc[ad18_rna.obs_names].copy()
ad18_atac.obs = meta.loc[ad18_atac.obs_names].copy()

ad15_atac = ad15_atac[ad15_rna.obs_names].copy()
ad13_atac = ad13_atac[ad13_rna.obs_names].copy()
ad18_rna  = ad18_rna[ad18_atac.obs_names].copy()  # 这是要和当初E18取atac的obs_name顺序一致

del peak_mat, ad_bridge, ad_test1, ad_test2
gc.collect()

<KeysViewHDF5 ['barcodes', 'data', 'features', 'indices', 'indptr', 'shape']>
<KeysViewHDF5 ['_all_tag_keys', 'feature_type', 'genome', 'id', 'interval', 'name']>
<KeysViewHDF5 ['barcodes', 'data', 'features', 'indices', 'indptr', 'shape']>
<KeysViewHDF5 ['_all_tag_keys', 'feature_type', 'genome', 'id', 'interval', 'name']>
<KeysViewHDF5 ['barcodes', 'data', 'features', 'indices', 'indptr', 'shape']>
<KeysViewHDF5 ['_all_tag_keys', 'feature_type', 'genome', 'id', 'interval', 'name']>


4708

In [28]:
ad13_rna.obs['src'] = ad13_atac.obs['src'] = ['e13']*ad13_rna.n_obs
ad15_rna.obs['src'] = ad15_atac.obs['src'] = ['e15']*ad15_rna.n_obs
ad18_rna.obs['src'] = ad18_atac.obs['src'] = ['e18']*ad18_rna.n_obs

In [29]:
ad_rna_all = sc.concat([ad13_rna, ad15_rna, ad18_rna])
ad_atac_all = sc.concat([ad13_atac, ad15_atac, ad18_atac])

sc.pp.highly_variable_genes(ad_rna_all, flavor='seurat_v3', n_top_genes=5000, batch_key='src')
hvg_names = ad_rna_all.var.query('highly_variable').index.to_numpy()

sc.pp.highly_variable_genes(ad_atac_all, flavor='seurat_v3', n_top_genes=50000, batch_key='src')
hvp_names = ad_atac_all.var.query('highly_variable').index.to_numpy()

In [30]:
ad13_rna = ad13_rna[:, hvg_names].copy(); ad13_atac = ad13_atac[:, hvp_names].copy()
ad15_rna = ad15_rna[:, hvg_names].copy(); ad15_atac = ad15_atac[:, hvp_names].copy()
ad18_rna = ad18_rna[:, hvg_names].copy(); ad18_atac = ad18_atac[:, hvp_names].copy()

## filter feat names
filtered_atac_feats = [_ for _ in ad13_atac.var_names if _.startswith('chr')]
ad13_atac = ad13_atac[:, filtered_atac_feats].copy()
ad15_atac = ad15_atac[:, filtered_atac_feats].copy()
ad18_atac = ad18_atac[:, filtered_atac_feats].copy()

In [31]:
ad13_atac.shape, ad13_rna.shape

((1777, 49988), (1777, 5000))

In [32]:
RNA_ADS = [ad13_rna, ad15_rna, ad18_rna]
ATAC_ADS = [ad13_atac, ad15_atac, ad18_atac]
# IDS = [np.arange(ad13_rna.n_obs), ad1_rna.n_obs+np.arange(ad2_rna.n_obs), ad1_rna.n_obs+ad2_rna.n_obs+np.arange(ad3_rna.n_obs)]
n_batches = 3

In [33]:
output_dir = './Misar_E13-E15-E18'
os.makedirs(output_dir, exist_ok=True)

In [34]:
for i in range(n_batches):
    print(f'==> cv{i+1}')
    train_idx = list(set(np.arange(n_batches)) - set({i}))
    RNA_data = sc.concat([RNA_ADS[idx] for idx in train_idx])
    ATAC_data = sc.concat([ATAC_ADS[idx] for idx in train_idx])   # input raw count data
    test_RNA_data = RNA_ADS[i]
    adata_paired = sc.concat([RNA_data, ATAC_data], merge='same', axis=1)
    adata_paired.var_names_make_unique()
    adata_paired.var['modality'] = ['Gene Expression']*RNA_data.shape[1]+['Peaks']*ATAC_data.shape[1]
    test_RNA_data.var['modality'] = ['Gene Expression']*test_RNA_data.shape[1]

    adata_mvi = scvi.data.organize_multiome_anndatas(adata_paired, rna_anndata=test_RNA_data)
    adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()
    
    scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality')
    mvi = scvi.model.MULTIVI(
        adata_mvi,
        n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
        n_regions=(adata_mvi.var['modality']=='Peaks').sum()
    )
    mvi.view_anndata_setup()
    mvi.train(max_epochs=100,use_gpu='cuda:1')
    imputed_accessibility = mvi.get_accessibility_estimates()
    
    pred = sc.AnnData(imputed_accessibility[adata_paired.n_obs:])
    obs_name = [name.rsplit('_',1)[0] for name in list(pred.obs_names)]
    
    pred.obs_names = obs_name
    pred.write_h5ad(join(output_dir, f'cv{i}_imputedATAC.h5ad'))

==> cv1


  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)


  self.pid = os.fork()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
  self.pid = os.fork()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [08:19<00:00,  5.22s/it, loss=3.38e+03, v_num=1]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [08:19<00:00,  5.00s/it, loss=3.38e+03, v_num=1]
==> cv2


  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)


  self.pid = os.fork()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
  self.pid = os.fork()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:43<00:00,  3.21s/it, loss=2.84e+03, v_num=1]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:43<00:00,  3.43s/it, loss=2.84e+03, v_num=1]
==> cv3


  return multi_anndata.concatenate(other, join="outer", batch_key=modality_key)


  self.pid = os.fork()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
  self.pid = os.fork()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:41<00:00,  3.15s/it, loss=2.98e+03, v_num=1]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:41<00:00,  3.41s/it, loss=2.98e+03, v_num=1]
