In [3]:
import os
import gc
import pandas as pd
import pickle
import matplotlib.pyplot as plt

import anndata as ad
import numpy as np
import yaml
import sys
import h5py
import logging
import scanpy as sc
from os.path import join
import scipy.io as sio
import scipy.sparse as sps
from sklearn.cluster import KMeans
import gzip
from scipy.io import mmread
from pathlib import Path, PurePath
from sklearn.metrics import adjusted_rand_score

import warnings
def wrap_warn_plot(adata, basis, color, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        sc.pl.embedding(adata, basis=basis, color=color, **kwargs)

def get_umap(ad, use_reps=[]):
    for use_rep in use_reps:
        umap_add_key = f'{use_rep}_umap'
        sc.pp.neighbors(ad, use_rep=use_rep, n_neighbors=15)
        sc.tl.umap(ad)
        ad.obsm[umap_add_key] = ad.obsm['X_umap']
    return ad

from sklearn.metrics import adjusted_rand_score

def split_ob(ads, ad_ref, ob='obs', key='emb2'):
    len_ads = [_.n_obs for _ in ads]
    if ob=='obsm':
        split_obsms = np.split(ad_ref.obsm[key], np.cumsum(len_ads[:-1]))
        for ad, v in zip(ads, split_obsms):
            ad.obsm[key] = v
    else:
        split_obs = np.split(ad_ref.obs[key].to_list(), np.cumsum(len_ads[:-1]))
        for ad, v in zip(ads, split_obs):
            ad.obs[key] = v

def eval_ads(ads, ref_key, src_key, exclude=[]):
    aris = []
    for ad in ads:
        _mask = ~ad.obs[ref_key].isin(exclude)
        gt = ad.obs[ref_key].values[_mask]
        pred = ad.obs[src_key].values[_mask]
        aris.append(adjusted_rand_score(pred, gt))
    return aris

from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans
def search_louvain(ad, use_rep, n_neighbors=15, n_clusters=5):
    sc.pp.neighbors(ad, n_neighbors=n_neighbors, use_rep=use_rep)
    rs = np.arange(0.1, 1.0, 0.1)
    n_cs = []
    for r in rs:
        sc.tl.louvain(ad, resolution=r, key_added=f'r={r}')
        n_cs.append(ad.obs[f'r={r}'].nunique())
    n_cs = np.array(n_cs)
    if (n_cs==n_clusters).sum() >= 1:
        ri = np.where(n_cs==n_clusters)[0][0]
        ad.obs['louvain_k'] = ad.obs[f'r={rs[ri]}'].to_list()
    else:
        kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(ad.obsm[use_rep])
        ad.obs['louvain_k'] = kmeans.labels_.astype('str')

from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import f1_score
def eval_labelTransfer(ad1, ad2, use_rep, lab_key, knn=10):
     with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=FutureWarning)
        neigh1 = KNeighborsClassifier(n_neighbors=knn)
        neigh1.fit(ad1.obsm[use_rep], ad1.obs[lab_key].to_list())
        pr_lab2 = neigh1.predict(ad2.obsm[use_rep])
        f1_1 = f1_score(ad2.obs[lab_key].values, pr_lab2, #labels=['1.0', '2.0', '3.0', '4.0'], 
                        average='macro')
        # acc1 = (pr_lab2 == ad2.obs[lab_key].values).mean()
    
        neigh2 = KNeighborsClassifier(n_neighbors=knn)
        neigh2.fit(ad2.obsm[use_rep], ad2.obs[lab_key].to_list())
        pr_lab1 = neigh2.predict(ad1.obsm[use_rep])
        # acc2 = (pr_lab1 == ad1.obs[lab_key].values).mean()
        f1_2 = f1_score(ad1.obs[lab_key].values, pr_lab1, #labels=['1.0', '2.0', '3.0', '4.0'], 
                        average='macro')
        return (f1_1+f1_2)/2

