# Running CBM, CBM-FM and others on sc data

reference configs are in `fm_config`

One can easily override the configs as shown in the cell below

In [12]:
import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf
import conceptlab as clab

In [None]:
with initialize(config_path="../../fm_config/", version_base=None):
    cfg = compose(config_name="general", overrides=["data=kang",
                                                    "model=cbmfm",
                                                     "model.cbm_mod.max_epochs=10",
                                                       "model.num_epochs=10"])
    print(OmegaConf.to_yaml(cfg))

wandb:
  experiment: conceptlab
  project: conceptlab
  entity: null
seed: 42
intervention_labels:
  hold_out_label: CD4 T cells_stim
  mod_label: CD4 T cells_ctrl
  concepts_to_flip:
  - stim
data:
  _target_: conceptlab.data.dataset.InterventionDataset
  data_path: /braid/havivd/scgen/kang.h5ad
  concept_key: concepts
  label_variable: cell_stim
  intervention_labels: ${intervention_labels}
model:
  _target_: conceptlab.models.cb_fm.CBMFM_MetaTrainer
  num_epochs: 10
  batch_size: 128
  lr: 0.0003
  p_drop: 0.1
  raw: false
  concept_key: ${data.concept_key}
  num_workers: 4
  pca: false
  z_score: false
  fm_mod_cfg:
    hidden_dim: 1024
    latent_dim: 128
    n_layers: 4
    dropout: 0.1
    p_uncond: 0.1
  cbm_mod:
    _target_: conceptlab.models.cb_vae.CBM_MetaTrainer
    concept_key: ${data.concept_key}
    max_epochs: 10
    log_every_n_steps: 10
    num_workers: 4
    pca: ${model.pca}
    z_score: ${model.z_score}
    cbm_config:
      has_cbm: true
      lr: 0.0003
      n_

In [10]:
dataset = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

Loading and preprocessing data...
{'hold_out_label': 'CD4 T cells_stim', 'mod_label': 'CD4 T cells_ctrl', 'concepts_to_flip': ['stim']}
Splitting data for counterfactual experiment...
Train set: 19173 cells
Intervention set: 5154 cells
Ground Truth set: 5091 cells


In [11]:
adata, adata_train, adata_test, adata_inter =  dataset.get_anndatas()
model.train(adata_train.copy())
adata_preds = model.predict_intervention(adata_inter.copy(), hold_out_label = dataset.hold_out_label, concepts_to_flip = dataset.concepts_to_flip)

Training CB-FM model with learned concepts...
Training scCBGM model...
Starting training on cuda for 10 epochs...


Training Progress: 100%|████████| 10/10 [00:11<00:00,  1.13s/it, avg_loss=2.878e-02, lr=2.91997e-04]


Training finished.
Generating learned concepts from scCBGM...
Starting training on cuda for 10 epochs...


Training Progress: 100%|████████| 10/10 [00:08<00:00,  1.16it/s, avg_loss=6.273e-01, lr=2.91997e-04]


Performing intervention with CB-FM (learned)...
Generating learned concepts from scCBGM...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

In [13]:
x_baseline = adata_train.X
x_target = adata_test.X

In [14]:
mmd_score = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
    x_train = x_baseline,
    x_ivn = adata_preds.X,
    x_target = x_target,
    labels_train = adata_train.obs[dataset.label_variable].values
    )

de_score = clab.evaluation.interventions.evaluate_intervention_DE_with_target(
    x_train = x_baseline,
    x_ivn = adata_preds.X,
    x_target = x_target,
    genes_list = adata_train.var.index
) 

  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "logfoldchanges"] = np.log2(


In [15]:
print(mmd_score)
print(de_score)

{'mmd_ratio': np.float64(0.7381587904372132), 'pre_computed_mmd_train': np.float64(0.015308349343960526)}
{'recall_pos': 0.9700598801814335, 'recall_negs': 0.5445544554185864}
