### MIDAS cannot handle integration of ATAC and histone modification 

In [None]:
import os
import gc
import pandas as pd
import pickle
import matplotlib.pyplot as plt

import anndata as ad
import numpy as np
import yaml
import sys
import h5py
import logging
import scanpy as sc
from os.path import join
import scipy.io as sio
import scipy.sparse as sps
from sklearn.cluster import KMeans
import gzip
from scipy.io import mmread
from pathlib import Path, PurePath
from sklearn.metrics import adjusted_rand_score


import warnings
def wrap_warn_plot(adata, basis, color, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        sc.pl.embedding(adata, basis=basis, color=color, **kwargs)

def get_umap(ad, use_reps=[]):
    for use_rep in use_reps:
        umap_add_key = f'{use_rep}_umap'
        sc.pp.neighbors(ad, use_rep=use_rep, n_neighbors=15)
        sc.tl.umap(ad)
        ad.obsm[umap_add_key] = ad.obsm['X_umap']
    return ad

from sklearn.metrics import adjusted_rand_score

def split_ob(ads, ad_ref, ob='obs', key='emb2'):
    len_ads = [_.n_obs for _ in ads]
    if ob=='obsm':
        split_obsms = np.split(ad_ref.obsm[key], np.cumsum(len_ads[:-1]))
        for ad, v in zip(ads, split_obsms):
            ad.obsm[key] = v
    else:
        split_obs = np.split(ad_ref.obs[key].to_list(), np.cumsum(len_ads[:-1]))
        for ad, v in zip(ads, split_obs):
            ad.obs[key] = v

def eval_ads(ads, ref_key, src_key, exclude=[]):
    aris = []
    for ad in ads:
        _mask = ~ad.obs[ref_key].isin(exclude)
        gt = ad.obs[ref_key].values[_mask]
        pred = ad.obs[src_key].values[_mask]
        aris.append(adjusted_rand_score(pred, gt))
    return aris

from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans
def search_louvain(ad, use_rep, n_neighbors=15, n_clusters=5):
    sc.pp.neighbors(ad, n_neighbors=n_neighbors, use_rep=use_rep)
    rs = np.arange(0.1, 1.0, 0.1)
    n_cs = []
    for r in rs:
        sc.tl.louvain(ad, resolution=r, key_added=f'r={r}')
        n_cs.append(ad.obs[f'r={r}'].nunique())
    n_cs = np.array(n_cs)
    if (n_cs==n_clusters).sum() >= 1:
        ri = np.where(n_cs==n_clusters)[0][0]
        ad.obs['louvain_k'] = ad.obs[f'r={rs[ri]}'].to_list()
    else:
        kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(ad.obsm[use_rep])
        ad.obs['louvain_k'] = kmeans.labels_.astype('str')

from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import f1_score
def eval_labelTransfer(ad1, ad2, use_rep, lab_key, knn=10):
     with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=FutureWarning)
        neigh1 = KNeighborsClassifier(n_neighbors=knn)
        neigh1.fit(ad1.obsm[use_rep], ad1.obs[lab_key].to_list())
        pr_lab2 = neigh1.predict(ad2.obsm[use_rep])
        f1_1 = f1_score(ad2.obs[lab_key].values, pr_lab2, #labels=['1.0', '2.0', '3.0', '4.0'], 
                        average='macro')
        # acc1 = (pr_lab2 == ad2.obs[lab_key].values).mean()
    
        neigh2 = KNeighborsClassifier(n_neighbors=knn)
        neigh2.fit(ad2.obsm[use_rep], ad2.obs[lab_key].to_list())
        pr_lab1 = neigh2.predict(ad1.obsm[use_rep])
        # acc2 = (pr_lab1 == ad1.obs[lab_key].values).mean()
        f1_2 = f1_score(ad1.obs[lab_key].values, pr_lab1, #labels=['1.0', '2.0', '3.0', '4.0'], 
                        average='macro')
        return (f1_1+f1_2)/2

