# Running CBM, CBM-FM and others on sc data

reference configs are in `fm_config`

One can easily override the configs as shown in the cell below

In [2]:
import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf
import conceptlab as clab

In [6]:
with initialize(config_path="../../fm_config/", version_base=None):
    cfg = compose(config_name="general", overrides=["data=kang",
                                                    "model=cbm",
                                                     "model.max_epochs=10",
                                                       "data/intervention_labels=bcells"])
    print(OmegaConf.to_yaml(cfg))

wandb:
  experiment: conceptlab
  project: conceptlab
  entity: null
seed: 42
intervention_labels:
  label_variable: cell_stim
  hold_out_label: B cells_stim
  mod_label: B cells_ctrl
  concepts_to_flip:
  - stim
  reference:
  - 0
data:
  _target_: conceptlab.data.dataset.InterventionDataset
  data_path: /braid/havivd/scgen/kang_scimilarity_v3_with_concepts.h5ad
  concept_key: concepts
  intervention_labels: ${intervention_labels}
  mmd_label: cell_stim
model:
  _target_: conceptlab.models.cb_vae.CBM_MetaTrainer
  concept_key: ${data.concept_key}
  max_epochs: 10
  log_every_n_steps: 10
  num_workers: 4
  obsm_key: X_pca
  z_score: false
  cbm_config:
    has_cbm: true
    lr: 0.0003
    n_layers: 4
    hidden_dim: 1024
    beta: 1.0e-05
    latent_dim: 128
    n_unknown: 128
    min_bottleneck_size: 128
    concepts_hp: 0.01
    orthogonality_hp: 0.5
    use_soft_concepts: false



In [7]:
dataset = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

Loading and preprocessing data...
{'label_variable': 'cell_stim', 'hold_out_label': 'B cells_stim', 'mod_label': 'B cells_ctrl', 'concepts_to_flip': ['stim'], 'reference': [0]}
Splitting data for counterfactual experiment...
Train set: 22823 cells
Intervention set: 1306 cells
Ground Truth set: 1231 cells


In [None]:
adata, adata_train, adata_test, adata_inter =  dataset.get_anndatas()
model.train(adata_train.copy())
adata_preds = model.predict_intervention(adata_inter.copy(), hold_out_label = dataset.hold_out_label, concepts_to_flip = dataset.concepts_to_flip)

Training scCBGM model...
Starting training on cuda for 10 epochs...


Training Progress:  10%|████                                    | 1/10 [00:01<00:15,  1.74s/it, avg_loss=2.318e-01, concept_f1=0.5034, lr=3.00000e-04]

In [7]:
x_baseline = adata_train.X
x_target = adata_test.X

In [8]:
mmd_score = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
    x_train = x_baseline,
    x_ivn = adata_preds.X,
    x_target = x_target,
    labels_train = adata_train.obs[dataset.label_variable].values
    )

de_score = clab.evaluation.interventions.evaluate_intervention_DE_with_target(
    x_train = x_baseline,
    x_ivn = adata_preds.X,
    x_target = x_target,
    genes_list = adata_train.var.index
) 

  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)


In [15]:
print(mmd_score)
print(de_score)

{'mmd_ratio': np.float64(0.7381587904372132), 'pre_computed_mmd_train': np.float64(0.015308349343960526)}
{'recall_pos': 0.9700598801814335, 'recall_negs': 0.5445544554185864}
