In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle

import anndata as ad
import numpy as np
import pandas as pd
import yaml
import sys
import scanpy as sc
import scipy.sparse as sps
import scipy.io as sio

import scglue
import seaborn as sns

from os.path import join

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Path to the data directory
root_dir = '/home/yanxh/gitrepo/multi-omics-matching/neurips2021_multimodal_topmethods-main'
data_dir = "/home/sda1/yanxh/data/DOGMA"

par = {}
par['output_pretrain'] = os.path.join(
    root_dir, 
    'output/pretrain/clue/dogma_StimSplit.clue_train.output_pretrain/')

In [5]:
print('Reading `mtx` files...')
rna_count_mat = sps.csr_matrix(sio.mmread(join(data_dir, 'RNA/rna_mat_count.mtx')).T)
adt_count_mat = sps.csr_matrix(sio.mmread(join(data_dir, 'ADT/adt_mat_count.mtx')).T)
atac_count_mat = sps.csr_matrix(sio.mmread(join(data_dir, 'ATAC/atac_mat_count.mtx')).T)

rna_names = pd.read_csv(join(data_dir, 'RNA/hvg_names.csv'))['VariableFeatures(data_ref)'].to_numpy()
adt_names = pd.read_csv(join(data_dir, 'ADT/adt_names.csv'))['VariableFeatures(data_ref)'].to_numpy()
atac_names = pd.read_csv(join(data_dir, 'ATAC/hvp_names.csv'))['VariableFeatures(data_ref)'].to_numpy()

cell_names = pd.read_csv(join(data_dir, 'cell_names.csv'))['x'].to_numpy()
meta_data = pd.read_csv(join(data_dir, 'metadata.csv'), index_col=0)
meta_data = meta_data[['stim', 'predicted.celltype.l1', 'predicted.celltype.l2']].copy()
meta_data['batch'] = meta_data.stim.to_numpy()

train_idx = np.where((meta_data.batch=='Control').to_numpy())[0]
test_idx  = np.where((meta_data.batch=='Stim').to_numpy())[0]

rna_count_mat.shape, adt_count_mat.shape, atac_count_mat.shape, train_idx.size, test_idx.size

Reading `mtx` files...


((13763, 2000), (13763, 210), (13763, 50285), 7624, 6139)

In [6]:
sys.path.append(os.path.join(root_dir, 'src/match_modality/methods/clue/resources'))
import utils

In [7]:
print('Reading `h5ad` files...')
ad_mult_rna = sc.AnnData(sps.csr_matrix(rna_count_mat[train_idx]), obs=meta_data.iloc[train_idx])
ad_mult_adt = sc.AnnData(sps.csr_matrix(adt_count_mat[train_idx]), obs=meta_data.iloc[train_idx])
ad_mult_atac = sc.AnnData(sps.csr_matrix(atac_count_mat[train_idx]), obs=meta_data.iloc[train_idx])

ad_mult_rna.var_names = rna_names
ad_mult_adt.var_names = adt_names
ad_mult_atac.var_names = atac_names

ad_mult_rna.layers["counts"] = ad_mult_rna.X.astype(np.float32)
ad_mult_adt.layers["counts"] = ad_mult_adt.X.astype(np.float32)
ad_mult_atac.layers["counts"] = ad_mult_atac.X.astype(np.float32)

mod1_feature_type = 'GEX'
mod2_feature_type = 'ADT'
mod3_feature_type = 'ATAC'
omics = 'cite'

Reading `h5ad` files...


  ad_mult_rna = sc.AnnData(sps.csr_matrix(rna_count_mat[train_idx]), obs=meta_data.iloc[train_idx])
  ad_mult_adt = sc.AnnData(sps.csr_matrix(adt_count_mat[train_idx]), obs=meta_data.iloc[train_idx])
  ad_mult_atac = sc.AnnData(sps.csr_matrix(atac_count_mat[train_idx]), obs=meta_data.iloc[train_idx])


In [8]:
if omics == "cite":
    n_genes = 5000
    latent_dim = 20
    x2u_h_depth = 2
    x2u_h_dim = 512
    u2x_h_depth = 1
    u2x_h_dim = 128
    du_h_depth = 2
    du_h_dim = 128
    dropout = 0.2
    lam_data = 1.0
    lam_kl = 1.0
    lam_align = 2.0
    lam_cross = 2.0
    lam_cos = 1.0
    normalize_u = True
    random_seed = 5
elif omics == "multiome":
    n_genes = 10000
    latent_dim = 50
    x2u_h_depth = 2
    x2u_h_dim = 512
    u2x_h_depth = 1
    u2x_h_dim = 256
    du_h_depth = 1
    du_h_dim = 256
    dropout = 0.2
    lam_data = 1.0
    lam_kl = 0.3
    lam_align = 0.02
    lam_cross = 1.0
    lam_cos = 0.02
    normalize_u = True
    random_seed = 2

In [9]:
os.makedirs(par['output_pretrain'], exist_ok=True)
with open(os.path.join(par['output_pretrain'], "hyperparams.yaml"), "w") as f:
    yaml.dump({
        "n_genes": n_genes,
        "latent_dim": latent_dim,
        "x2u_h_depth": x2u_h_depth,
        "x2u_h_dim": x2u_h_dim,
        "u2x_h_depth": u2x_h_depth,
        "u2x_h_dim": u2x_h_dim,
        "du_h_depth": du_h_depth,
        "du_h_dim": du_h_dim,
        "dropout": dropout,
        "lam_data": lam_data,
        "lam_kl": lam_kl,
        "lam_align": lam_align,
        "lam_cross": lam_cross,
        "lam_cos": lam_cos,
        "normalize_u": normalize_u,
        "random_seed": random_seed
    }, f)

