# Config

In [28]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" 

In [29]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
import conceptlab as clab
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import anndata as ad
import scanpy as sc
import torch
import scipy.spatial
import matplotlib.patches as mpatches
import string
import colorcet 

from omegaconf import OmegaConf
import pytorch_lightning as pl

In [31]:
DATA_PATH = '/braid/havivd/immune_dictionary/lig_seurat_with_concepts.h5ad'
OBSM_KEY = 'X_pca'
CONCEPT_KEY = 'concepts'
RANDOM_SEED = 0

# Set random seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7f3d111aebf0>

In [32]:
CTYPE = 'T_cell_gd'
PERT = 'IL15'
CTRL = 'PBS'

In [33]:
# --- Plotting Configuration ---
CT_CMAP = {'B_cell': '#1f77b3',
 'Basophil': '#ff7e0e',
 'FRC': '#2ba02b',
 'ILC': '#d62628',
 'LEC': '#9367bc',
 'Langerhans': '#8c564b',
 'Macrophage': '#e277c1',
 'Mast_cell': '#7e7e7e',
 'MigDC': '#bcbc21',
 'Monocyte': '#16bdcf',
 'NK_cell': '#3a0182',
 'Neutrophil': '#004201',
 'T_cell_CD4': '#0fffa8',
 'T_cell_CD8': '#5d003f',
 'T_cell_Ki67': '#bcbcff',
 'T_cell_gd': '#d8afa1',
 'Treg': '#b80080',
 'cDC1': '#004d52',
 'cDC2': '#6b6400',
 'eTAC': '#7c0100',
 'pDC': '#6026ff'}

STIM_CMAP = {'ctrl': '#b80080', 'stim': '#e277c1', 'other': '#7e7e7e'}
# IDENT_CMAP = {
#     'train': '#676765', 'held out for intervention': '#c84639',
#     'held out as GT': '#048757', 'intervened on': '#06d400'
# }

IDENT_CMAP = {
    'train': '#676765', 'intervention': '#c84639','held out for intervention': '#c84639',
    'held out as GT': '#048757', 'intervened on': '#06d400'
}

TITLE_MAP = {'celltype': 'Cell Type', 'stim': 'State', 'ident': 'Split'}


cell_type_clusters = {
    # Lymphoid cells are crucial for the adaptive immune response.
    'Lymphoid': {
        'T_cell': [
            'T_cell_CD4',
            'T_cell_CD8',
            'T_cell_Ki67',
            'T_cell_gd',
            'Treg'
        ],
        'B_cell': [
            'B_cell'
        ],
        'Innate_lymphoid_cell': [
            'ILC',
            'NK_cell'
        ]
    },
    # Myeloid cells are primarily part of the innate immune system.
    'Myeloid': {
        'Dendritic_cell': [
            'cDC1',
            'cDC2',
            'Langerhans',
            'MigDC',
            'pDC'
        ],
        'Monocyte_Macrophage': [
            'Macrophage',
            'Monocyte'
        ],
        'Granulocyte': [
            'Basophil',
            'Mast_cell',
            'Neutrophil'
        ]
    },
    # Stromal cells provide structure and support within tissues.
    'Stromal_Other': {
        'Stromal': [
            'FRC',
            'LEC',
            'eTAC'
        ]
    }
}

# Example of how to access a specific group:
# print(cell_type_clusters['Myeloid']['Granulocyte'])

# --- Look-Up Table (LUT) Generation ---

# Create an LUT to map each specific subtype to its parent group (e.g., 'T_cell_CD4' -> 'T_cell').
# This uses a dictionary comprehension to iterate through the nested structure.
subtype_to_group_lut = {
    subtype: group
    for major_group, sub_groups in cell_type_clusters.items()
    for group, subtypes in sub_groups.items()
    for subtype in subtypes
}

# DATA LOADING AND PREPARATION

