In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle

import anndata as ad
import numpy as np
import pandas as pd
import yaml
import sys
import scanpy as sc

import scglue
import seaborn as sns

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_dir = '/home/yanxh/gitrepo/multi-omics-matching/neurips2021_multimodal_topmethods-main'

In [3]:
dataset_path = os.path.join(root_dir, 'output/datasets/match_modality/openproblems_bmmc_multiome_phase2_rna/openproblems_bmmc_multiome_phase2_rna.censor_dataset.output_')

par = {
    'input_train_mod1': f'{dataset_path}train_mod1.h5ad',
    'input_train_mod2': f'{dataset_path}train_mod2.h5ad',
    'input_train_sol': f'{dataset_path}train_sol.h5ad',
    'output_pretrain': os.path.join(root_dir, 'output/pretrain/clue/openproblems_bmmc_multiome_phase2_rna.clue_train.output_pretrain/')
}

meta = { 'resources_dir': os.path.join(root_dir, 'src/match_modality/methods/clue/resources') }

In [4]:
sys.path.append(meta['resources_dir'])
import utils

In [5]:
print('Reading `h5ad` files...')
input_train_mod1 = ad.read_h5ad(par['input_train_mod1'])
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
input_train_sol = ad.read_h5ad(par['input_train_sol'])

input_train_mod1.X = input_train_mod1.X.astype(np.float32)
input_train_mod2.X = input_train_mod2.X.astype(np.float32)
input_train_mod1.layers["counts"] = input_train_mod1.layers["counts"].astype(np.float32)
input_train_mod2.layers["counts"] = input_train_mod2.layers["counts"].astype(np.float32)

mod1_feature_type = set(input_train_mod1.var["feature_types"])
mod2_feature_type = set(input_train_mod2.var["feature_types"])
assert len(mod1_feature_type) == len(mod2_feature_type) == 1

Reading `h5ad` files...


In [6]:
input_train_mod1, input_train_mod2

(AnnData object with n_obs × n_vars = 42492 × 13431
     obs: 'batch', 'size_factors'
     var: 'gene_ids', 'feature_types'
     uns: 'dataset_id', 'organism'
     layers: 'counts',
 AnnData object with n_obs × n_vars = 42492 × 116490
     obs: 'batch'
     var: 'feature_types'
     uns: 'dataset_id', 'gene_activity_var_names', 'organism'
     obsm: 'gene_activity'
     layers: 'counts')

In [7]:
mod1_feature_type = mod1_feature_type.pop()
mod2_feature_type = mod2_feature_type.pop()

mod1_feature_type, mod2_feature_type

('GEX', 'ATAC')

In [8]:
if {mod1_feature_type, mod2_feature_type} == {"GEX", "ATAC"}:
    omics = "multiome"
elif {mod1_feature_type, mod2_feature_type} == {"GEX", "ADT"}:
    omics = "cite"
else:
    raise RuntimeError("Unrecognized modality!")

In [9]:
if omics == "cite":
    n_genes = 5000
    latent_dim = 20
    x2u_h_depth = 2
    x2u_h_dim = 512
    u2x_h_depth = 1
    u2x_h_dim = 128
    du_h_depth = 2
    du_h_dim = 128
    dropout = 0.2
    lam_data = 1.0
    lam_kl = 1.0
    lam_align = 2.0
    lam_cross = 2.0
    lam_cos = 1.0
    normalize_u = True
    random_seed = 5
elif omics == "multiome":
    n_genes = 10000
    latent_dim = 50
    x2u_h_depth = 2
    x2u_h_dim = 512
    u2x_h_depth = 1
    u2x_h_dim = 256
    du_h_depth = 1
    du_h_dim = 256
    dropout = 0.2
    lam_data = 1.0
    lam_kl = 0.3
    lam_align = 0.02
    lam_cross = 1.0
    lam_cos = 0.02
    normalize_u = True
    random_seed = 2

In [10]:
os.makedirs(par['output_pretrain'], exist_ok=True)
with open(os.path.join(par['output_pretrain'], "hyperparams.yaml"), "w") as f:
    yaml.dump({
        "n_genes": n_genes,
        "latent_dim": latent_dim,
        "x2u_h_depth": x2u_h_depth,
        "x2u_h_dim": x2u_h_dim,
        "u2x_h_depth": u2x_h_depth,
        "u2x_h_dim": u2x_h_dim,
        "du_h_depth": du_h_depth,
        "du_h_dim": du_h_dim,
        "dropout": dropout,
        "lam_data": lam_data,
        "lam_kl": lam_kl,
        "lam_align": lam_align,
        "lam_cross": lam_cross,
        "lam_cos": lam_cos,
        "normalize_u": normalize_u,
        "random_seed": random_seed
    }, f)

In [11]:
print("Unscrambling training cells...")
ord = input_train_sol.X.tocsr().indices
if "pairing_ix" in input_train_sol.uns:
    assert np.all(ord == np.argsort(input_train_sol.uns["pairing_ix"]))