In [10]:
ad_mult_rna.obs["uid"] = [f"train-{i}" for i in range(ad_mult_rna.shape[0])]
ad_mult_adt.obs["uid"] = [f"train-{i}" for i in range(ad_mult_adt.shape[0])]
ad_mult_atac.obs["uid"] = [f"train-{i}" for i in range(ad_mult_atac.shape[0])]

ad_mult_rna.obs['domain'] = 'gex'
ad_mult_adt.obs['domain'] = 'adt'
ad_mult_atac.obs['domain'] = 'atac'
ad_mult_rna.uns['domain'] = 'gex'
ad_mult_adt.uns['domain'] = 'adt'
ad_mult_atac.uns['domain'] = 'atac'

In [23]:
print('Preprocessing GEX...')
gex_prep = utils.GEXPreprocessing(n_comps=100, n_genes=n_genes, merge_adt=False)
gex_prep.fit_transform(ad_mult_rna)

print('Preprocessing ADT...')
adt_prep = utils.ADTPreprocessing(n_comps=100)

print('Preprocessing ATAC...')
atac_prep = utils.ATACPreprocessing(n_comps=100)
    
adt_prep.fit_transform(ad_mult_adt)
atac_prep.fit_transform(ad_mult_atac)

Preprocessing GEX...
Preprocessing ADT...
Preprocessing ATAC...


In [24]:
with open(os.path.join(par['output_pretrain'], "prep.pickle"), "wb") as f:
    pickle.dump({
        "gex_prep": gex_prep,
        "adt_prep": adt_prep,
        "atac_prep": atac_prep,
    }, f)


In [25]:
scglue.models.configure_dataset(
    ad_mult_rna, "NB", use_highly_variable=True,
    use_layer="counts", use_rep="X_pca",
    use_batch="stim", use_uid="uid"
)
scglue.models.configure_dataset(
    ad_mult_adt, "NB", use_highly_variable=True,
    use_layer="counts", use_rep="X_pca",
    use_batch="stim", use_uid="uid"
)
scglue.models.configure_dataset(
    ad_mult_atac, "NB", use_highly_variable=True,
    use_layer="counts", use_rep="X_lsi",
    use_batch="stim", use_uid="uid"
)

In [26]:
print('Building model...')
model = scglue.models.SCCLUEModel(
    {"gex": ad_mult_rna, "adt": ad_mult_adt, "atac":ad_mult_atac},
    latent_dim=latent_dim,
    x2u_h_depth=x2u_h_depth,
    x2u_h_dim=x2u_h_dim,
    u2x_h_depth=u2x_h_depth,
    u2x_h_dim=u2x_h_dim,
    du_h_depth=du_h_depth,
    du_h_dim=du_h_dim,
    dropout=dropout,
    shared_batches=True,
    random_seed=random_seed
)

training = True

Building model...
[INFO] autodevice: Using GPU 1 as computation device.


In [27]:
# loading pretrained weight
# model = scglue.models.load_model(os.path.join(par['output_pretrain'], "pretrain.dill"))
# training = False

In [28]:
print('Compiling model...')
model.compile(
    lam_data=lam_data, lam_kl=lam_kl, lam_align=lam_align,
    lam_cross=lam_cross, lam_cos=lam_cos, normalize_u=normalize_u,
    domain_weight={"gex": 1, "adt": 1, 'atac':1}
)

Compiling model...


In [29]:
if training:
    print('Training model...')
    model.fit(
        {"gex": ad_mult_rna, "adt": ad_mult_adt, "atac":ad_mult_atac}
    )
    model.save(os.path.join(par['output_pretrain'], "pretrain.dill"))

Training model...
[INFO] SCCLUEModel: Setting `align_burnin` = 112
[INFO] SCCLUEModel: Setting `max_epochs` = 672
[INFO] SCCLUEModel: Setting `patience` = 84
[INFO] SCCLUEModel: Setting `reduce_lr_patience` = 28
[INFO] SCCLUETrainer: Using training directory: "/tmp/GLUETMPzd26uju0"
[INFO] SCCLUETrainer: [Epoch 10] train={'dsc_loss': 1.098, 'gen_loss': 7.076, 'cross_loss': 2.941, 'cos_loss': 0.18, 'x_gex_nll': 0.373, 'x_gex_kl': 0.053, 'x_gex_elbo': 0.426, 'x_adt_nll': 1.651, 'x_adt_kl': 0.208, 'x_adt_elbo': 1.859, 'x_atac_nll': 0.917, 'x_atac_kl': 0.006, 'x_atac_elbo': 0.924}, val={'dsc_loss': 1.1, 'gen_loss': 6.799, 'cross_loss': 2.858, 'cos_loss': 0.167, 'x_gex_nll': 0.364, 'x_gex_kl': 0.052, 'x_gex_elbo': 0.416, 'x_adt_nll': 1.62, 'x_adt_kl': 0.211, 'x_adt_elbo': 1.831, 'x_atac_nll': 0.864, 'x_atac_kl': 0.006, 'x_atac_elbo': 0.87}, 12.5s elapsed
[INFO] SCCLUETrainer: [Epoch 20] train={'dsc_loss': 1.097, 'gen_loss': 5.763, 'cross_loss': 2.514, 'cos_loss': 0.158, 'x_gex_nll': 0.313, '

2023-08-02 13:16:46,717 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


[INFO] EarlyStopping: Restoring checkpoint "305"...
[INFO] EarlyStopping: Restoring checkpoint "305"...