In [34]:
def split_data(adata, hold_out_label, mod_label, label_key = 'L2_stim'):
    """
    Splits data into train, intervention, and ground truth sets.

    - Ground Truth: All cells with the `hold_out_label`.
    - Intervention: All cells with the `mod_label`.
    - Train: All remaining cells.
    """
    print("Splitting data with simplified logic...")
    labels = adata.obs[label_key]

    # Define the three disjoint sets based on their labels
    is_test = (labels == hold_out_label)
    is_inter = (labels == mod_label)
    is_train = ~is_test

    # Create AnnData objects for each split
    adata_train = adata[is_train].copy()
    adata_test = adata[is_test].copy()
    adata_inter = adata[is_inter].copy()

    # Store split identifiers in the original object
    ident_vec = np.array(['train'] * len(adata)).astype('<U32')
    ident_vec[is_test] = 'held out as GT'
    ident_vec[is_inter] = 'intervention'
    adata.obs['ident'] = ident_vec
    

    return adata, adata_train, adata_test, adata_inter



# MODELING & PREDICTION METHODS

## Method 1: scCBGM

In [35]:
def train_cbgm(adata_train, concept_key = 'concepts', obsm_key = 'X_pca'):
    """Trains and returns the scCBGM model."""
    print("Training scCBGM model...")

    # --- MODIFICATION START ---
    # Conditionally set the data source and input dimension based on the 'pca' flag
    if(obsm_key != 'X'):
        data_matrix = adata_train.obsm[obsm_key]
    else:
        data_matrix = adata_train.X

    # --- MODIFICATION END ---

    torch.set_flush_denormal(True)
    
    config = OmegaConf.create(dict(
        has_cbm=True, 
        lr=5e-4, 
        hidden_dim=1024, 
        n_layers = 4,
        beta=1e-5,
        input_dim=data_matrix.shape[-1],  # <-- Use the dynamically set input dimension
        latent_dim=128,
        n_concepts=adata_train.obsm[concept_key].shape[1],
        min_bottleneck_size=128, independent_training=True,
        concepts_hp=0.1, orthogonality_hp=0.5, use_soft_concepts=False
    ))
    
    model = clab.models.scCBGM(config)

    model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[concept_key].to_numpy().astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return model

def pred_cbgm(model, adata_inter, ctrl_index, stim_index, concept_key = 'concepts', obsm_key = 'X_pca'):
    """Performs intervention using a trained scCBGM model."""
    print("Performing intervention with scCBGM...")
    if(obsm_key != 'X'):
        x_intervene_on = torch.tensor(adata_inter.obsm[obsm_key], dtype=torch.float32)
    else:
        x_intervene_on = torch.tensor(adata_inter.X, dtype=torch.float32)

    c_intervene_on = adata_inter.obsm[concept_key].to_numpy().astype(np.float32)

    # Define the intervention by creating a mask and new concept values
    mask = torch.zeros(c_intervene_on.shape, dtype=torch.float32)
    mask[:, ctrl_index] = 1
    mask[:, stim_index] = 1 

    inter_concepts = torch.tensor(c_intervene_on, dtype=torch.float32)
    inter_concepts[:, ctrl_index] = 0
    inter_concepts[:, stim_index] = 1 # Set stim concept to 1

    with torch.no_grad():
        inter_preds = model.intervene(x_intervene_on.to('cuda'), mask=mask.to('cuda'), concepts=inter_concepts.to('cuda'))
    
    inter_preds = inter_preds['x_pred'].cpu().numpy()
    
    if(obsm_key != 'X'):
        x_inter_preds = np.zeros_like(adata_inter.X)
    else:
        x_inter_preds = inter_preds

    pred_adata = adata_inter.copy()
    pred_adata.X = x_inter_preds
    pred_adata.obs['ident'] = 'intervened on'
    pred_adata.obs['sample'] = adata_inter.obsm['concepts'].columns[stim_index]
    
    if(obsm_key != 'X'):
        pred_adata.obsm[obsm_key] = inter_preds
    return pred_adata


##  Method 2: Flow Matching with Learned Concepts

In [36]:
def get_learned_concepts(scCBGM_model, adata_full, obsm_key = 'X_pca'):
    """Uses a trained scCBGM to generate learned concepts for all data."""
    print("Generating learned concepts from scCBGM...")

    if(obsm_key != 'X'):
        all_x = torch.tensor(adata_full.obsm[obsm_key], dtype=torch.float32).to('cuda')
    else:
        all_x = torch.tensor(adata_full.X, dtype=torch.float32).to('cuda')

    with torch.no_grad():
        enc = scCBGM_model.encode(all_x)
        adata_full.obsm['scCBGM_concepts_known'] = scCBGM_model.cb_concepts_layers(enc['mu']).cpu().numpy()
        adata_full.obsm['scCBGM_concepts_unknown'] = scCBGM_model.cb_unk_layers(enc['mu']).cpu().numpy()

    adata_full.obsm['scCBGM_concepts'] = np.concatenate([adata_full.obsm['scCBGM_concepts_known'], adata_full.obsm['scCBGM_concepts_unknown']], axis=1)
    return adata_full