from scib.metrics import lisi
def eval_lisi(
        adata,
        batch_keys=['domain', 'batch'],
        label_keys = ['gt'],
        use_rep='X_emb', use_neighbors=False,
    ):
    res = {}
    for key in batch_keys:
        adata.obs[key] = adata.obs[key].astype('category')

        _lisi = lisi.ilisi_graph(
            adata,
            key,
            'embed' if not use_neighbors else 'knn',
            use_rep=use_rep,
            k0=90,
            subsample=None,
            scale=True,
            n_cores=1,
            verbose=False,
        )
        res[key+'_iLISI'] = _lisi
    for key in label_keys:
        adata.obs[key] = adata.obs[key].astype('category')

        _lisi = lisi.clisi_graph(
            adata,
            key,
            'embed' if not use_neighbors else 'knn',
            use_rep=use_rep,
            batch_key=None,
            k0=90,
            subsample=None,
            scale=True,
            n_cores=1,
            verbose=False,
        )
        res[key+'_cLISI'] = _lisi
    df = pd.DataFrame.from_dict(res, orient='index').T
    # df.columns = [_+'_LISI' for _ in df.columns]
    return df

os.environ['R_HOME'] = '/disco_500t/xuhua/miniforge3/envs/Seurat5/lib/R'
os.environ['R_USER'] = '/disco_500t/xuhua/miniforge3/envs/Seurat5/lib/python3.8/site-packages/rpy2'
def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
    np.random.seed(random_seed)
    import rpy2.robjects as robjects
    robjects.r.library("mclust")

    import rpy2.robjects.numpy2ri
    rpy2.robjects.numpy2ri.activate()
    r_random_seed = robjects.r['set.seed']
    r_random_seed(random_seed)
    rmclust = robjects.r['Mclust']

    res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
    mclust_res = np.array(res[-2])

    adata.obs['mclust'] = mclust_res
    adata.obs['mclust'] = adata.obs['mclust'].astype('int')
    adata.obs['mclust'] = adata.obs['mclust'].astype('category')
    return adata

def load_data(_dir):
    feat_names = pd.read_csv(join(_dir, 'features.tsv.gz'), compression='gzip', sep='\t', header=None)
    barcodes   = pd.read_csv(join(_dir, 'barcodes.tsv.gz'), compression='gzip', sep='\t', header=None)

    with gzip.open(join(_dir, 'matrix.mtx.gz'), 'rb') as gzipped_file:
        mat = mmread(gzipped_file)

    ad = sc.AnnData(sps.csr_matrix(mat.T))
    ad.obs_names = barcodes[0].values
    ad.var_names = feat_names[0].values
    ad.var['name'] = feat_names[1].values
    ad.var['type'] = feat_names[2].values
    return ad

import json
import copy
from matplotlib.image import imread
def flip_coords(ads):
    for ad in ads:
        ad.obsm['spatial'] = -1 * ad.obsm['spatial']
        ad.obsm['spatial'] = ad.obsm['spatial'][:, ::-1]

def reorder(ad1, ad2):
    shared_barcodes = ad1.obs_names.intersection(ad2.obs_names)
    ad1 = ad1[shared_barcodes].copy()
    ad2 = ad2[shared_barcodes].copy()
    return ad1, ad2

def load_peak_expr(_dir):
    data = sio.mmread(join(_dir, 'data.mtx'))
    cname = pd.read_csv(join(_dir, 'barcode.csv'), index_col=0)['x'].to_list()
    feat = pd.read_csv(join(_dir, 'feat.csv'), index_col=0)['x'].to_list()
    ad = sc.AnnData(sps.csr_matrix(data.T))
    ad.obs_names = cname
    ad.var_names = feat
    return ad

In [None]:
def set_col2cat(ad, cols=[]):
    for col in cols:
        ad.obs[col] = ad.obs[col].astype('category')

def unify_colors(queries, color_key, ref_color_dict):
    for q in queries:
        q.obs[color_key] = q.obs[color_key].astype('category')
        q.uns[f'{color_key}_colors'] = [ref_color_dict[_] for _ in q.obs[color_key].cat.categories]
    return queries

def subset_ad(ad, subset_index):
    ad = ad[subset_index].copy()
    return ad

def set_spatial(ad):
    ad.obsm['spatial'] = ad.obs[['array_row', 'array_col']].values
    ad.obsm['spatial'] = ad.obsm['spatial'][:, ::-1]
    ad.obsm['spatial'][:, 1] = -1 * ad.obsm['spatial'][:, 1]
    return ad

