# Running CBM, CBM-FM and others on sc data

reference configs are in `fm_config`

One can easily override the configs as shown in the cell below

In [2]:
import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf
import conceptlab as clab

In [3]:
with initialize(config_path="../../fm_config/", version_base=None):
    cfg = compose(config_name="general", overrides=["data=kang",
                                                    "model=cbmfm",
                                                     "model.cbm_mod.max_epochs=10",
                                                       "model.num_epochs=10",
                                                       "data/intervention_labels=bcells"])
    print(OmegaConf.to_yaml(cfg))

wandb:
  experiment: conceptlab
  project: conceptlab
  entity: null
seed: 42
intervention_labels:
  label_variable: cell_stim
  hold_out_label: B cells_stim
  mod_label: B cells_ctrl
  concepts_to_flip:
  - stim
  reference:
  - 0
data:
  _target_: conceptlab.data.dataset.InterventionDataset
  data_path: /braid/havivd/scgen/kang_scimilarity_v3_with_concepts.h5ad
  concept_key: concepts
  intervention_labels: ${intervention_labels}
  mmd_label: cell_stim
model:
  _target_: conceptlab.models.cb_fm.CBMFM_MetaTrainer
  num_epochs: 10
  batch_size: 128
  lr: 0.0003
  raw: false
  concept_key: ${data.concept_key}
  num_workers: 4
  obsm_key: X_pca
  z_score: false
  edit: true
  fm_mod_cfg:
    hidden_dim: 1024
    latent_dim: 128
    n_layers: 4
    dropout: 0.1
    p_uncond: 0.0
  cbm_mod:
    _target_: conceptlab.models.cb_vae.CBM_MetaTrainer
    concept_key: ${data.concept_key}
    max_epochs: 10
    log_every_n_steps: 10
    num_workers: 4
    obsm_key: ${model.obsm_key}
    z_score: $

In [None]:
dataset = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

Loading and preprocessing data...
{'label_variable': 'cell_stim', 'hold_out_label': 'B cells_stim', 'mod_label': 'B cells_ctrl', 'concepts_to_flip': ['stim'], 'reference': [0]}
Splitting data for counterfactual experiment...
Train set: 22823 cells
Intervention set: 1306 cells
Ground Truth set: 1231 cells


In [None]:
adata, adata_train, adata_test, adata_inter =  dataset.get_anndatas()
model.train(adata_train.copy())
adata_preds = model.predict_intervention(adata_inter.copy(), hold_out_label = dataset.hold_out_label, concepts_to_flip = dataset.concepts_to_flip)

Training CB-FM model with learned concepts...
Training scCBGM model...
Starting training on cuda for 10 epochs...


Training Progress: 100%|███████████████████████████████████████| 10/10 [00:10<00:00,  1.07s/it, avg_loss=9.841e-02, concept_f1=0.7263, lr=2.91997e-04]


Training finished.
Generating learned concepts from scCBGM...
Starting training on cuda for 10 epochs...


Training Progress: 100%|████████| 10/10 [00:07<00:00,  1.33it/s, avg_loss=4.420e-01, lr=2.91997e-04]


Performing intervention with CB-FM (learned)...
Generating learned concepts from scCBGM...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

In [7]:
x_baseline = adata_train.X
x_target = adata_test.X

In [8]:
mmd_score = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
    x_train = x_baseline,
    x_ivn = adata_preds.X,
    x_target = x_target,
    labels_train = adata_train.obs[dataset.label_variable].values
    )

de_score = clab.evaluation.interventions.evaluate_intervention_DE_with_target(
    x_train = x_baseline,
    x_ivn = adata_preds.X,
    x_target = x_target,
    genes_list = adata_train.var.index
) 

  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)


In [15]:
print(mmd_score)
print(de_score)

{'mmd_ratio': np.float64(0.7381587904372132), 'pre_computed_mmd_train': np.float64(0.015308349343960526)}
{'recall_pos': 0.9700598801814335, 'recall_negs': 0.5445544554185864}