from scib.metrics import lisi
def eval_lisi(
        adata,
        batch_keys=['domain', 'batch'],
        label_keys = ['gt'],
        use_rep='X_emb', use_neighbors=False,
    ):
    res = {}
    for key in batch_keys:
        adata.obs[key] = adata.obs[key].astype('category')

        _lisi = lisi.ilisi_graph(
            adata,
            key,
            'embed' if not use_neighbors else 'knn',
            use_rep=use_rep,
            k0=90,
            subsample=None,
            scale=True,
            n_cores=1,
            verbose=False,
        )
        res[key+'_iLISI'] = _lisi
    for key in label_keys:
        adata.obs[key] = adata.obs[key].astype('category')

        _lisi = lisi.clisi_graph(
            adata,
            key,
            'embed' if not use_neighbors else 'knn',
            use_rep=use_rep,
            batch_key=None,
            k0=90,
            subsample=None,
            scale=True,
            n_cores=1,
            verbose=False,
        )
        res[key+'_cLISI'] = _lisi
    df = pd.DataFrame.from_dict(res, orient='index').T
    # df.columns = [_+'_LISI' for _ in df.columns]
    return df

os.environ['R_HOME'] = '/disco_500t/xuhua/miniforge3/envs/Seurat5/lib/R'
os.environ['R_USER'] = '/disco_500t/xuhua/miniforge3/envs/Seurat5/lib/python3.8/site-packages/rpy2'
def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
    np.random.seed(random_seed)
    import rpy2.robjects as robjects
    robjects.r.library("mclust")

    import rpy2.robjects.numpy2ri
    rpy2.robjects.numpy2ri.activate()
    r_random_seed = robjects.r['set.seed']
    r_random_seed(random_seed)
    rmclust = robjects.r['Mclust']

    res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
    mclust_res = np.array(res[-2])

    adata.obs['mclust'] = mclust_res
    adata.obs['mclust'] = adata.obs['mclust'].astype('int')
    adata.obs['mclust'] = adata.obs['mclust'].astype('category')
    return adata

   
def load_h5(path):
    with h5py.File(path, 'r') as f:
        print(f['matrix'].keys())
        print(f['matrix']['features'].keys())

        barcodes = [_.decode('utf-8') for _ in f['matrix']['barcodes'][:]]
        data = f['matrix']['data'][:]
        indices = f['matrix']['indices'][:]
        indptr = f['matrix']['indptr'][:]
        shape = f['matrix']['shape'][:]

        feature_type = [_.decode('utf-8') for _ in f['matrix']['features']['feature_type'][:]]
        feature_id   = [_.decode('utf-8') for _ in f['matrix']['features']['id'][:]]
        feature_name = [_.decode('utf-8') for _ in f['matrix']['features']['name'][:]]
        feature_interval = [_.decode('utf-8') for _ in f['matrix']['features']['interval'][:]]


        X = sps.csc_matrix(
            (data, indices, indptr), 
            shape = shape
        ).tocsc().astype(np.float32).T.toarray()

        adata = sc.AnnData(X)
        adata.obs_names = barcodes
        adata.var_names = feature_id
        adata.var['type'] = feature_type
        adata.var['name'] = feature_name
        adata.var['interval'] = feature_interval
    return adata

import json
import copy
from matplotlib.image import imread
def flip_coords(ads):
    for ad in ads:
        ad.obsm['spatial'] = -1 * ad.obsm['spatial']
        ad.obsm['spatial'] = ad.obsm['spatial'][:, ::-1]

def reorder(ad1, ad2):
    shared_barcodes = ad1.obs_names.intersection(ad2.obs_names)
    ad1 = ad1[shared_barcodes].copy()
    ad2 = ad2[shared_barcodes].copy()
    return ad1, ad2

def load_peak_expr(_dir):
    data = sio.mmread(join(_dir, 'data.mtx'))
    cname = pd.read_csv(join(_dir, 'barcode.csv'), index_col=0).index.to_list()
    feat = pd.read_csv(join(_dir, 'feat.csv'), index_col=0)['x'].to_list()
    ad = sc.AnnData(sps.csr_matrix(data.T))
    ad.obs_names = cname
    ad.var_names = feat
    return ad