input_train_mod2 = input_train_mod2[ord, :].copy()
input_train_mod2.obs_names = input_train_mod1.obs_names
input_train_mod1.obs["uid"] = [f"train-{i}" for i in range(input_train_mod1.shape[0])]
input_train_mod2.obs["uid"] = [f"train-{i}" for i in range(input_train_mod2.shape[0])]
assert np.all(input_train_mod1.obs["batch"] == input_train_mod2.obs["batch"])

Unscrambling training cells...


In [12]:
mod1_feature_type

'GEX'

In [13]:
if mod1_feature_type == "GEX":
    gex = input_train_mod1
    other = input_train_mod2
else:
    gex = input_train_mod2
    other = input_train_mod1

In [14]:
print('Preprocessing GEX...')
gex_prep = utils.GEXPreprocessing(n_comps=100, n_genes=n_genes, merge_adt=omics == "cite", scale=True, clip=True)
gex_prep.fit_transform(gex)

if omics == "cite":
    print('Preprocessing ADT...')
    other_prep = utils.ADTPreprocessing(n_comps=100, scale=True, clip=True)
elif omics == "multiome":
    print('Preprocessing ATAC...')
    other_prep = utils.ATACPreprocessing(n_comps=100)
    
other_prep.fit_transform(other, X_lsi=np.load('./cache/clue-multiome/atac_X_lsi.npy'))
# np.save('./cache/clue-multiome/atac_X_lsi.npy', other.obsm['X_lsi'])

Preprocessing GEX...
Preprocessing ATAC...


In [17]:
np.save('./cache/clue-multiome/atac_X_lsi.npy', other.obsm['X_lsi'])

In [19]:
with open(os.path.join(par['output_pretrain'], "prep.pickle"), "wb") as f:
    pickle.dump({
        "gex_prep": gex_prep,
        "other_prep": other_prep
    }, f)


In [24]:
scglue.models.configure_dataset(
    gex, "NB", use_highly_variable=True,
    use_layer="counts", use_rep="X_pca",
    use_batch="batch", use_uid="uid"
)
scglue.models.configure_dataset(
    other, "NB", use_highly_variable=True,
    use_layer="counts", use_rep="X_lsi",
    use_batch="batch", use_uid="uid"
)

In [25]:
print('Building model...')
model = scglue.models.SCCLUEModel(
    {"gex": gex, "other": other},
    latent_dim=latent_dim,
    x2u_h_depth=x2u_h_depth,
    x2u_h_dim=x2u_h_dim,
    u2x_h_depth=u2x_h_depth,
    u2x_h_dim=u2x_h_dim,
    du_h_depth=du_h_depth,
    du_h_dim=du_h_dim,
    dropout=dropout,
    shared_batches=True,
    random_seed=random_seed
)

training = True

Building model...
[INFO] autodevice: Using GPU 1 as computation device.


In [22]:
# loading pretrained weight
model = scglue.models.load_model(os.path.join(par['output_pretrain'], "pretrain.dill"))
training = False

In [26]:
print('Compiling model...')
model.compile(
    lam_data=lam_data, lam_kl=lam_kl, lam_align=lam_align,
    lam_cross=lam_cross, lam_cos=lam_cos, normalize_u=normalize_u,
    domain_weight={"gex": 1, "other": 1}
)

Compiling model...


In [27]:
if training:
    print('Training model...')
    model.fit(
        {"gex": gex, "other": other}
    )
    model.save(os.path.join(par['output_pretrain'], "pretrain.dill"))

Training model...
[INFO] SCCLUEModel: Setting `align_burnin` = 21
[INFO] SCCLUEModel: Setting `max_epochs` = 121
[INFO] SCCLUEModel: Setting `patience` = 16
[INFO] SCCLUEModel: Setting `reduce_lr_patience` = 6
[INFO] SCCLUETrainer: Using training directory: "/tmp/GLUETMPlkt0ojvl"
[INFO] SCCLUETrainer: [Epoch 10] train={'dsc_loss': 0.685, 'gen_loss': 1.313, 'cross_loss': 0.66, 'cos_loss': 0.096, 'x_gex_nll': 0.282, 'x_gex_kl': 0.011, 'x_gex_elbo': 0.285, 'x_other_nll': 0.378, 'x_other_kl': 0.005, 'x_other_elbo': 0.379}, val={'dsc_loss': 0.69, 'gen_loss': 1.301, 'cross_loss': 0.655, 'cos_loss': 0.087, 'x_gex_nll': 0.277, 'x_gex_kl': 0.011, 'x_gex_elbo': 0.28, 'x_other_nll': 0.377, 'x_other_kl': 0.005, 'x_other_elbo': 0.378}, 11.6s elapsed
[INFO] SCCLUETrainer: [Epoch 20] train={'dsc_loss': 0.681, 'gen_loss': 1.307, 'cross_loss': 0.657, 'cos_loss': 0.093, 'x_gex_nll': 0.28, 'x_gex_kl': 0.011, 'x_gex_elbo': 0.283, 'x_other_nll': 0.376, 'x_other_kl': 0.005, 'x_other_elbo': 0.378}, val={'dsc

2023-07-14 12:03:34,697 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


[INFO] EarlyStopping: Restoring checkpoint "57"...
[INFO] EarlyStopping: Restoring checkpoint "57"...