def train_cb_fm(adata_train, concpet_key = 'scCBGM_concepts', obsm_key = 'X_pca'):
    """Trains and returns the CB-FM model using learned concepts."""
    print("Training Concept Bottleneck Flow Model")

    if(obsm_key != 'X'):
        data_matrix = adata_train.obsm[obsm_key]
    else:
        data_matrix = adata_train.X
    
    config = dict(
        input_dim=data_matrix.shape[1],
        hidden_dim=1024,
        latent_dim=128,
        n_concepts=adata_train.obsm[concpet_key].shape[1],
        n_layers=4,
        dropout=0.1,
        p_uncond = 0.0)

    fm_model = clab.models.cond_fm.Cond_FM(config=config)

    fm_model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[concpet_key].astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return fm_model


def pred_cb_fm(model, adata_inter, ctrl_index, stim_index, concept_key = 'scCBGM_concepts', obsm_key = 'X_pca', edit = True):
    """Performs intervention using a trained learned-concept CB-FM model."""
    print("Performing intervention with CB-FM (learned)...")
    init_concepts = adata_inter.obsm[concept_key].astype(np.float32)
    
    edit_concepts = init_concepts.copy()
    edit_concepts[:, ctrl_index] = 0 # Set crl concept to 0
    edit_concepts[:, stim_index] = 1 # Set stim concept to 1
    
    if(obsm_key != 'X'):
        x_inter = adata_inter.obsm[obsm_key]
    else:
        x_inter = adata_inter.X
    if(edit):
        inter_preds = model.edit(
                x = torch.from_numpy(x_inter.astype(np.float32)).to('cuda'),
                c = torch.from_numpy(init_concepts.astype(np.float32)).to('cuda'),
                c_prime = torch.from_numpy(edit_concepts.astype(np.float32)).to('cuda'),
                t_edit = 0.2,
                n_steps = 1000,
                w_cfg_forward = 1.0,
                w_cfg_backward = 1.0,
                noise_add = 0.0)
    else:
        inter_preds = model.decode(
                h = torch.from_numpy(edit_concepts.astype(np.float32)).to('cuda'),
                n_steps = 1000,
                w_cfg = 1.0)
        
    
    inter_preds = inter_preds.detach().cpu().numpy()


    if(obsm_key != 'X'):
        x_inter_preds = np.zeros_like(adata_inter.X)
    else:
        x_inter_preds = inter_preds

    pred_adata = adata_inter.copy()
    pred_adata.X = x_inter_preds
    pred_adata.obs['ident'] = 'intervened on'
    pred_adata.obs['sample'] = adata_inter.obsm['concepts'].columns[stim_index]
    
    if(obsm_key != 'X'):
        pred_adata.obsm[obsm_key] = inter_preds
    return pred_adata



## Method 3: Raw Flows



In [37]:
def train_raw_fm(adata_train, concpet_key = 'concepts', obsm_key = 'X_pca'):
    """Trains and returns the CB-FM model using learned concepts."""
    print("Training raw Flow Model")

    if(obsm_key != 'X'):
        data_matrix = adata_train.obsm[obsm_key]
    else:
        data_matrix = adata_train.X
    
    config = dict(
        input_dim=data_matrix.shape[1],
        hidden_dim=1024,
        latent_dim=128,
        n_concepts=adata_train.obsm[concpet_key].to_numpy().shape[1],
        n_layers=4,
        dropout=0.1,
        p_uncond = 0.0)

    fm_model = clab.models.cond_fm.Cond_FM(config=config)

    fm_model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[concpet_key].to_numpy().astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return fm_model