In [4]:
def set_col2cat(ad, cols=[]):
    for col in cols:
        ad.obs[col] = ad.obs[col].astype('category')

def unify_colors(queries, color_key, ref_color_dict):
    for q in queries:
        q.obs[color_key] = q.obs[color_key].astype('category')
        q.uns[f'{color_key}_colors'] = [ref_color_dict[_] for _ in q.obs[color_key].cat.categories]
    return queries

def subset_ad(ad, subset_index):
    ad = ad[subset_index].copy()
    return ad

def set_spatial(ad):
    ad.obsm['spatial'] = ad.obs[['array_row', 'array_col']].values
    ad.obsm['spatial'] = ad.obsm['spatial'][:, ::-1]
    ad.obsm['spatial'][:, 1] = -1 * ad.obsm['spatial'][:, 1]
    return ad

In [5]:
from collections import Counter
def load_zu(_dir):
    zs = []

    for fi in sorted(os.listdir(_dir)):
        dfi = pd.read_csv(join(_dir, fi), header=None)
        zs.append(dfi.values)
    zs = np.vstack(zs)
    z, u = zs[:, :-2], zs[:, -2:]
    return z, u

def wrap_warn_plot(adata, basis, color, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        sc.pl.embedding(adata, basis=basis, color=color, **kwargs)

def get_umap(ad, use_reps=[]):
    for use_rep in use_reps:
        umap_add_key = f'{use_rep}_umap'
        sc.pp.neighbors(ad, use_rep=use_rep, n_neighbors=15)
        sc.tl.umap(ad)
        ad.obsm[umap_add_key] = ad.obsm['X_umap']
    return ad

In [7]:
data_dir = '/disco_500t/xuhua/data/MISAR_seq/'
ad_bridge = load_h5(join(data_dir, 'E15_5-S1_raw_feature_bc_matrix.h5'))
ad_test1 = load_h5(join(data_dir, 'E13_5-S1_raw_feature_bc_matrix.h5'))
ad_test2 = load_h5(join(data_dir, 'E18_5-S1_raw_feature_bc_matrix.h5'))  # inconsistent peak name across batches
peak_mat = sps.csr_matrix(sio.mmread(join(data_dir, 'BaiduDisk/section1/peak_mat.mtx')).T)
peak_spot_name = pd.read_csv(join(data_dir, 'BaiduDisk/section1/peak_spot_names.csv')).x.values

meta = pd.read_csv(join(data_dir, 'BaiduDisk/section1/meta_data.csv'), index_col=0)

ad_bridge.obs_names = [f'E15_5-S1#{_}' for _ in ad_bridge.obs_names]
ad_test1.obs_names = [f'E13_5-S1#{_}' for _ in ad_test1.obs_names]
ad_test2.obs_names = [f'E18_5-S1#{_}' for _ in ad_test2.obs_names]

# split rna and peak
ad15_rna = ad_bridge[:, ad_bridge.var['type'] == 'Gene Expression'].copy()
ad13_rna = ad_test1[:, ad_test1.var['type'] == 'Gene Expression'].copy()
ad18_rna = ad_test2[:, ad_test2.var['type'] == 'Gene Expression'].copy()

ad13_atac = load_peak_expr(join(data_dir, 'S1-E13-E15-18-peak_data/E13'))
ad15_atac = load_peak_expr(join(data_dir, 'S1-E13-E15-18-peak_data/E15'))
ad18_atac = load_peak_expr(join(data_dir, 'S1-E13-E15-18-peak_data/E18'))
ad13_atac.obs_names = [f'E13_5-S1#{_}' for _ in ad13_atac.obs_names]
ad15_atac.obs_names = [f'E15_5-S1#{_}' for _ in ad15_atac.obs_names]
ad18_atac.obs_names = [f'E18_5-S1#{_}' for _ in ad18_atac.obs_names]

ad15_rna = subset_ad(ad15_rna, ad15_rna.obs_names.intersection(meta.index))
ad13_rna = subset_ad(ad13_rna, ad13_rna.obs_names.intersection(meta.index))
ad18_rna = subset_ad(ad18_rna, ad18_rna.obs_names.intersection(meta.index))

ad15_rna.obs = meta.loc[ad15_rna.obs_names].copy()
ad15_atac.obs = meta.loc[ad15_atac.obs_names].copy()
ad13_rna.obs = meta.loc[ad13_rna.obs_names].copy()
ad13_atac.obs = meta.loc[ad13_atac.obs_names].copy()
ad18_rna.obs = meta.loc[ad18_rna.obs_names].copy()
ad18_atac.obs = meta.loc[ad18_atac.obs_names].copy()

ad15_atac = ad15_atac[ad15_rna.obs_names].copy()
ad13_atac = ad13_atac[ad13_rna.obs_names].copy()
ad18_rna  = ad18_rna[ad18_atac.obs_names].copy()  # 这是要和当初E18取atac的obs_name顺序一致

del peak_mat, ad_bridge, ad_test1, ad_test2
gc.collect()

<KeysViewHDF5 ['barcodes', 'data', 'features', 'indices', 'indptr', 'shape']>
<KeysViewHDF5 ['_all_tag_keys', 'feature_type', 'genome', 'id', 'interval', 'name']>
<KeysViewHDF5 ['barcodes', 'data', 'features', 'indices', 'indptr', 'shape']>
<KeysViewHDF5 ['_all_tag_keys', 'feature_type', 'genome', 'id', 'interval', 'name']>
<KeysViewHDF5 ['barcodes', 'data', 'features', 'indices', 'indptr', 'shape']>
<KeysViewHDF5 ['_all_tag_keys', 'feature_type', 'genome', 'id', 'interval', 'name']>


2131

In [8]:
ad13_rna.obs['src'] = ad13_atac.obs['src'] = ['e13']*ad13_rna.n_obs
ad15_rna.obs['src'] = ad15_atac.obs['src'] = ['e15']*ad15_rna.n_obs
ad18_rna.obs['src'] = ad18_atac.obs['src'] = ['e18']*ad18_rna.n_obs

In [9]:
# pd.DataFrame([x.split('#')[1] for x in ad13_rna.obs_names]).to_csv(join(data_dir, 'E13_filtered_barcode.csv'))
# pd.DataFrame([x.split('#')[1] for x in ad15_rna.obs_names]).to_csv(join(data_dir, 'E15_filtered_barcode.csv'))
# pd.DataFrame([x.split('#')[1] for x in ad18_rna.obs_names]).to_csv(join(data_dir, 'E18_filtered_barcode.csv'))

In [10]:
ad_rna_all = sc.concat([ad13_rna, ad15_rna, ad18_rna])
ad_atac_all = sc.concat([ad13_atac, ad15_atac, ad18_atac])

sc.pp.highly_variable_genes(ad_rna_all, flavor='seurat_v3', n_top_genes=5000, batch_key='src')
hvg_names = ad_rna_all.var.query('highly_variable').index.to_numpy()

sc.pp.highly_variable_genes(ad_atac_all, flavor='seurat_v3', n_top_genes=50000, batch_key='src')
hvp_names = ad_atac_all.var.query('highly_variable').index.to_numpy()

In [11]:
ad13_rna = ad13_rna[:, hvg_names].copy(); ad13_atac = ad13_atac[:, hvp_names].copy()
ad15_rna = ad15_rna[:, hvg_names].copy(); ad15_atac = ad15_atac[:, hvp_names].copy()
ad18_rna = ad18_rna[:, hvg_names].copy(); ad18_atac = ad18_atac[:, hvp_names].copy()

## filter feat names
filtered_atac_feats = [_ for _ in ad13_atac.var_names if _.startswith('chr')]
ad13_atac = ad13_atac[:, filtered_atac_feats].copy()
ad15_atac = ad15_atac[:, filtered_atac_feats].copy()
ad18_atac = ad18_atac[:, filtered_atac_feats].copy()

In [12]:
def split_list_byChr(input_list):
    atac_name_chunks = []
    for n in input_list:
        chr = n.split("-")[0]
        if len(atac_name_chunks) == 0 or chr!=atac_name_chunks[-1]:
            atac_name_chunks.append(chr)
    return atac_name_chunks

In [13]:
RNA_ADS = [ad13_rna, ad15_rna, ad18_rna]
ATAC_ADS = [ad13_atac, ad15_atac, ad18_atac]
mod_dict = {'rna': RNA_ADS, 'atac':ATAC_ADS}

In [15]:
for i in range(3):  # train test split
    mod = 'atac'
    tmp_out_dir = f'/disco_500t/xuhua/gitrepo/midas/data/processed/Misar-E13-15-18_cv{i+1}_missing{mod}'
    os.makedirs(tmp_out_dir, exist_ok=True)
    print(tmp_out_dir)
    feat_dir = join(tmp_out_dir, 'feat')
    os.makedirs(feat_dir, exist_ok=True)
    
    atac_feat_chunks = split_list_byChr(ad13_atac.var_names) # split the chr-? in order
    
    atac_chr_count = Counter([_.split('-')[0] for _ in ad13_atac.var_names]) # count the frequency of each chr-?
    df_feat_dims = pd.DataFrame(np.array([atac_chr_count[_] for _ in atac_feat_chunks]).reshape(-1, 1), columns=['atac'])
    df_feat_dims['rna'] = ad13_rna.n_vars
    df_feat_rna_names = pd.DataFrame(ad13_rna.var_names, columns=['x'])
    df_feat_adt_names = pd.DataFrame(ad13_atac.var_names, columns=['x'])
    df_feat_dims.to_csv(join(feat_dir, 'feat_dims.csv'))
    df_feat_rna_names.to_csv(join(feat_dir, 'feat_names_rna.csv'))
    df_feat_adt_names.to_csv(join(feat_dir, 'feat_names_atac.csv'))

    # # each subset
    subsets, mods = [], []
    for bi in range(3):
        tmp_set, tmp_mod = [], []
        for bmod in ['rna', 'atac']:
            if (bi==i) and (bmod == mod):
                continue
            tmp_set.append(mod_dict[bmod][bi])
            tmp_mod.append(bmod)
        subsets.append(tmp_set)
        mods.append(tmp_mod)
    
    # print(mods)
    for si in range(3):
        for fname in ['mask', 'mat', 'vec']:
            os.makedirs(join(tmp_out_dir, f'subset_{si}/{fname}'), exist_ok=True)
        tmp_dir = join(tmp_out_dir, f'subset_{si}')
        for ad,mi in zip(subsets[si], mods[si]):
            mat = ad.X.A if sps.issparse(ad.X) else ad.X
            df_mat = pd.DataFrame(mat, index=ad.obs_names, columns=ad.var_names)
            df_mat.to_csv(join(tmp_dir, f'mat/{mi}.csv'))
    
            os.makedirs(join(tmp_dir, f'vec/{mi}'), exist_ok=True)
            for idx, mati in enumerate(mat):
                pd.DataFrame(mati.reshape(1, -1)).to_csv(join(tmp_dir, 'vec/{}/{:05d}.csv'.format(mi, idx)), header=None, index=None)
            pd.DataFrame(ad.obs_names, columns=['x']).to_csv(join(tmp_dir, 'cell_names.csv'))
    
            # save mask
            pd.DataFrame(np.ones(ad.n_vars, dtype='int')).to_csv(join(tmp_dir, f'mask/{mi}.csv'))

In [14]:
e, ep = 1, 2000
for i in range(3):  # train test split
    for mod in ['atac']:   # missing mod
        training_command = f'CUDA_VISIBLE_DEVICES=1 python run.py --exp e{e} --task Misar-E13-15-18_cv{i+1}_missing{mod} --epoch_num {ep}'
    
        run_command = 'CUDA_VISIBLE_DEVICES=1 python run.py --task Misar-E13-15-18_cv{}_missing{} --act translate --init_model sp_{:08d} --exp e{}'\
                                    .format(i+1, mod, ep-1, e)
        print(training_command)
        print(run_command)

In [16]:
import csv
def csv_read(path):
    res = []
    with open(path, mode='r', newline='') as file:
        reader = csv.reader(file)

        for row in reader:
            try:
                float_row = [float(item) for item in row]
                res.append(float_row)  # Each row is now a list of floats
            except ValueError as e:
                print(f"Error converting to float: {e}")
    res = np.vstack(res)   
    return res

def collect_csv(_dir):
    fls = sorted(os.listdir(_dir))
    res = []
    for fl in fls:
        # df = pd.read_csv(join(_dir, fl), header=None)
        data = csv_read(join(_dir, fl))
        res.append(data)
    res = np.vstack(res)
    return res
    
import copy
def binarize(Xs, bin_thr=0):
    rs = []
    for X in Xs:
        X = copy.deepcopy(X.A) if sps.issparse(X) else copy.deepcopy(X)
        X[X>bin_thr] = 1
        rs.append(X)
    return rs

from sklearn.metrics import roc_auc_score
def eval_AUC_all(gt_X, pr_X, bin_thr=1):
    gt_X = binarize([gt_X], bin_thr)[0].flatten()
    pr_X = pr_X.flatten()
    auroc = roc_auc_score(gt_X, pr_X)
    return auroc

def PCCs(gt_X, pr_X):
    pcc_cell = [np.corrcoef(gt_X[i,:], pr_X[i,:])[0,1] for i in range(gt_X.shape[0])] 
    pcc_peak = [np.corrcoef(gt_X[:,i], pr_X[:,i])[0,1] for i in range(gt_X.shape[1])] 
    return pcc_cell, pcc_peak

def cal_cmd(pred, true):
    zero_rows_indices1 = list(np.where(~pred.any(axis=1))[0]) # all-zero rows
    zero_rows_indices2 = list(np.where(~true.any(axis=1))[0])
    zero_rows_indices = zero_rows_indices1 + zero_rows_indices2
    rm_p = len(zero_rows_indices) / pred.shape[0]
    if rm_p >= .05:
        print(f'Warning: two many rows {rm_p}% with all zeros')
    pred_array = pred[~np.isin(np.arange(pred.shape[0]), zero_rows_indices)].copy()
    true_array = true[~np.isin(np.arange(true.shape[0]), zero_rows_indices)].copy()
    corr_pred = np.corrcoef(pred_array,dtype=np.float32)
    corr_true = np.corrcoef(true_array,dtype=np.float32)
    
    x = np.trace(corr_pred.dot(corr_true))
    y = np.linalg.norm(corr_pred,'fro')*np.linalg.norm(corr_true,'fro')
    cmd = 1- x/(y+1e-8)
    return cmd

In [17]:
out_dir = '/disco_500t/xuhua/gitrepo/midas/result'
for i in range(3):
    pr_atac = collect_csv(join(out_dir, f'Misar-E13-15-18_cv{i+1}_missingatac/e1/default/predict/sp_00001999/subset_{i}/x_trans/rna_to_atac'))
    gt_atac = ATAC_ADS[i].X.A if sps.issparse(ATAC_ADS[i].X) else ATAC_ADS[i].X

    ad_pred = sc.AnnData(pr_atac, obs=ATAC_ADS[i].obs.copy())
    ad_pred.var_names = ATAC_ADS[i].var_names.copy()
    ad_pred.write_h5ad(f'/disco_500t/xuhua/gitrepo/BridgeNorm/figures/imputation/Misar_E13-E15-E18/midas/cv{i+1}_imputedATAC.h5ad')
