In [1]:
import os
import torch
import numpy as np 
import random
import scanpy as sc
import anndata as ad
from icecream import ic

from scDisInFact import scdisinfact, create_scdisinfact_dataset
from scDisInFact import utils

In [2]:
def set_seed(seed):
    ic('Setting seed to', seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
adata_path = '/data/Experiments/Benchmark/SCDISENTANGLE_REPRODUCE/Datasets/preprocessed_datasets/kang.h5ad'
cov_key = "cell_type"
cond_key = "condition"
ood_cov = "B"
control_name = "control"
stim_name = "stimulated"
vars_to_predict = ['stimulated', 'control']
categorical_attributes = ['condition', 'cell_type'] # Should be in this order: cond, cov
seed_nb = 1
device_nb = 1

In [4]:
device =  torch.device(f"cuda:{device_nb}" if torch.cuda.is_available() else "cpu")

In [5]:
# Set seed
set_seed(seed_nb)

[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;36m'[39m[38;5;36mSetting seed to[39m[38;5;36m'[39m[38;5;245m,[39m[38;5;245m [39m[38;5;247mseed[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m1[39m


In [6]:
# Read adata
adata = sc.read_h5ad(adata_path)

try:
    adata.X = adata.X.toarray()
except:
    print('adata matrix is already in array format')

adata matrix is already in array format


In [7]:
adata.obs['batch_placeholder'] = [0] * adata.shape[0]

In [8]:
adata.obs[f'{cond_key}_org'] = adata.obs[cond_key].copy()

In [9]:
counts = adata.X
meta_cells = adata.obs

# Convert columns to str
for _col in categorical_attributes:
    meta_cells[_col] = meta_cells[_col].astype(str)

In [10]:
# Train and test ood
test_idx = (meta_cells[f'split_{stim_name}_{ood_cov}'] == 'ood')
train_idx = ~test_idx

# Cells used as sources for counterfactual prediction:
# control condition, specific covariate, and in the train split.
input_indices = (
        (meta_cells[cond_key] == control_name)
        & (meta_cells[cov_key] == ood_cov)
        & (meta_cells[f'split_{stim_name}_{ood_cov}'] == 'train')
    )

In [11]:
data_dict = create_scdisinfact_dataset(
        counts[train_idx,:], 
        meta_cells.loc[train_idx,:], 
        condition_key = categorical_attributes, 
        batch_key = "batch_placeholder"
    )

Sanity check...
Finished.
Create scDisInFact datasets...
Finished.


In [12]:
reg_mmd_comm = 1e-4
reg_mmd_diff = 1e-4
reg_kl_comm = 1e-5
reg_kl_diff = 1e-2
reg_class = 1
reg_gl = 1

In [13]:
Ks = [8] + [2] * len(categorical_attributes)

In [14]:
batch_size = 64
nepochs = 100
interval = 10
lr = 5e-4

lambs = [reg_mmd_comm, reg_mmd_diff, reg_kl_comm, reg_kl_diff, reg_class, reg_gl]

In [15]:
model = scdisinfact(
        data_dict = data_dict, 
        Ks = Ks, 
        batch_size = batch_size, 
        interval = interval, 
        lr = lr, 
        reg_mmd_comm = reg_mmd_comm, 
        reg_mmd_diff = reg_mmd_diff, 
        reg_gl = reg_gl, 
        reg_class = reg_class, 
        reg_kl_comm = reg_kl_comm, 
        reg_kl_diff = reg_kl_diff, 
        seed = seed_nb, 
        device = device,
    )

In [16]:
model.train()

scdisinfact(
  (Enc_c): Encoder(
    (fc): FCLayers(
      (fc_layers): Sequential(
        (Layer 0): Sequential(
          (0): Linear(in_features=13404, out_features=128, bias=True)
          (1): None
          (2): None
          (3): ReLU()
          (4): Dropout(p=0.2, inplace=False)
        )
        (Layer 1): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): None
          (2): None
          (3): ReLU()
          (4): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (mean_layer): Linear(in_features=128, out_features=8, bias=True)
    (var_layer): Linear(in_features=128, out_features=8, bias=True)
  )
  (Enc_ds): ModuleList(
    (0-1): 2 x Encoder(
      (fc): FCLayers(
        (fc_layers): Sequential(
          (Layer 0): Sequential(
            (0): Linear(in_features=13404, out_features=128, bias=True)
            (1): None
            (2): None
            (3): ReLU()
            (4): Dropout(p=0.2, inplace=False)
    

In [17]:
losses = model.train_model(nepochs = nepochs, recon_loss = "NB")

Epoch 0, Validating Loss: 4.2188
	 loss reconstruction: 1.94981
	 loss kl comm: 0.18824
	 loss kl diff: 3.77042
	 loss mmd common: 4.33069
	 loss mmd diff: 19.09957
	 loss classification: 2.09541
	 loss group lasso diff: 0.13350
GPU memory usage: 0.000000MB
Epoch 10, Validating Loss: 1.4110
	 loss reconstruction: 0.23145
	 loss kl comm: 21.33522
	 loss kl diff: 21.91511
	 loss mmd common: 13.99707
	 loss mmd diff: 4.32577
	 loss classification: 0.80714
	 loss group lasso diff: 0.15120
GPU memory usage: 0.000000MB
Epoch 20, Validating Loss: 1.2558
	 loss reconstruction: 0.21589
	 loss kl comm: 20.25145
	 loss kl diff: 22.37972
	 loss mmd common: 13.79798
	 loss mmd diff: 3.90825
	 loss classification: 0.66420
	 loss group lasso diff: 0.14994
GPU memory usage: 0.000000MB
Epoch 30, Validating Loss: 1.1807
	 loss reconstruction: 0.20893
	 loss kl comm: 19.94520
	 loss kl diff: 22.70644
	 loss mmd common: 11.80556
	 loss mmd diff: 3.78259
	 loss classification: 0.59683
	 loss group lasso di

In [18]:
_ = model.eval()

In [19]:
meta_input = meta_cells.loc[input_indices, :]
counts_input = counts[input_indices, :]

In [20]:
predict_conds = ['stimulated', 'B']

In [21]:
counts_predict = model.predict_counts(
                input_counts = counts_input,
                meta_cells = meta_input,
                condition_keys = categorical_attributes, 
                batch_key = "batch_placeholder",
                predict_conds = predict_conds, 
                predict_batch = None
            )

adata_pred = sc.AnnData(X = counts_predict, obs=meta_input)

In [22]:
adata_pred.X.max()

np.float32(317.3731)

In [23]:
counts_predict = model.predict_counts(
                input_counts = counts_input,
                meta_cells = meta_input,
                condition_keys = categorical_attributes, 
                batch_key = "batch_placeholder",
                predict_conds = None, 
                predict_batch = None
            )

adata_pred_ctrl = sc.AnnData(X = counts_predict, obs=meta_input)

In [24]:
adata_pred_ctrl.X.max()

np.float32(366.90988)

In [31]:
import scDisInFact
scDisInFact.__version__

'0.1.0'