def pred_raw_fm(model, adata_inter, ctrl_index, stim_index, concept_key = 'concepts', obsm_key = 'X_pca', edit = True):
    """Performs intervention using a trained learned-concept CB-FM model."""
    print("Performing intervention with raw FM ")
    init_concepts =  adata_inter.obsm[concept_key].to_numpy().astype(np.float32)

    edit_concepts = init_concepts.copy()
    edit_concepts[:, ctrl_index] = 0 # Set crl concept to 0
    edit_concepts[:, stim_index] = 1 # Set stim concept to 1
    
    if(obsm_key != 'X'):
        x_inter = adata_inter.obsm[obsm_key]
    else:
        x_inter = adata_inter.X
    if(edit):
        inter_preds = model.edit(
                x = torch.from_numpy(x_inter.astype(np.float32)).to('cuda'),
                c = torch.from_numpy(init_concepts.astype(np.float32)).to('cuda'),
                c_prime = torch.from_numpy(edit_concepts.astype(np.float32)).to('cuda'),
                t_edit = 0.2,
                n_steps = 1000,
                w_cfg_forward = 1.0,
                w_cfg_backward = 1.0,
                noise_add = 0.0)
    else:
        inter_preds = model.decode(
                h = torch.from_numpy(edit_concepts.astype(np.float32)).to('cuda'),
                n_steps = 1000,
                w_cfg = 1.0)
        
    
    inter_preds = inter_preds.detach().cpu().numpy()


    if(obsm_key != 'X'):
        x_inter_preds = np.zeros_like(adata_inter.X)
    else:
        x_inter_preds = inter_preds

    pred_adata = adata_inter.copy()
    pred_adata.X = x_inter_preds
    pred_adata.obs['ident'] = 'intervened on'
    pred_adata.obs['sample'] = adata_inter.obsm['concepts'].columns[stim_index]
    
    if(obsm_key != 'X'):
        pred_adata.obsm[obsm_key] = inter_preds
    return pred_adata

# Main

## Proccesing

In [38]:
import sklearn.decomposition

In [39]:
print("Loading and preprocessing data...")
adata = ad.read_h5ad(DATA_PATH)


Loading and preprocessing data...


In [40]:
adata.obs['cell_type_L1'] = adata.obs['celltype'].map(subtype_to_group_lut)
adata.obs['cell_type_L2'] = adata.obs['celltype']

In [41]:
adata.obsm['concepts'] = pd.get_dummies(adata.obs[['sample', 'cell_type_L1']]).astype(np.float32)
adata.obsm['concepts'].columns = [col.replace('sample_', '').replace('cell_type_L1_', '') for col in adata.obsm['concepts'].columns]

In [42]:
adata.obs['L2_stim'] = [ctype + "_" + stim for ctype, stim in zip(adata.obs['cell_type_L2'], adata.obs['sample'])]

In [43]:
hold_out_label = CTYPE + '_' + PERT
mod_label =  CTYPE + '_' + CTRL


adata, adata_train, adata_test, adata_inter = split_data(
    adata, hold_out_label, mod_label, label_key = 'L2_stim'
)


print(f"Train set: {len(adata_train)} cells")
print(f"Intervention set: {len(adata_inter)} cells")
print(f"Ground Truth set: {len(adata_test)} cells")

Splitting data with simplified logic...
Train set: 110278 cells
Intervention set: 100 cells
Ground Truth set: 100 cells


In [44]:
ctrl_index = np.where(adata.obsm['concepts'].columns == CTRL)[0][0]
stim_index = np.where(adata.obsm['concepts'].columns == PERT)[0][0]

In [45]:
adata.uns['pc_transform'] = sklearn.decomposition.PCA(n_components=256).fit(adata_train.X)

for x_data in [adata, adata_train, adata_test, adata_inter]:
    x_data.uns['pc_transform'] = adata.uns['pc_transform']
    x_data.obsm['X_pca'] = x_data.uns['pc_transform'].transform(x_data.X)


## --- Method 1: scCBGM ---


In [46]:
cbgm_model = train_cbgm(adata_train, concept_key = 'concepts', obsm_key = 'X_pca')

Training scCBGM model...
Starting training on cuda for 200 epochs...


Training Progress: 100%|█████████████████████████████████████| 200/200 [20:14<00:00,  6.07s/it, avg_loss=2.341e-01, concept_f1=1.0000, lr=1.64990e-04]

Training finished.





In [47]:
pred_adata_cbgm = pred_cbgm(cbgm_model, adata_inter = adata_inter.copy(), 
                                        ctrl_index = ctrl_index, 
                                        stim_index = stim_index, 
                                        concept_key = 'concepts', obsm_key = 'X_pca')           

