# Running CBM, CBM-FM and others on sc data

reference configs are in `fm_config`

One can easily override the configs as shown in the cell below

In [2]:
import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf
import conceptlab as clab

In [6]:
with initialize(config_path="../../fm_config/", version_base=None):
    cfg = compose(config_name="general", overrides=["data=kang",
                                                    "model=cbm",
                                                     "model.max_epochs=10",
                                                       "data/intervention_labels=bcells"])
    print(OmegaConf.to_yaml(cfg))

wandb:
  experiment: conceptlab
  project: conceptlab
  entity: null
seed: 42
intervention_labels:
  label_variable: cell_stim
  hold_out_label: B cells_stim
  mod_label: B cells_ctrl
  concepts_to_flip:
  - stim
  reference:
  - 0
data:
  _target_: conceptlab.data.dataset.InterventionDataset
  data_path: /braid/havivd/scgen/kang_scimilarity_v3_with_concepts.h5ad
  concept_key: concepts
  intervention_labels: ${intervention_labels}
  mmd_label: cell_stim
model:
  _target_: conceptlab.models.cb_vae.CBM_MetaTrainer
  concept_key: ${data.concept_key}
  max_epochs: 10
  log_every_n_steps: 10
  num_workers: 4
  obsm_key: X_pca
  z_score: false
  cbm_config:
    has_cbm: true
    lr: 0.0003
    n_layers: 4
    hidden_dim: 1024
    beta: 1.0e-05
    latent_dim: 128
    n_unknown: 128
    min_bottleneck_size: 128
    concepts_hp: 0.01
    orthogonality_hp: 0.5
    use_soft_concepts: false



In [7]:
dataset = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

Loading and preprocessing data...
{'label_variable': 'cell_stim', 'hold_out_label': 'B cells_stim', 'mod_label': 'B cells_ctrl', 'concepts_to_flip': ['stim'], 'reference': [0]}
Splitting data for counterfactual experiment...
Train set: 22823 cells
Intervention set: 1306 cells
Ground Truth set: 1231 cells


In [9]:
adata, adata_train, adata_test, adata_inter =  dataset.get_anndatas()
model.train(adata_train.copy())
adata_preds = model.predict_intervention(adata_inter.copy(), hold_out_label = dataset.hold_out_label, concepts_to_flip = dataset.concepts_to_flip)

Training scCBGM model...
Starting training on cuda for 10 epochs...


Training Progress: 100%|███████████████████████████████████████| 10/10 [00:14<00:00,  1.48s/it, avg_loss=5.085e-02, concept_f1=0.7219, lr=2.91997e-04]

Training finished.
Performing intervention with scCBGM...





In [10]:
if cfg.model.normalize_from_og:
        dataset.normalize_from_og(adata_preds)

if cfg.model.obsm_key == "X_pca":
    x_baseline_rec = adata_train.X
    x_target_rec = adata_test.X
    x_ivn_rec = adata_train.uns["pc_transform"].inverse_transform(adata_preds.obsm["X_pca"])

ConfigAttributeError: Key 'normalize_from_og' is not in struct
    full_key: model.normalize_from_og
    object_type=dict

In [None]:
mmd_score = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
        x_train = adata_train.obsm[cfg.model.obsm_key] if cfg.model.obsm_key !="X" else adata_train.X,
        x_ivn = adata_preds.obsm[cfg.model.obsm_key] if cfg.model.obsm_key !="X" else adata_preds.X,
        x_target = adata_test.obsm[cfg.model.obsm_key] if cfg.model.obsm_key !="X" else adata_test.X,
        labels_train = adata_train.obs[dataset.mmd_label].values
        )
    
# The DE metric is only evaluated in gene space (reconstructions)
de_score = clab.evaluation.interventions.evaluate_intervention_DE_with_target(
    x_train = adata_train.X if cfg.model.obsm_key =="X" else x_baseline_rec,
    x_ivn = adata_preds.X if cfg.model.obsm_key =="X" else x_ivn_rec,
    x_target = adata_test.X if cfg.model.obsm_key =="X" else x_target_rec,
    genes_list = adata_train.var.index.tolist()
) 

In [15]:
print(mmd_score)
print(de_score)

{'mmd_ratio': np.float64(0.7381587904372132), 'pre_computed_mmd_train': np.float64(0.015308349343960526)}
{'recall_pos': 0.9700598801814335, 'recall_negs': 0.5445544554185864}