In [None]:
from collections import Counter
def load_zu(_dir):
    zs = []

    for fi in sorted(os.listdir(_dir)):
        dfi = pd.read_csv(join(_dir, fi), header=None)
        zs.append(dfi.values)
    zs = np.vstack(zs)
    z, u = zs[:, :-2], zs[:, -2:]
    return z, u

def wrap_warn_plot(adata, basis, color, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        sc.pl.embedding(adata, basis=basis, color=color, **kwargs)

def get_umap(ad, use_reps=[]):
    for use_rep in use_reps:
        umap_add_key = f'{use_rep}_umap'
        sc.pp.neighbors(ad, use_rep=use_rep, n_neighbors=15)
        sc.tl.umap(ad)
        ad.obsm[umap_add_key] = ad.obsm['X_umap']
    return ad

def add_name_prefix(list, prefix):
    new_list = [f'{prefix}-{x}' for x in list]
    return new_list

In [None]:
result_dir = '/disco_500t/xuhua/gitrepo/midas/result'

In [None]:
# rna+H3K27me3
data_dir = '/disco_500t/xuhua/data/real_mosaic_cases/mouse_brain_rna+H3K27me3'
df_mult_rna = pd.read_csv(join(data_dir, 'rna+H3K27me3/GSM6753044_MouseBrain_20um_100barcodes_H3K27me3_matrix.tsv'), sep='\t')
df_mult_rna_spatial = pd.read_csv(join(data_dir, 'rna+H3K27me3/spatial/tissue_positions_list.csv'), header=None, index_col=0)
ad_mult_rna = sc.AnnData(df_mult_rna.T, obsm={'spatial': df_mult_rna_spatial.loc[df_mult_rna.columns, [2, 3]].values})
ad_mult_rna.layers['counts'] = ad_mult_rna.X.copy()

ad_mult_atac = load_peak_expr(join(data_dir, 'rna+H3K27me3/peak_data'))
df_mult_atac_spatial = pd.read_csv(join(data_dir, 'rna+H3K27me3/spatial/tissue_positions_list.csv'), header=None, index_col=0)
ad_mult_atac.obsm['spatial'] = df_mult_atac_spatial.loc[ad_mult_atac.obs_names, [2, 3]].values
ad_mult_atac.layers['counts'] = ad_mult_atac.X.copy()

ad_mult_rna, ad_mult_atac = reorder(ad_mult_rna, ad_mult_atac)

# rna+H3K4me3
data_dir = '/disco_500t/xuhua/data/real_mosaic_cases/mouse_brain_rna+H3K4me3'
df_mult_rna2 = pd.read_csv(join(data_dir, 'rna+H3K4me3/GSM6753046_MouseBrain_20um_100barcodes_H3K4me3_matrix.tsv'), sep='\t')
df_mult_rna_spatial2 =pd.read_csv(join(data_dir, 'rna+H3K4me3/spatial/tissue_positions_list.csv'), header=None, index_col=0)
ad_mult_rna2 = sc.AnnData(df_mult_rna2.T, obsm={'spatial': df_mult_rna_spatial2.loc[df_mult_rna2.columns, [2, 3]].values})
ad_mult_rna2.layers['counts'] = ad_mult_rna2.X.copy()

ad_mult_atac2 = load_peak_expr(join(data_dir, 'rna+H3K4me3/peak_data'))
df_mult_atac_spatial2 = pd.read_csv(join(data_dir, 'rna+H3K4me3/spatial/tissue_positions_list.csv'), header=None, index_col=0)
ad_mult_atac2.obsm['spatial'] = df_mult_atac_spatial2.loc[ad_mult_atac2.obs_names, [2, 3]].values
ad_mult_atac2.layers['counts'] = ad_mult_atac2.X.copy()

ad_mult_rna2, ad_mult_atac2 = reorder(ad_mult_rna2, ad_mult_atac2)

# rna+H3K27ac
data_dir = '/disco_500t/xuhua/data/real_mosaic_cases/mouse_brain_rna+H3K27ac/'
df_mult_rna3 = pd.read_csv(join(data_dir, 'rna+atac/GSM6753045_MouseBrain_20um_100barcodes_H3K27ac_matrix.tsv'), sep='\t')
df_mult_rna_spatial3 = pd.read_csv(join(data_dir, 'rna+atac/GSM6753045_spatial/tissue_positions_list.csv'), index_col=0, header=None)
ad_mult_rna3 = sc.AnnData(df_mult_rna3.T, obsm={'spatial': df_mult_rna_spatial3.loc[df_mult_rna3.columns, [2, 3]].values})
ad_mult_rna3.layers['counts'] = ad_mult_rna3.X.copy()

ad_mult_atac3 = load_peak_expr(join(data_dir, 'rna+atac/peak_data_3slices/GSM6753045_peak_data'))
df_mult_atac_spatial3 = pd.read_csv(join(data_dir, 'rna+atac/GSM6753045_spatial/tissue_positions_list.csv'), index_col=0, header=None)
ad_mult_atac3.obsm['spatial'] = df_mult_atac_spatial3.loc[ad_mult_atac3.obs_names, [2, 3]].values
ad_mult_atac3.layers['counts'] = ad_mult_atac3.X.copy()

ad_mult_rna3, ad_mult_atac3 = reorder(ad_mult_rna3, ad_mult_atac3)

# rna+atac
data_dir = '/disco_500t/xuhua/data/real_mosaic_cases/mouse_brain_rna+atac/'
df_mult_rna4 = pd.read_csv(join(data_dir, 'rna+atac/GSM6753043_MouseBrain_20um_100barcodes_ATAC_matrix.tsv'), sep='\t')
df_mult_rna_spatial4 = pd.read_csv(join(data_dir, 'rna+atac/GSM6753043_MouseBrain_20um_100barcodes_ATAC_spatial/tissue_positions_list.csv'), header=None, index_col=0)
ad_mult_rna4 = sc.AnnData(df_mult_rna4.T, obsm={'spatial': df_mult_rna_spatial4.loc[df_mult_rna4.columns, [2, 3]].values})
ad_mult_rna4.layers['counts'] = ad_mult_rna4.X.copy()

ad_mult_atac4 = load_peak_expr(join(data_dir, 'rna+atac/For_Imputation_Task/GSM6758285_peak_data'))
df_mult_atac_spatial4 = pd.read_csv(join(data_dir, 'rna+atac/GSM6753043_MouseBrain_20um_100barcodes_ATAC_spatial/tissue_positions_list.csv'), header=None, index_col=0)
ad_mult_atac4.obsm['spatial'] = df_mult_atac_spatial4.loc[ad_mult_atac4.obs_names, [2, 3]].values
ad_mult_atac4.layers['counts'] = ad_mult_atac4.X.copy()

ad_mult_rna4, ad_mult_atac4 = reorder(ad_mult_rna4, ad_mult_atac4)

In [None]:
def flip_coords(ads):
    for ad in ads:
        ad.obsm['spatial'][:, 1] = -1 * ad.obsm['spatial'][:, 1]
        
flip_coords([ad_mult_rna, ad_mult_rna2, ad_mult_rna3, ad_mult_rna4])

In [None]:
shared_gene = ad_mult_rna.var_names
for ad in [ad_mult_rna2, ad_mult_rna3, ad_mult_rna4]:
    shared_gene = np.intersect1d(ad.var_names, shared_gene)
for ad in [ad_mult_rna, ad_mult_rna2, ad_mult_rna3, ad_mult_rna4]:
    ad = ad[:, shared_gene].copy()

ad_mult_rna.obs_names = ad_mult_atac.obs_names = add_name_prefix(ad_mult_rna.obs_names, 'rna+H3K27me3')
ad_mult_rna2.obs_names = ad_mult_atac2.obs_names = add_name_prefix(ad_mult_rna2.obs_names, 'rna+H3K4me3')
ad_mult_rna3.obs_names = ad_mult_atac3.obs_names = add_name_prefix(ad_mult_rna3.obs_names, 'rna+H3K27ac')
ad_mult_rna4.obs_names = ad_mult_atac4.obs_names = add_name_prefix(ad_mult_rna4.obs_names, 'rna+atac')

ad_mult_rna.obs['src'] = ad_mult_atac.obs['src'] = ['rna+H3K27me3']*ad_mult_rna.n_obs
ad_mult_rna2.obs['src'] = ad_mult_atac2.obs['src'] = ['rna+H3K4me3']*ad_mult_rna2.n_obs
ad_mult_rna3.obs['src'] = ad_mult_atac3.obs['src'] = ['rna+H3K27ac']*ad_mult_rna3.n_obs
ad_mult_rna4.obs['src'] = ad_mult_atac4.obs['src'] = ['rna+atac']*ad_mult_rna4.n_obs

In [None]:
ad_rna_all = sc.concat([ad_mult_rna, ad_mult_rna2, ad_mult_rna3, ad_mult_rna4])

sc.pp.highly_variable_genes(ad_rna_all, flavor='seurat_v3', n_top_genes=10000, batch_key='src')
hvg_names = ad_rna_all.var.query('highly_variable').index.to_numpy()

HVP_NAMES = []
for ad in [ad_mult_atac, ad_mult_atac2, ad_mult_atac3, ad_mult_atac4]:
    sc.pp.highly_variable_genes(ad, flavor='seurat_v3', n_top_genes=100000)
    hvp_names = (ad.var.query('highly_variable').index.to_numpy())
    HVP_NAMES.append(hvp_names)

In [None]:
ad_mult_rna = ad_mult_rna[:, hvg_names].copy(); ad_mult_atac = ad_mult_atac[:, HVP_NAMES[0]].copy()
ad_mult_rna2 = ad_mult_rna2[:, hvg_names].copy(); ad_mult_atac2 = ad_mult_atac2[:, HVP_NAMES[1]].copy()
ad_mult_rna3 = ad_mult_rna3[:, hvg_names].copy(); ad_mult_atac3 = ad_mult_atac3[:, HVP_NAMES[2]].copy()
ad_mult_rna4 = ad_mult_rna4[:, hvg_names].copy(); ad_mult_atac4 = ad_mult_atac4[:, HVP_NAMES[3]].copy()

In [None]:
ad_mult_atac4.shape

In [None]:
## filter feat names
filtered_atac_feats = [_ for _ in ad_mult_atac.var_names if _.startswith('chr')]
ad_mult_atac = ad_mult_atac[:, filtered_atac_feats].copy()
filtered_atac_feats = [_ for _ in ad_mult_atac2.var_names if _.startswith('chr')]
ad_mult_atac2 = ad_mult_atac2[:, filtered_atac_feats].copy()
filtered_atac_feats = [_ for _ in ad_mult_atac3.var_names if _.startswith('chr')]
ad_mult_atac3 = ad_mult_atac3[:, filtered_atac_feats].copy()
filtered_atac_feats = [_ for _ in ad_mult_atac4.var_names if _.startswith('chr')]
ad_mult_atac4 = ad_mult_atac4[:, filtered_atac_feats].copy()

In [None]:
def split_list_byChr(input_list):
    atac_name_chunks = []
    for n in input_list:
        chr = n.split("-")[0]
        if len(atac_name_chunks) == 0 or chr!=atac_name_chunks[-1]:
            atac_name_chunks.append(chr)
    return atac_name_chunks

In [None]:
tmp_out_dir = f'/disco_500t/xuhua/gitrepo/midas/data/processed/MB_RNA_ATAC_Histone'
feat_dir = join(tmp_out_dir, 'feat')
os.makedirs(feat_dir, exist_ok=True)

atac1_feat_chunks = split_list_byChr(ad_mult_atac.var_names) # split the chr-? in order
atac2_feat_chunks = split_list_byChr(ad_mult_atac2.var_names) # split the chr-? in order
atac3_feat_chunks = split_list_byChr(ad_mult_atac3.var_names) # split the chr-? in order
atac4_feat_chunks = split_list_byChr(ad_mult_atac4.var_names) # split the chr-? in order

In [None]:
(np.array(atac1_feat_chunks) == np.array(atac2_feat_chunks)).all(), \
(np.array(atac1_feat_chunks) == np.array(atac3_feat_chunks)).all(), \
(np.array(atac1_feat_chunks) == np.array(atac4_feat_chunks)).all(),

In [None]:
atac_chr_count1 = Counter([_.split('-')[0] for _ in ad_mult_atac.var_names]) # count the frequency of each chr-?
atac_chr_count2 = Counter([_.split('-')[0] for _ in ad_mult_atac2.var_names]) # count the frequency of each chr-?
atac_chr_count3 = Counter([_.split('-')[0] for _ in ad_mult_atac3.var_names]) # count the frequency of each chr-?
atac_chr_count4 = Counter([_.split('-')[0] for _ in ad_mult_atac4.var_names]) # count the frequency of each chr-?

df_feat_dims = pd.DataFrame(np.array([atac_chr_count1[_] for _ in atac1_feat_chunks]).reshape(-1, 1), columns=['H3K27me3'])
df_feat_dims['H3K4me3'] = [atac_chr_count2[_] for _ in atac2_feat_chunks]
df_feat_dims['H3K27ac'] = [atac_chr_count3[_] for _ in atac3_feat_chunks]
df_feat_dims['atac'] = [atac_chr_count4[_] for _ in atac4_feat_chunks]
df_feat_dims['rna'] = ad_mult_rna.n_vars
df_feat_dims.to_csv(join(feat_dir, 'feat_dims.csv'))

df_feat_rna_names = pd.DataFrame(ad_mult_rna.var_names, columns=['x'])
df_feat_atac1_names = pd.DataFrame(ad_mult_atac.var_names, columns=['x'])
df_feat_atac2_names = pd.DataFrame(ad_mult_atac2.var_names, columns=['x'])
df_feat_atac3_names = pd.DataFrame(ad_mult_atac3.var_names, columns=['x'])
df_feat_atac4_names = pd.DataFrame(ad_mult_atac4.var_names, columns=['x'])

df_feat_rna_names.to_csv(join(feat_dir, 'feat_names_rna.csv'))
df_feat_atac1_names.to_csv(join(feat_dir, 'feat_names_H3K27me3.csv'))
df_feat_atac2_names.to_csv(join(feat_dir, 'feat_names_H3K4me3.csv'))
df_feat_atac3_names.to_csv(join(feat_dir, 'feat_names_H3K27ac.csv'))
df_feat_atac4_names.to_csv(join(feat_dir, 'feat_names_atac.csv'))

In [None]:
# each subset
subsets = [[ad_mult_rna, ad_mult_atac], [ad_mult_rna2, ad_mult_atac2], [ad_mult_rna3, ad_mult_atac3], [ad_mult_rna4, ad_mult_atac4]]
mods = [['rna', 'H3K27me3'], ['rna', 'H3K4me3'], ['rna', 'H3K27ac'], ['rna', 'atac']]

for si in range(4):
    for fname in ['mask', 'mat', 'vec']:
        os.makedirs(join(tmp_out_dir, f'subset_{si}/{fname}'), exist_ok=True)
    tmp_dir = join(tmp_out_dir, f'subset_{si}')
    for ad,mi in zip(subsets[si], mods[si]):
        mat = ad.X.A if sps.issparse(ad.X) else ad.X
        df_mat = pd.DataFrame(mat, index=ad.obs_names, columns=ad.var_names)
        df_mat.to_csv(join(tmp_dir, f'mat/{mi}.csv'))

        os.makedirs(join(tmp_dir, f'vec/{mi}'), exist_ok=True)
        for idx, mati in enumerate(mat):
            pd.DataFrame(mati.reshape(1, -1)).to_csv(join(tmp_dir, 'vec/{}/{:05d}.csv'.format(mi, idx)), header=None, index=None)
        pd.DataFrame(ad.obs_names, columns=['x']).to_csv(join(tmp_dir, 'cell_names.csv'))

        # save mask
        pd.DataFrame(np.ones(ad.n_vars, dtype='int')).to_csv(join(tmp_dir, f'mask/{mi}.csv'))

In [None]:
for e, ep in enumerate([2000], 1):
    training_command = f'CUDA_VISIBLE_DEVICES=1 python run.py --exp e{e} --task MB_RNA_ATAC_Histone --epoch_num {ep}'

    run_command = 'CUDA_VISIBLE_DEVICES=1 python run.py --task MB_RNA_ATAC_Histone --act predict_all_latent_bc --init_model sp_{:08d} --exp e{}'\
                                .format(ep-1, e)
    print(training_command)
    print(run_command)

In [None]:
e, ep = 1, 2000
model_pt = 'sp_{:08d}'.format(ep-1)

tmp_dir = f'{result_dir}/MB_RNA_ATAC_Histone/e{e}/default/predict/{model_pt}'
target_ads = [ad_mult_rna, ad_mult_rna2, ad_mult_rna3, ad_mult_rna4]
for i in range(4):
    z, u = load_zu(join(tmp_dir, f'subset_{i}/z/joint'))
    target_ads[i].obsm['z'] = z
    target_ads[i].obsm['u'] = u
    
ad_mosaic = sc.concat(target_ads)
ad_mosaic = get_umap(ad_mosaic, ['z'])

wrap_warn_plot(ad_mosaic, 'z_umap', ['src'])

In [None]:
# for n_clusters in [4, 6, 8, 10, 11]:
#     try:

#         from sklearn.cluster import KMeans
#         kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(ad_mosaic.obsm['z'])
#         ad_mosaic.obs['kmeans'] = kmeans.labels_.astype('str')
#         clust_key = 'kmeans'

#         # ad_mosaic = mclust_R(ad_mosaic, n_clusters, used_obsm='z')  
#         # clust_key = 'mclust'
#         # ad_mosaic.obs['mclust'] = ad_mosaic.obs['mclust'].astype('str')
    
#     except:
        
#         from sklearn.cluster import KMeans
#         print("mclust failed, try kmeans")
#         kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(ad_mosaic.obsm['z'])
#         ad_mosaic.obs['kmeans'] = kmeans.labels_.astype('str')
#         clust_key = 'kmeans'

#         # ad_mosaic = mclust_R(ad_mosaic, n_clusters, used_obsm='z')  
#         # clust_key = 'mclust'
#         # ad_mosaic.obs['mclust'] = ad_mosaic.obs['mclust'].astype('str')
        
#     split_ob([ad_bridge_rna, ad_test_rna, ad_test_adt], ad_mosaic, ob='obs', key=clust_key)
#     r = eval_ads([ad_mosaic, ad_bridge_rna, ad_test_rna, ad_test_adt], 'label', clust_key, exclude=['Exclude'])
#     print(n_clusters, ['{:.4f}'.format(_) for _ in r])



In [None]:
lisi_res = eval_lisi(
    ad_mosaic,
    batch_keys=['src'],
    label_keys = [],
    use_rep='z', use_neighbors=False,
)
lisi_res['src_iLISI'][0]

### batch corr

In [None]:
from batchCorr import HARMONY

ad_mosaic.obsm['z_har'] = HARMONY(pd.DataFrame(ad_mosaic.obsm['z']), ad_mosaic.obs['src'].to_list())
split_ob([ad_mult_rna, ad_mult_rna2, ad_mult_rna3, ad_mult_rna4], ad_mosaic, ob='obsm', key='z_har')

# try:

#     from sklearn.cluster import KMeans
#     print("mclust failed, try kmeans")
#     kmeans = KMeans(n_clusters=ad_mosaic.obs['Combined_Clusters_annotation'].nunique(), random_state=0).fit(ad_mosaic.obsm['z_har'])
#     ad_mosaic.obs['kmeans'] = kmeans.labels_.astype('str')
#     clust_key = 'kmeans'

# except:

#     ad_mosaic = mclust_R(ad_mosaic, ad_mosaic.obs.Combined_Clusters_annotation.nunique(), used_obsm='z_har')  
#     clust_key = 'mclust'
    
# split_ob([ad_bridge_rna, ad_test1_rna, ad_test2_atac], ad_mosaic, ob='obs',  key=clust_key)

lisi_res = eval_lisi(
    ad_mosaic,
    batch_keys=['src'],
    label_keys = [],
    use_rep='z_har', use_neighbors=False,
)
r2 = lisi_res['src_iLISI'][0]
r2

## Vis

In [None]:
fig_dir = f'/disco_500t/xuhua/gitrepo/BridgeNorm/figures/MB_RNA+ATAC+Histone/midas'
os.makedirs(fig_dir, exist_ok=True)
df = pd.DataFrame(ad_mosaic.obsm['z'], index=ad_mosaic.obs_names.to_list())
# df['before_clust'] = ad_mosaic.obs['before_clust'].to_list() 
df.to_csv(join(fig_dir, 'X_emb.csv'))

df = pd.DataFrame(ad_mosaic.obsm['z_har'], index=ad_mosaic.obs_names.to_list())
# df['after_clust'] = ad_mosaic.obs['after_clust'].to_list() 
df.to_csv(join(fig_dir, 'X_emb_har.csv'))