Performing intervention with scCBGM...


## --- Method 2: CB-FM with Learned Concepts ---


In [48]:
adata_with_concepts = get_learned_concepts(cbgm_model, adata.copy())

adata_train.obsm['scCBGM_concepts'] = adata_with_concepts[adata_train.obs.index].obsm['scCBGM_concepts']
adata_inter.obsm['scCBGM_concepts'] = adata_with_concepts[adata_inter.obs.index].obsm['scCBGM_concepts']

Generating learned concepts from scCBGM...


In [49]:
cb_fm_model = train_cb_fm(adata_train, concpet_key = 'scCBGM_concepts', obsm_key = 'X_pca')

Training Concept Bottleneck Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [15:02<00:00,  4.51s/it, avg_loss=8.612e-01, lr=1.64990e-04]


In [50]:
pred_adata_fm_edit = pred_cb_fm(cb_fm_model, 
                                adata_inter = adata_inter.copy(), 
                                ctrl_index = ctrl_index,
                                stim_index = stim_index,
                                concept_key = 'scCBGM_concepts', 
                                obsm_key = 'X_pca',
                                edit = True)
pred_adata_fm_guid = pred_cb_fm(cb_fm_model, 
                                adata_inter = adata_inter.copy(), 
                                ctrl_index = ctrl_index,
                                stim_index = stim_index,
                                concept_key = 'scCBGM_concepts', 
                                obsm_key = 'X_pca',
                                edit = False)

Performing intervention with CB-FM (learned)...
Editing from t=1.0 back to t=0.20, then forward with new condition.


                                                                                        

Performing intervention with CB-FM (learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

## --- Method 3: Raw Flows ---

In [51]:
fm_raw_model = train_raw_fm(adata_train.copy(), concpet_key = 'concepts', obsm_key = 'X_pca')

Training raw Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [14:19<00:00,  4.30s/it, avg_loss=1.658e+00, lr=1.64990e-04]


In [52]:
pred_adata_raw_fm_edit = pred_raw_fm(fm_raw_model, 
                                adata_inter = adata_inter.copy(), 
                                ctrl_index = ctrl_index,
                                stim_index = stim_index,
                                concept_key = 'concepts', 
                                obsm_key = 'X_pca',
                                edit = True)
pred_adata_raw_fm_guid = pred_raw_fm(fm_raw_model, 
                                adata_inter = adata_inter.copy(), 
                                ctrl_index = ctrl_index,
                                stim_index = stim_index,
                                concept_key = 'concepts', 
                                obsm_key = 'X_pca',
                                edit = False)

Performing intervention with raw FM 
Editing from t=1.0 back to t=0.20, then forward with new condition.


                                                                                        

Performing intervention with raw FM 
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

In [53]:
for pred_adata, name in zip([pred_adata_cbgm, pred_adata_fm_edit, pred_adata_fm_guid, pred_adata_raw_fm_edit,  pred_adata_raw_fm_guid],
                            ['scCBGM', 'CB-FM (edit)', 'CB-FM (guided)', 'Raw-FM (edit)', 'Raw-FM (guided)']):
    print(name)
    print(clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
        x_train=adata_train.obsm[OBSM_KEY],
        x_ivn=pred_adata.obsm[OBSM_KEY],
        x_target=adata_test.obsm[OBSM_KEY],
        labels_train=adata_train.obs['L2_stim'].values,
        pre_computed_mmd_train=1.0
    ))



scCBGM
Using provided gamma: 0.00390625
{'mmd_ratio': np.float64(0.09866818087487926), 'pre_computed_mmd_train': 1.0}
CB-FM (edit)
Using provided gamma: 0.00390625
{'mmd_ratio': np.float64(0.07654700548599336), 'pre_computed_mmd_train': 1.0}
CB-FM (guided)
Using provided gamma: 0.00390625
{'mmd_ratio': np.float64(0.07922485765142015), 'pre_computed_mmd_train': 1.0}
Raw-FM (edit)
Using provided gamma: 0.00390625
{'mmd_ratio': np.float64(0.053031511095759175), 'pre_computed_mmd_train': 1.0}
Raw-FM (guided)
Using provided gamma: 0.00390625
{'mmd_ratio': np.float64(0.10516365090963684), 'pre_computed_mmd_train': 1.0}
