In [6]:
import os
import gc
import pandas as pd
import pickle
import matplotlib.pyplot as plt

import anndata as ad
import numpy as np
import yaml
import sys
import h5py
import logging
import scanpy as sc
from os.path import join
import scipy.io as sio
import scipy.sparse as sps
from sklearn.cluster import KMeans
import gzip
from scipy.io import mmread
from pathlib import Path, PurePath
from sklearn.metrics import adjusted_rand_score


import warnings
def wrap_warn_plot(adata, basis, color, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        sc.pl.embedding(adata, basis=basis, color=color, **kwargs)

def get_umap(ad, use_reps=[]):
    for use_rep in use_reps:
        umap_add_key = f'{use_rep}_umap'
        sc.pp.neighbors(ad, use_rep=use_rep, n_neighbors=15)
        sc.tl.umap(ad)
        ad.obsm[umap_add_key] = ad.obsm['X_umap']
    return ad

from sklearn.metrics import adjusted_rand_score

def split_ob(ads, ad_ref, ob='obs', key='emb2'):
    len_ads = [_.n_obs for _ in ads]
    if ob=='obsm':
        split_obsms = np.split(ad_ref.obsm[key], np.cumsum(len_ads[:-1]))
        for ad, v in zip(ads, split_obsms):
            ad.obsm[key] = v
    else:
        split_obs = np.split(ad_ref.obs[key].to_list(), np.cumsum(len_ads[:-1]))
        for ad, v in zip(ads, split_obs):
            ad.obs[key] = v

def eval_ads(ads, ref_key, src_key, exclude=[]):
    aris = []
    for ad in ads:
        _mask = ~ad.obs[ref_key].isin(exclude)
        gt = ad.obs[ref_key].values[_mask]
        pred = ad.obs[src_key].values[_mask]
        aris.append(adjusted_rand_score(pred, gt))
    return aris

from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans
def search_louvain(ad, use_rep, n_neighbors=15, n_clusters=5):
    sc.pp.neighbors(ad, n_neighbors=n_neighbors, use_rep=use_rep)
    rs = np.arange(0.1, 1.0, 0.1)
    n_cs = []
    for r in rs:
        sc.tl.louvain(ad, resolution=r, key_added=f'r={r}')
        n_cs.append(ad.obs[f'r={r}'].nunique())
    n_cs = np.array(n_cs)
    if (n_cs==n_clusters).sum() >= 1:
        ri = np.where(n_cs==n_clusters)[0][0]
        ad.obs['louvain_k'] = ad.obs[f'r={rs[ri]}'].to_list()
    else:
        kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(ad.obsm[use_rep])
        ad.obs['louvain_k'] = kmeans.labels_.astype('str')

import json
import copy
from matplotlib.image import imread
def load_spatial(path, adata, library_id='0'):
    tissue_positions_file = join(path, "tissue_positions.csv")
    files = dict(
        tissue_positions_file=tissue_positions_file,
        scalefactors_json_file=join(path, "scalefactors_json.json"),
        hires_image=join(path, "tissue_hires_image.png"),
        lowres_image=join(path, "tissue_lowres_image.png"),
    )
    
    adata.uns["spatial"] = dict()
    adata.uns["spatial"][library_id] = dict()
    adata.uns["spatial"][library_id]["images"] = dict()
    for res in ["hires", "lowres"]:
        try:
            adata.uns["spatial"][library_id]["images"][res] = imread(
                str(files[f"{res}_image"])
            )
        except Exception:
            raise OSError(f"Could not find '{res}_image'")

    # read json scalefactors
    adata.uns["spatial"][library_id]["scalefactors"] = json.loads(
        Path(files["scalefactors_json_file"]).read_bytes()
    )

    # adata.uns["spatial"][library_id]["metadata"] = {
    #     k: (str(attrs[k], "utf-8") if isinstance(attrs[k], bytes) else attrs[k])
    #     for k in ("chemistry_description", "software_version")
    #     if k in attrs
    # }

    # read coordinates
    positions = pd.read_csv(
        files["tissue_positions_file"],
        header=0 if Path(tissue_positions_file).name == "tissue_positions.csv" else None,
        index_col=0,
    )
    positions.columns = [
        "in_tissue",
        "array_row",
        "array_col",
        "pxl_col_in_fullres",
        "pxl_row_in_fullres",
    ]
    # print(positions.head())

    adata.obs = adata.obs.join(positions, how="left")

    adata.obsm["spatial"] = adata.obs[
        ["pxl_row_in_fullres", "pxl_col_in_fullres"]
    ].to_numpy()
   
    adata.obs.drop(
        columns=["pxl_row_in_fullres", "pxl_col_in_fullres"],
        inplace=True,
    )

from scib.metrics import lisi
def eval_lisi(
        adata,
        batch_keys=['domain', 'batch'],
        label_keys = ['gt'],
        use_rep='X_emb', use_neighbors=False,
    ):
    res = {}
    for key in batch_keys:
        adata.obs[key] = adata.obs[key].astype('category')

        _lisi = lisi.ilisi_graph(
            adata,
            key,
            'embed' if not use_neighbors else 'knn',
            use_rep=use_rep,
            k0=90,
            subsample=None,
            scale=True,
            n_cores=1,
            verbose=False,
        )
        res[key+'_iLISI'] = _lisi
    for key in label_keys:
        adata.obs[key] = adata.obs[key].astype('category')

        _lisi = lisi.clisi_graph(
            adata,
            key,
            'embed' if not use_neighbors else 'knn',
            use_rep=use_rep,
            batch_key=None,
            k0=90,
            subsample=None,
            scale=True,
            n_cores=1,
            verbose=False,
        )
        res[key+'_cLISI'] = _lisi
    df = pd.DataFrame.from_dict(res, orient='index').T
    # df.columns = [_+'_LISI' for _ in df.columns]
    return df

os.environ['R_HOME'] = '/disco_500t/xuhua/miniforge3/envs/Seurat5/lib/R'
os.environ['R_USER'] = '/disco_500t/xuhua/miniforge3/envs/Seurat5/lib/python3.8/site-packages/rpy2'
def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
    np.random.seed(random_seed)
    import rpy2.robjects as robjects
    robjects.r.library("mclust")

    import rpy2.robjects.numpy2ri
    rpy2.robjects.numpy2ri.activate()
    r_random_seed = robjects.r['set.seed']
    r_random_seed(random_seed)
    rmclust = robjects.r['Mclust']

    res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
    mclust_res = np.array(res[-2])

    adata.obs['mclust'] = mclust_res
    adata.obs['mclust'] = adata.obs['mclust'].astype('int')
    adata.obs['mclust'] = adata.obs['mclust'].astype('category')
    return adata
    
def load_data(_dir):
    feat_names = pd.read_csv(join(_dir, 'features.tsv.gz'), compression='gzip', sep='\t', header=None)
    barcodes   = pd.read_csv(join(_dir, 'barcodes.tsv.gz'), compression='gzip', sep='\t', header=None)

    with gzip.open(join(_dir, 'matrix.mtx.gz'), 'rb') as gzipped_file:
        mat = mmread(gzipped_file)

    ad = sc.AnnData(sps.csr_matrix(mat.T))
    ad.obs_names = barcodes[0].values
    ad.var_names = feat_names[1].values
    ad.var['id'] = feat_names[0].values
    ad.var['type'] = feat_names[2].values
    return ad

import json
import copy
from matplotlib.image import imread
def flip_coords(ads):
    for ad in ads:
        ad.obsm['spatial'] = -1 * ad.obsm['spatial']
        ad.obsm['spatial'] = ad.obsm['spatial'][:, ::-1]

def reorder(ad1, ad2):
    shared_barcodes = ad1.obs_names.intersection(ad2.obs_names)
    ad1 = ad1[shared_barcodes].copy()
    ad2 = ad2[shared_barcodes].copy()
    return ad1, ad2

def load_peak_expr(_dir):
    data = sio.mmread(join(_dir, 'data.mtx'))
    cname = pd.read_csv(join(_dir, 'barcode.csv'), index_col=0)['x'].to_list()
    feat = pd.read_csv(join(_dir, 'feat.csv'), index_col=0)['x'].to_list()
    ad = sc.AnnData(sps.csr_matrix(data.T))
    ad.obs_names = cname
    ad.var_names = feat
    return ad

In [7]:
def set_col2cat(ad, cols=[]):
    for col in cols:
        ad.obs[col] = ad.obs[col].astype('category')

def unify_colors(queries, color_key, ref_color_dict):
    for q in queries:
        q.obs[color_key] = q.obs[color_key].astype('category')
        q.uns[f'{color_key}_colors'] = [ref_color_dict[_] for _ in q.obs[color_key].cat.categories]
    return queries

def subset_ad(ad, subset_index):
    ad = ad[subset_index].copy()
    return ad

def set_spatial(ad):
    ad.obsm['spatial'] = ad.obs[['array_row', 'array_col']].values
    ad.obsm['spatial'] = ad.obsm['spatial'][:, ::-1]
    ad.obsm['spatial'][:, 1] = -1 * ad.obsm['spatial'][:, 1]
    return ad

In [8]:
from collections import Counter
def load_zu(_dir):
    zs = []

    for fi in sorted(os.listdir(_dir)):
        dfi = pd.read_csv(join(_dir, fi), header=None)
        zs.append(dfi.values)
    zs = np.vstack(zs)
    z, u = zs[:, :-2], zs[:, -2:]
    return z, u

def wrap_warn_plot(adata, basis, color, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        sc.pl.embedding(adata, basis=basis, color=color, **kwargs)

def get_umap(ad, use_reps=[]):
    for use_rep in use_reps:
        umap_add_key = f'{use_rep}_umap'
        sc.pp.neighbors(ad, use_rep=use_rep, n_neighbors=15)
        sc.tl.umap(ad)
        ad.obsm[umap_add_key] = ad.obsm['X_umap']
    return ad

In [9]:
result_dir = '/disco_500t/xuhua/gitrepo/midas/result'

In [10]:
data_dir = '/disco_500t/xuhua/data/spatial_multi_omics/lymp_node/LN-2024-new/outs'

ad3 = load_data(join(data_dir, 'filtered_feature_bc_matrix'))
ad3_rna = ad3[:, ad3.var['type']=='Gene Expression'].copy()
ad3_adt = ad3[:, ad3.var['type']=='Antibody Capture'].copy()
load_spatial(join(data_dir, 'spatial'), ad3_rna)
load_spatial(join(data_dir, 'spatial'), ad3_adt)

ad3_rna.obs['src'] = ad3_adt.obs['src'] = ['s3']*ad3_rna.n_obs
ad3_rna.obs_names = [f's3-{x}' for x in ad3_rna.obs_names]
ad3_adt.obs_names = [f's3-{x}' for x in ad3_adt.obs_names]

ad3_rna.var_names_make_unique()
ad3_adt.var_names_make_unique()

data_dir = '/disco_500t/xuhua/data/spatial_multi_omics/lymp_tonsil_ramen'

ad_a1_rna = sc.read_h5ad(join(data_dir, 'lymph_A1/adata_RNA.h5ad'))
ad_a1_adt = sc.read_h5ad(join(data_dir, 'lymph_A1/adata_ADT.h5ad'))
meta1 = pd.read_csv(join(data_dir, 'lymph_A1/A1_LN_cloupe_Kwoh.csv'), index_col=0) 
ad_a1_rna.obs['lab'] = meta1.loc[ad_a1_rna.obs_names, 'manual'].to_list()
ad_a1_adt.obs['lab'] = meta1.loc[ad_a1_adt.obs_names, 'manual'].to_list()
ad_a1_rna.obs['src'] = ad_a1_adt.obs['src'] = ['s1'] * ad_a1_rna.n_obs
ad_a1_rna.obs_names = [f's1-{x}' for x in ad_a1_rna.obs_names]
ad_a1_adt.obs_names = [f's1-{x}' for x in ad_a1_adt.obs_names]
ad_a1_rna.var_names_make_unique()
ad_a1_adt.var_names_make_unique()

ad_d1_rna = sc.read_h5ad(join(data_dir, 'lymph_D1/adata_RNA.h5ad'))
ad_d1_adt = sc.read_h5ad(join(data_dir, 'lymph_D1/adata_ADT.h5ad'))
meta2 = pd.read_csv(join(data_dir, 'lymph_D1/D1_LN_cloupe_Kwoh.csv'), index_col=0) 
ad_d1_rna.obs['lab'] = meta2.loc[ad_d1_rna.obs_names, 'manual'].to_list()
ad_d1_adt.obs['lab'] = meta2.loc[ad_d1_adt.obs_names, 'manual'].to_list()
ad_d1_rna.obs['src'] = ad_d1_adt.obs['src'] = ['s2'] * ad_d1_rna.n_obs
ad_d1_rna.obs_names = [f's2-{x}' for x in ad_d1_rna.obs_names]
ad_d1_adt.obs_names = [f's2-{x}' for x in ad_d1_adt.obs_names]
ad_d1_rna.var_names_make_unique()
ad_d1_adt.var_names_make_unique()

## unify feature names
shared_gene = ad_a1_rna.var_names.intersection(ad_d1_rna.var_names).intersection(ad3_rna.var_names)
shared_prot = ad_a1_adt.var_names.intersection(ad_d1_adt.var_names).intersection(ad3_adt.var_names)

ad_a1_rna, ad_d1_rna, ad3_rna = ad_a1_rna[:, shared_gene].copy(), ad_d1_rna[:, shared_gene].copy(), ad3_rna[:, shared_gene].copy()
ad_a1_adt, ad_d1_adt, ad3_adt = ad_a1_adt[:, shared_prot].copy(), ad_d1_adt[:, shared_prot].copy(), ad3_adt[:, shared_prot].copy()

  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


In [11]:
ad_rna_all = sc.concat([ad_a1_rna, ad_d1_rna, ad3_rna])
ad_adt_all = sc.concat([ad_a1_adt, ad_d1_adt, ad3_adt])

sc.pp.highly_variable_genes(ad_rna_all, batch_key="src", flavor="seurat_v3", n_top_genes=5000)

ad_a1_rna = ad_a1_rna[:, ad_rna_all.var.query('highly_variable').index].copy()
ad_d1_rna = ad_d1_rna[:, ad_rna_all.var.query('highly_variable').index].copy()
ad3_rna = ad3_rna[:, ad_rna_all.var.query('highly_variable').index].copy()

In [12]:
RNA_ADS = [ad_a1_rna, ad_d1_rna, ad3_rna]
ADT_ADS = [ad_a1_adt, ad_d1_adt, ad3_adt]
n_batches = 3
mod_dict = {'rna':RNA_ADS, 'adt':ADT_ADS}

In [1]:
for i in range(n_batches):  # train test split
    for mod in ['adt']:   # missing mod
        tmp_out_dir = f'/disco_500t/xuhua/gitrepo/midas/data/processed/Lymph_cv{i+1}_missing{mod}'
        print(tmp_out_dir)
        feat_dir = join(tmp_out_dir, 'feat')
        os.makedirs(feat_dir, exist_ok=True)
         
        df_feat_dims = pd.DataFrame(np.array([ad_a1_rna.n_vars, ad_a1_adt.n_vars]).reshape(1, -1), columns=['rna', 'adt'])
        df_feat_rna_names = pd.DataFrame(ad_a1_rna.var_names, columns=['x'])
        df_feat_adt_names = pd.DataFrame(ad_a1_adt.var_names, columns=['x'])
        df_feat_dims.to_csv(join(feat_dir, 'feat_dims.csv'))
        df_feat_rna_names.to_csv(join(feat_dir, 'feat_names_rna.csv'))
        df_feat_adt_names.to_csv(join(feat_dir, 'feat_names_adt.csv'))

        # # each subset
        subsets, mods = [], []
        for bi in range(3):
            tmp_set, tmp_mod = [], []
            for bmod in ['rna', 'adt']:
                if (bi==i) and (bmod == mod):
                    continue
                tmp_set.append(mod_dict[bmod][bi])
                tmp_mod.append(bmod)
            subsets.append(tmp_set)
            mods.append(tmp_mod)
        
        print(mods)
        for si in range(3):
            for fname in ['mask', 'mat', 'vec']:
                os.makedirs(join(tmp_out_dir, f'subset_{si}/{fname}'), exist_ok=True)
            tmp_dir = join(tmp_out_dir, f'subset_{si}')
            for ad,mi in zip(subsets[si], mods[si]):
                mat = ad.X.A if sps.issparse(ad.X) else ad.X
                df_mat = pd.DataFrame(mat, index=ad.obs_names, columns=ad.var_names)
                df_mat.to_csv(join(tmp_dir, f'mat/{mi}.csv'))
        
                os.makedirs(join(tmp_dir, f'vec/{mi}'), exist_ok=True)
                for idx, mati in enumerate(mat):
                    pd.DataFrame(mati.reshape(1, -1)).to_csv(join(tmp_dir, 'vec/{}/{:05d}.csv'.format(mi, idx)), header=None, index=None)
                pd.DataFrame(ad.obs_names, columns=['x']).to_csv(join(tmp_dir, 'cell_names.csv'))
        
                # save mask
                pd.DataFrame(np.ones(ad.n_vars, dtype='int')).to_csv(join(tmp_dir, f'mask/{mi}.csv'))

In [13]:
e, ep = 1, 2000
for i in range(3):  # train test split
    for mod in ['adt']:   # missing mod
        training_command = f'CUDA_VISIBLE_DEVICES=1 python run.py --exp e{e} --task Lymph_cv{i+1}_missing{mod} --epoch_num {ep}'
    
        run_command = 'CUDA_VISIBLE_DEVICES=1 python run.py --task Lymph_cv{}_missing{} --act translate --init_model sp_{:08d} --exp e{}'\
                                    .format(i+1, mod, ep-1, e)
        print(training_command)
        print(run_command)

In [14]:
import csv
def csv_read(path):
    res = []
    with open(path, mode='r', newline='') as file:
        reader = csv.reader(file)

        for row in reader:
            try:
                float_row = [float(item) for item in row]
                res.append(float_row)  # Each row is now a list of floats
            except ValueError as e:
                print(f"Error converting to float: {e}")
    res = np.vstack(res)   
    return res

def collect_csv(_dir):
    fls = sorted(os.listdir(_dir))
    res = []
    for fl in fls:
        # df = pd.read_csv(join(_dir, fl), header=None)
        data = csv_read(join(_dir, fl))
        res.append(data)
    res = np.vstack(res)
    return res
    
import copy
def binarize(Xs, bin_thr=0):
    rs = []
    for X in Xs:
        X = copy.deepcopy(X.A) if sps.issparse(X) else copy.deepcopy(X)
        X[X>bin_thr] = 1
        rs.append(X)
    return rs

from sklearn.metrics import roc_auc_score
def eval_AUC_all(gt_X, pr_X, bin_thr=1):
    gt_X = binarize([gt_X], bin_thr)[0].flatten()
    pr_X = pr_X.flatten()
    auroc = roc_auc_score(gt_X, pr_X)
    return auroc

def PCCs(gt_X, pr_X):
    pcc_cell = [np.corrcoef(gt_X[i,:], pr_X[i,:])[0,1] for i in range(gt_X.shape[0])] 
    pcc_peak = [np.corrcoef(gt_X[:,i], pr_X[:,i])[0,1] for i in range(gt_X.shape[1])] 
    return pcc_cell, pcc_peak

def cal_cmd(pred, true):
    zero_rows_indices1 = list(np.where(~pred.any(axis=1))[0]) # all-zero rows
    zero_rows_indices2 = list(np.where(~true.any(axis=1))[0])
    zero_rows_indices = zero_rows_indices1 + zero_rows_indices2
    rm_p = len(zero_rows_indices) / pred.shape[0]
    if rm_p >= .05:
        print(f'Warning: two many rows {rm_p}% with all zeros')
    pred_array = pred[~np.isin(np.arange(pred.shape[0]), zero_rows_indices)].copy()
    true_array = true[~np.isin(np.arange(true.shape[0]), zero_rows_indices)].copy()
    corr_pred = np.corrcoef(pred_array,dtype=np.float32)
    corr_true = np.corrcoef(true_array,dtype=np.float32)
    
    x = np.trace(corr_pred.dot(corr_true))
    y = np.linalg.norm(corr_pred,'fro')*np.linalg.norm(corr_true,'fro')
    cmd = 1- x/(y+1e-8)
    return cmd

from os.path import join

In [15]:
out_dir = '/disco_500t/xuhua/gitrepo/midas/result'
for i in range(3):
    pr_atac = collect_csv(join(out_dir, f'Lymph_cv{i+1}_missingadt/e1/default/predict/sp_00001999/subset_{i}/x_trans/rna_to_adt'))
    gt_atac = ADT_ADS[i].X.A if sps.issparse(ADT_ADS[i].X) else ADT_ADS[i].X

    ad_pred = sc.AnnData(pr_atac, obs=ADT_ADS[i].obs.copy())
    ad_pred.var_names = ADT_ADS[i].var_names.copy()
    ad_pred.write_h5ad(f'/disco_500t/xuhua/gitrepo/BridgeNorm/figures/imputation/Lymph/midas/cv{i+1}_imputedADT.h5ad')
