# Config

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" 

In [3]:
import conceptlab as clab
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import anndata as ad
import scanpy as sc
import scanpy as sc
import torch
import scipy.spatial
import matplotlib
import matplotlib.patches as mpatches
import string

from omegaconf import OmegaConf
import pytorch_lightning as pl

In [4]:
DATA_PATH = '/braid/havivd/immune_dictionary/lig_seurat_with_concepts.h5ad'
OBSM_KEY = 'X_pca'
Z_SCORE = False
CONCEPT_KEY = 'concepts'
RANDOM_SEED = 0

# Set random seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7f12b448ebf0>

In [5]:

import itertools
from collections import defaultdict

# DATA LOADING AND PREPARATION

In [6]:
def split_data(adata, hold_out_label, mod_label, label_key = 'L2_stim'):
    """
    Splits data into train, intervention, and ground truth sets.

    - Ground Truth: All cells with the `hold_out_label`.
    - Intervention: All cells with the `mod_label`.
    - Train: All remaining cells.
    """
    print("Splitting data with simplified logic...")
    labels = adata.obs[label_key]

    # Define the three disjoint sets based on their labels
    is_test = (labels == hold_out_label)
    is_inter = (labels == mod_label)
    is_train = ~is_test

    # Create AnnData objects for each split
    adata_train = adata[is_train].copy()
    adata_test = adata[is_test].copy()
    adata_inter = adata[is_inter].copy()

    # Store split identifiers in the original object
    ident_vec = np.array(['train'] * len(adata)).astype('<U32')
    ident_vec[is_test] = 'held out as GT'
    ident_vec[is_inter] = 'intervention'
    adata.obs['ident'] = ident_vec
    

    return adata, adata_train, adata_test, adata_inter



# MODELING & PREDICTION METHODS

## Method 1: scCBGM

In [7]:
def train_cbgm(adata_train, concept_key = 'concepts', obsm_key = 'X_pca'):
    """Trains and returns the scCBGM model."""
    print("Training scCBGM model...")

    # --- MODIFICATION START ---
    # Conditionally set the data source and input dimension based on the 'pca' flag
    if(obsm_key != 'X'):
        data_matrix = adata_train.obsm[obsm_key]
    else:
        data_matrix = adata_train.X
    # --- MODIFICATION END ---

    torch.set_flush_denormal(True)

    config = OmegaConf.create(dict(
        has_cbm=True, 
        lr=5e-4, 
        hidden_dim=1024, 
        n_layers = 4,
        beta=1e-5,
        input_dim=data_matrix.shape[-1],  # <-- Use the dynamically set input dimension
        latent_dim=128,
        n_concepts=adata_train.obsm[concept_key].shape[1],
        n_unknown=128, 
        concepts_hp=0.1, 
        orthogonality_hp=0.5, 
        use_soft_concepts=False
    ))
    model = clab.models.scCBGM(config)

    model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[concept_key].to_numpy().astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return model


def pred_cbgm(model, adata_inter,  ctrl_index, stim_index, concept_key = 'concepts', obsm_key = 'X_pca'):
    """Performs intervention using a trained scCBGM model."""
    print("Performing intervention with scCBGM...")
    if(obsm_key != 'X'):
        x_intervene_on = torch.tensor(adata_inter.obsm[obsm_key], dtype=torch.float32)
    else:
        x_intervene_on = torch.tensor(adata_inter.X, dtype=torch.float32)
    c_intervene_on = adata_inter.obsm[concept_key].to_numpy().astype(np.float32)

    # Define the intervention by creating a mask and new concept values
    mask = torch.zeros(c_intervene_on.shape, dtype=torch.float32)
    mask[:, ctrl_index] = 1  # Intervene on the last concept (stim)
    mask[:, stim_index] = 1  # Intervene on the last concept (stim)
    
    inter_concepts = torch.tensor(c_intervene_on, dtype=torch.float32)
    inter_concepts[:, ctrl_index] = 0 # Set ctrl concept to 0
    inter_concepts[:, stim_index] = 1 # Set stim concept to 1

    with torch.no_grad():
        inter_preds = model.intervene(x_intervene_on.to('cuda'), mask=mask.to('cuda'), concepts=inter_concepts.to('cuda'))
    
    inter_preds = inter_preds['x_pred'].cpu().numpy()

    if(obsm_key != 'X'):
        x_inter_preds = np.zeros_like(adata_inter.X)
    else:
        x_inter_preds = inter_preds

    pred_adata = adata_inter.copy()
    pred_adata.X = x_inter_preds
    pred_adata.obs['ident'] = 'intervened on'

    if(obsm_key != 'X'):
        pred_adata.obsm[obsm_key] = inter_preds
    return pred_adata


##  Method 2: Flow Matching with Learned Concepts

In [8]:
def get_learned_concepts(scCBGM_model, adata_full, obsm_key = 'X_pca'):
    """Uses a trained scCBGM to generate learned concepts for all data."""
    print("Generating learned concepts from scCBGM...")

    if(obsm_key != 'X'):
        all_x = torch.tensor(adata_full.obsm[obsm_key], dtype=torch.float32).to('cuda')
    else:
        all_x = torch.tensor(adata_full.X, dtype=torch.float32).to('cuda')

    with torch.no_grad():
        enc = scCBGM_model.encode(all_x)
        adata_full.obsm['scCBGM_concepts_known'] = scCBGM_model.cb_concepts_layers(enc['mu']).cpu().numpy()
        adata_full.obsm['scCBGM_concepts_unknown'] = scCBGM_model.cb_unk_layers(enc['mu']).cpu().numpy()

    adata_full.obsm['scCBGM_concepts'] = np.concatenate([adata_full.obsm['scCBGM_concepts_known'], adata_full.obsm['scCBGM_concepts_unknown']], axis=1)
    return adata_full

def train_cb_fm(adata_train, concept_key = 'scCBGM_concepts', obsm_key = 'X_pca'):
    """Trains and returns the CB-FM model using learned concepts."""
    print("Training Concept Bottleneck Flow Model")

    if(obsm_key != 'X'):
        data_matrix = adata_train.obsm[obsm_key]
    else:
        data_matrix = adata_train.X
    
    config = dict(
        input_dim=data_matrix.shape[1],
        hidden_dim=1024,
        latent_dim=128,
        n_concepts=adata_train.obsm[concept_key].shape[1],
        n_layers=4,
        dropout=0.1,
        p_uncond = 0.0)

    fm_model = clab.models.cond_fm.Cond_FM(config=config)

    fm_model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[concept_key].astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return fm_model



def pred_cb_fm(model, adata_inter, ctrl_index, stim_index, concept_key = 'scCBGM_concepts', obsm_key = 'X_pca', edit = True):
    """Performs intervention using a trained learned-concept CB-FM model."""
    print("Performing intervention with CB-FM (learned)...")

    if(obsm_key != 'X'):
        x_inter = adata_inter.obsm[obsm_key]
    else:
        x_inter = adata_inter.X
    

    init_concepts = adata_inter.obsm[concept_key].astype(np.float32)

    edit_concepts = init_concepts.copy()
    edit_concepts[:, ctrl_index] = 0 # Set ctrl concept to 0
    edit_concepts[:, stim_index] = 1 # Set stim concept to 1


    if(edit):
        inter_preds = model.edit(
                x = torch.from_numpy(x_inter.astype(np.float32)).to('cuda'),
                c = torch.from_numpy(init_concepts.astype(np.float32)).to('cuda'),
                c_prime = torch.from_numpy(edit_concepts.astype(np.float32)).to('cuda'),
                t_edit = 0.0,
                n_steps = 1000,
                w_cfg_forward = 1.0,
                w_cfg_backward = 1.0,
                noise_add = 0.0)
    else:
        inter_preds = model.decode(
                h = torch.from_numpy(edit_concepts.astype(np.float32)).to('cuda'),
                n_steps = 1000,
                w_cfg = 1.0)
        
    inter_preds = inter_preds.detach().cpu().numpy()

    if(obsm_key != 'X'):
        x_inter_preds = np.zeros_like(adata_inter.X)
    else:
        x_inter_preds = inter_preds

    pred_adata = adata_inter.copy()
    pred_adata.X = x_inter_preds
    pred_adata.obs['ident'] = 'intervened on'

    if(obsm_key != 'X'):
        pred_adata.obsm[obsm_key] = inter_preds
    return pred_adata



## Method 3: Flow Matching with Raw Concepts 

In [9]:
def train_raw_fm(adata_train, concept_key = 'concepts', obsm_key = 'X_pca'):
    """Trains and returns the CB-FM model using learned concepts."""
    print("Training Conditonal Flow Model")

    if(obsm_key != 'X'):
        data_matrix = adata_train.obsm[obsm_key]
    else:
        data_matrix = adata_train.X
    
    config = dict(
        input_dim=data_matrix.shape[1],
        hidden_dim=1024,
        latent_dim=128,
        n_concepts=adata_train.obsm[concept_key].to_numpy().shape[1],
        n_layers=4,
        dropout=0.1,
        p_uncond = 0.0)

    fm_model = clab.models.cond_fm.Cond_FM(config=config)

    fm_model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[concept_key].to_numpy().astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return fm_model



def pred_raw_fm(model, adata_inter, ctrl_index, stim_index, concept_key = 'concepts', obsm_key = 'X_pca', edit = False):
    """Performs intervention using a trained learned-concept CB-FM model."""
    print("Performing intervention with Raw Flow Matching(learned)...")

    if(obsm_key != 'X'):
        x_inter = adata_inter.obsm[obsm_key]
    else:
        x_inter = adata_inter.X
    

    init_concepts = adata_inter.obsm[concept_key].to_numpy().astype(np.float32)

    edit_concepts = init_concepts.copy()
    edit_concepts[:, ctrl_index] = 0 # Set ctrl concept to 0
    edit_concepts[:, stim_index] = 1 # Set stim concept to 1
    


    if(edit):
        inter_preds = model.edit(
                x = torch.from_numpy(x_inter.astype(np.float32)).to('cuda'),
                c = torch.from_numpy(init_concepts).to('cuda'),
                c_prime = torch.from_numpy(edit_concepts).to('cuda'),
                t_edit = 0.0,
                n_steps = 1000,
                w_cfg_forward = 1.0,
                w_cfg_backward = 1.0,
                noise_add = 0.0)
    else:
        inter_preds = model.decode(
                h = torch.from_numpy(edit_concepts).to('cuda'),
                n_steps = 1000,
                w_cfg = 1.0)
    
    inter_preds = inter_preds.detach().cpu().numpy()

    if(obsm_key != 'X'):
        x_inter_preds = np.zeros_like(adata_inter.X)
    else:
        x_inter_preds = inter_preds

    pred_adata = adata_inter.copy()
    pred_adata.X = x_inter_preds
    pred_adata.obs['ident'] = 'intervened on'

    if(obsm_key != 'X'):
        pred_adata.obsm[obsm_key] = inter_preds
    return pred_adata




# Method 4: CB VAE FM

In [10]:
def train_cb_fm_vae(adata_train, concept_key = 'concepts', obsm_key = 'X_pca', kl_hp = 0.1, concepts_hp = 0.2, orthogonality_hp = 0.5):
    """Trains and returns the scCBGM model."""
    print("Training Concept Flow Model")

    if(obsm_key != 'X'):
        data_matrix = adata_train.obsm[obsm_key]
    else:
        data_matrix = adata_train.X

    config = dict(
        input_dim=data_matrix.shape[1],
        hidden_dim=1024,
        latent_dim=128,
        n_concepts=adata_train.obsm[concept_key].to_numpy().shape[1],
        n_unknown=128,
        n_layers=4,
        dropout=0.1,
        p_uncond=0.0,
        unknown_activation = 'relu',
        kl_hp = kl_hp,
        concepts_hp=concepts_hp, 
        orthogonality_hp=orthogonality_hp,
        flow_hp = 1.0)

    fm_model = clab.models.concept_fm.Concept_FM(config=config)

    fm_model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[concept_key].to_numpy().astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return fm_model


def pred_cb_fm_vae(model, adata_inter, ctrl_index, stim_index, concept_key = 'concepts', obsm_key = 'X_pca'):
    """Performs intervention using a trained CB_FM_VAE model."""
    print("Performing intervention with CB_FM_VAE...")
    if(obsm_key != 'X'):
        x_intervene_on = torch.tensor(adata_inter.obsm[obsm_key], dtype=torch.float32)
    else:
        x_intervene_on = torch.tensor(adata_inter.X, dtype=torch.float32)
    c_intervene_on = adata_inter.obsm[concept_key].to_numpy().astype(np.float32)

    # Define the intervention by creating a mask and new concept values
    mask = torch.zeros(c_intervene_on.shape, dtype=torch.float32)
    mask[:, ctrl_index] = 1  # Intervene on the last concept (stim)
    mask[:, stim_index] = 1  # Intervene on the last concept (stim)
    
    inter_concepts = torch.tensor(c_intervene_on, dtype=torch.float32)
    inter_concepts[:, ctrl_index] = 0 # Set ctrl concept to 0
    inter_concepts[:, stim_index] = 1 # Set stim concept to 1

    with torch.no_grad():
        inter_preds = model.intervene(x_intervene_on.to('cuda'), mask=mask.to('cuda'), concepts=inter_concepts.to('cuda'))
    
    inter_preds = inter_preds.cpu().numpy()

    if(obsm_key != 'X'):
        x_inter_preds = np.zeros_like(adata_inter.X)
    else:
        x_inter_preds = inter_preds

    pred_adata = adata_inter.copy()
    pred_adata.X = x_inter_preds
    pred_adata.obs['ident'] = 'intervened on'

    if(obsm_key != 'X'):
        pred_adata.obsm[obsm_key] = inter_preds
    return pred_adata


# Main

## Proccesing

In [11]:
import sklearn.decomposition


print("Loading and preprocessing data...")
adata = ad.read_h5ad(DATA_PATH)

# adata.obs['L2_stim'] = [l1_ctype + '_' + stim for l1_ctype, stim in zip(adata.obs['cell_types_L2'], adata.obs['stim'])]

Loading and preprocessing data...


In [30]:
adata.obsm['concepts']

Unnamed: 0,41BBL,APRIL,Adiponectin,BAFF,C3a,C5a,CD27L,CD30L,CD40L,Cardiotrophin-1,...,TSLP,TWEAK,VEGF,B_cell,Dendritic_cell,Granulocyte,Innate_lymphoid_cell,Monocyte_Macrophage,Stromal,T_cell
CATAGACGTAACGATA-18,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
CGTGAATCAGCGAGTA-10,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
ACGGTCGGTAACCCTA-10,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
ACTTTGTGTACGAGTG-42,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
GTGTAACCATGACTTG-18,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CATACTTGTAGTCACT-40,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
GGAACCCAGACATAAC-40,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
ATGGGAGGTAAGATTG-44,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
ATGGATCCAGTCGGAA-40,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


In [29]:
adata.write_h5ad('/braid/havivd/immune_dictionary/lig_seurat_with_concepts_v2.h5ad')

In [12]:
# --- Plotting Configuration ---
CT_CMAP = {'B_cell': '#1f77b3',
 'Basophil': '#ff7e0e',
 'FRC': '#2ba02b',
 'ILC': '#d62628',
 'LEC': '#9367bc',
 'Langerhans': '#8c564b',
 'Macrophage': '#e277c1',
 'Mast_cell': '#7e7e7e',
 'MigDC': '#bcbc21',
 'Monocyte': '#16bdcf',
 'NK_cell': '#3a0182',
 'Neutrophil': '#004201',
 'T_cell_CD4': '#0fffa8',
 'T_cell_CD8': '#5d003f',
 'T_cell_Ki67': '#bcbcff',
 'T_cell_gd': '#d8afa1',
 'Treg': '#b80080',
 'cDC1': '#004d52',
 'cDC2': '#6b6400',
 'eTAC': '#7c0100',
 'pDC': '#6026ff'}

STIM_CMAP = {'ctrl': '#b80080', 'stim': '#e277c1', 'other': '#7e7e7e'}
# IDENT_CMAP = {
#     'train': '#676765', 'held out for intervention': '#c84639',
#     'held out as GT': '#048757', 'intervened on': '#06d400'
# }

IDENT_CMAP = {
    'train': '#676765', 'intervention': '#c84639','held out for intervention': '#c84639',
    'held out as GT': '#048757', 'intervened on': '#06d400'
}

TITLE_MAP = {'celltype': 'Cell Type', 'stim': 'State', 'ident': 'Split'}


cell_type_clusters = {
    # Lymphoid cells are crucial for the adaptive immune response.
    'Lymphoid': {
        'T_cell': [
            'T_cell_CD4',
            'T_cell_CD8',
            'T_cell_Ki67',
            'T_cell_gd',
            'Treg'
        ],
        'B_cell': [
            'B_cell'
        ],
        'Innate_lymphoid_cell': [
            'ILC',
            'NK_cell'
        ]
    },
    # Myeloid cells are primarily part of the innate immune system.
    'Myeloid': {
        'Dendritic_cell': [
            'cDC1',
            'cDC2',
            'Langerhans',
            'MigDC',
            'pDC'
        ],
        'Monocyte_Macrophage': [
            'Macrophage',
            'Monocyte'
        ],
        'Granulocyte': [
            'Basophil',
            'Mast_cell',
            'Neutrophil'
        ]
    },
    # Stromal cells provide structure and support within tissues.
    'Stromal_Other': {
        'Stromal': [
            'FRC',
            'LEC',
            'eTAC'
        ]
    }
}

# Example of how to access a specific group:
# print(cell_type_clusters['Myeloid']['Granulocyte'])

# --- Look-Up Table (LUT) Generation ---

# Create an LUT to map each specific subtype to its parent group (e.g., 'T_cell_CD4' -> 'T_cell').
# This uses a dictionary comprehension to iterate through the nested structure.
subtype_to_group_lut = {
    subtype: group
    for major_group, sub_groups in cell_type_clusters.items()
    for group, subtypes in sub_groups.items()
    for subtype in subtypes 
}

In [13]:
adata.obs['cell_type_L1'] = adata.obs['celltype'].map(subtype_to_group_lut)
adata.obs['cell_type_L2'] = adata.obs['celltype']

adata.obsm['concepts'] = pd.get_dummies(adata.obs[['sample', 'cell_type_L1']]).astype(np.float32)
adata.obsm['concepts'].columns = [col.replace('sample_', '').replace('cell_type_L1_', '') for col in adata.obsm['concepts'].columns]



In [14]:
adata.obs['L1_stim'] = [ctype + "_" + stim for ctype, stim in zip(adata.obs['cell_type_L1'], adata.obs['sample'])]
adata.obs['L2_stim'] = [ctype + "_" + stim for ctype, stim in zip(adata.obs['cell_type_L2'], adata.obs['sample'])]

In [15]:
set(adata.obs['cell_type_L2'])

{'B_cell',
 'Basophil',
 'FRC',
 'ILC',
 'LEC',
 'Langerhans',
 'Macrophage',
 'Mast_cell',
 'MigDC',
 'Monocyte',
 'NK_cell',
 'Neutrophil',
 'T_cell_CD4',
 'T_cell_CD8',
 'T_cell_Ki67',
 'T_cell_gd',
 'Treg',
 'cDC1',
 'cDC2',
 'eTAC',
 'pDC'}

In [None]:
# val = adata.obs['L2_stim'].value_counts()

# for ind in val.index:
#     print(ind, val[ind])

# Run Benchmarking

In [None]:
# --- Results Storage Initialization ---
benchmark_instances = [{'cell_type': 'T_cell_CD4', 'ctrl': 'PBS', 'stim': 'TGF-beta-1'},
                       {'cell_type': 'T_cell_CD8', 'ctrl': 'PBS', 'stim': 'TNFa'},
                       {'cell_type': 'T_cell_gd', 'ctrl': 'PBS', 'stim': 'IL17E'},
                       {'cell_type': 'NK_cell', 'ctrl': 'PBS', 'stim': 'IL15'},
                       {'cell_type': 'Macrophage', 'ctrl': 'PBS', 'stim': 'M-CSF'},
                       {'cell_type': 'cDC2', 'ctrl': 'PBS', 'stim': 'IFNa1'},
                       {'cell_type': 'Langerhans', 'ctrl': 'PBS', 'stim': 'IFNg'}]

benchmark_results = defaultdict(lambda: defaultdict(list))



## Main Benchmark Loop
# =================================================================
# In this loop, we evaluate all models with their default parameters
# across all cell types to get a baseline performance comparison.
# =================================================================

print("--- Starting Main Benchmark Across All Labels ---")
for instance in benchmark_instances:

    hold_out_label = instance['cell_type'] + '_' + instance['stim']
    mod_label = instance['cell_type'] +'_' + instance['ctrl']

    adata, adata_train, adata_test, adata_inter = split_data(
        adata, hold_out_label, mod_label, label_key = 'L2_stim'
    )

    stim_index = np.where(adata.obsm['concepts'].columns == instance['stim'])[0][0]

    ctrl_index = np.where(np.isin(adata.obsm['concepts'].columns, list(set(adata.obs['sample']))))[0]
    ctrl_index = np.delete(ctrl_index, np.where(ctrl_index == stim_index)[0][0])

    print(f"\n--- Processing Label: {hold_out_label} ---")
    print(f"Train set: {len(adata_train)} cells, Test set: {len(adata_test)} cells, Intervention set: {len(adata_inter)} cells")

    # PCA Transformation
    adata.uns['pc_transform'] = sklearn.decomposition.PCA(n_components=128).fit(adata_train.X)
    for x_data in [adata, adata_train, adata_test, adata_inter]:
        x_data.uns['pc_transform'] = adata.uns['pc_transform']
        x_data.obsm['X_pca'] = x_data.uns['pc_transform'].transform(x_data.X)

    # --- Train All Models ---

    print("Training and Evaluating CBGM")

    cbgm_model = train_cbgm(adata_train.copy())
    pred_adata_cbgm = pred_cbgm(cbgm_model, adata_inter.copy(), ctrl_index, stim_index)

    print("Training and Evaluating FM")

    adata_with_concepts = get_learned_concepts(cbgm_model, adata.copy())
    adata_train.obsm['scCBGM_concepts'] = adata_with_concepts[adata_train.obs.index].obsm['scCBGM_concepts']
    adata_inter.obsm['scCBGM_concepts'] = adata_with_concepts[adata_inter.obs.index].obsm['scCBGM_concepts']

    
    cb_fm_model = train_cb_fm(adata_train.copy())
    pred_adata_fm_edit = pred_cb_fm(cb_fm_model, adata_inter.copy(), ctrl_index, stim_index, edit = True)
    pred_adata_fm_guid = pred_cb_fm(cb_fm_model, adata_inter.copy(), ctrl_index, stim_index, edit = False)

    print("Training and Evaluating Raw FM")
    
    fm_raw_model = train_raw_fm(adata_train.copy())
    pred_adata_raw_fm_edit = pred_raw_fm(fm_raw_model, adata_inter.copy(), ctrl_index, stim_index, edit = True)
    pred_adata_raw_fm_guid = pred_raw_fm(fm_raw_model, adata_inter.copy(), ctrl_index, stim_index, edit = False)

    print("Training and Evaluating Concept Flow VAE")

    # Train VAE model with default hyperparameters
    cb_fm_vae_model = train_cb_fm_vae(adata_train.copy(), kl_hp = 0.1, concepts_hp = 0.2, orthogonality_hp = 0.5)
    pred_adata_cb_fm_vae = pred_cb_fm_vae(cb_fm_vae_model, adata_inter.copy(), ctrl_index, stim_index)

    # --- Benchmark All Models for this Label ---
    pre_computed_mmd_train = -1
    all_models = [pred_adata_cbgm, pred_adata_fm_edit, pred_adata_fm_guid, pred_adata_raw_fm_edit,  pred_adata_raw_fm_guid, pred_adata_cb_fm_vae]
    all_names = ['scCBGM', 'CB-FM (edit)', 'CB-FM (guided)', 'Raw-FM (edit)', 'Raw-FM (guided)', 'CB-FM (VAE)']

    for pred_adata, name in zip(all_models, all_names):
        # Check if the baseline MMD has been computed for this label yet
        if pre_computed_mmd_train == -1:
            # First run: compute everything and store the baseline MMD
            evaluation_output = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
                x_train=adata_train.obsm[OBSM_KEY], 
                x_ivn=pred_adata.obsm[OBSM_KEY], 
                x_target=adata_test.obsm[OBSM_KEY],
                labels_train=adata_train.obs['L1_stim'].values
            )
            mmd_ratio = evaluation_output['mmd_ratio']
            pre_computed_mmd_train = evaluation_output['pre_computed_mmd_train']
        else:
            # Subsequent runs: provide the pre-computed baseline MMD to save time
            evaluation_output = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
                x_train=adata_train.obsm[OBSM_KEY], 
                x_ivn=pred_adata.obsm[OBSM_KEY], 
                x_target=adata_test.obsm[OBSM_KEY],
                labels_train=adata_train.obs['L1_stim'].values,
                pre_computed_mmd_train=pre_computed_mmd_train
            )
            mmd_ratio = evaluation_output['mmd_ratio']

        # Now, calculate true MMD and store results
        true_mmd = mmd_ratio * pre_computed_mmd_train
        
        benchmark_results[name]['mmd_ratio'].append(mmd_ratio)
        benchmark_results[name]['true_mmd'].append(true_mmd)
        print(f"  > {name}: mmd_ratio={mmd_ratio:.4f}")

print("\n--- Main Benchmark Finished ---")

--- Starting Main Benchmark Across All Labels ---
Splitting data with simplified logic...

--- Processing Label: T_cell_CD4_TGF-beta-1 ---
Train set: 110278 cells, Test set: 100 cells, Intervention set: 100 cells
Training and Evaluating CBGM
Training scCBGM model...
Starting training on cuda for 200 epochs...


Training Progress: 100%|█████████████████████████████████████| 200/200 [13:14<00:00,  3.97s/it, avg_loss=9.178e-02, concept_f1=0.9984, lr=1.64990e-04]


Training finished.
Performing intervention with scCBGM...
Training and Evaluating FM
Generating learned concepts from scCBGM...
Training Concept Bottleneck Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [07:59<00:00,  2.40s/it, avg_loss=4.984e-01, lr=1.64990e-04]


Performing intervention with CB-FM (learned)...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

Performing intervention with CB-FM (learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

Training and Evaluating Raw FM
Training Conditonal Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [08:07<00:00,  2.44s/it, avg_loss=1.948e+00, lr=1.64990e-04]


Performing intervention with Raw Flow Matching(learned)...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

Performing intervention with Raw Flow Matching(learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

Training and Evaluating Concept Flow VAE
Training Concept Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|███████████████████████████████| 200/200 [20:27<00:00,  6.14s/it, avg_loss=4.328e-01, fm_loss=1.311e-01, concept_f1=9.700e-01]


Performing intervention with CB_FM_VAE...
Decoding with 1000 steps and cfg_strength 1.0


                                                                                        

  > scCBGM: mmd_ratio=0.0689
  > CB-FM (edit): mmd_ratio=0.0478
  > CB-FM (guided): mmd_ratio=0.0661
  > Raw-FM (edit): mmd_ratio=0.1024
  > Raw-FM (guided): mmd_ratio=1.7863
  > CB-FM (VAE): mmd_ratio=4.8680
Splitting data with simplified logic...

--- Processing Label: T_cell_CD8_TNFa ---
Train set: 110278 cells, Test set: 100 cells, Intervention set: 100 cells
Training and Evaluating CBGM
Training scCBGM model...
Starting training on cuda for 200 epochs...


Training Progress: 100%|█████████████████████████████████████| 200/200 [13:16<00:00,  3.98s/it, avg_loss=8.814e-02, concept_f1=0.9948, lr=1.64990e-04]


Training finished.
Performing intervention with scCBGM...
Training and Evaluating FM
Generating learned concepts from scCBGM...
Training Concept Bottleneck Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [08:04<00:00,  2.42s/it, avg_loss=4.771e-01, lr=1.64990e-04]


Performing intervention with CB-FM (learned)...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

Performing intervention with CB-FM (learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

Training and Evaluating Raw FM
Training Conditonal Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [08:14<00:00,  2.47s/it, avg_loss=1.926e+00, lr=1.64990e-04]


Performing intervention with Raw Flow Matching(learned)...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

Performing intervention with Raw Flow Matching(learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

Training and Evaluating Concept Flow VAE
Training Concept Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|███████████████████████████████| 200/200 [20:43<00:00,  6.22s/it, avg_loss=4.346e-01, fm_loss=1.294e-01, concept_f1=9.675e-01]


Performing intervention with CB_FM_VAE...
Decoding with 1000 steps and cfg_strength 1.0


                                                                                        

  > scCBGM: mmd_ratio=0.2633
  > CB-FM (edit): mmd_ratio=0.2460
  > CB-FM (guided): mmd_ratio=0.2419
  > Raw-FM (edit): mmd_ratio=0.1705
  > Raw-FM (guided): mmd_ratio=1.8172
  > CB-FM (VAE): mmd_ratio=4.5972
Splitting data with simplified logic...

--- Processing Label: T_cell_gd_IL17E ---
Train set: 110278 cells, Test set: 100 cells, Intervention set: 100 cells
Training and Evaluating CBGM
Training scCBGM model...
Starting training on cuda for 200 epochs...


Training Progress: 100%|█████████████████████████████████████| 200/200 [13:26<00:00,  4.03s/it, avg_loss=7.915e-02, concept_f1=0.9916, lr=1.64990e-04]


Training finished.
Performing intervention with scCBGM...
Training and Evaluating FM
Generating learned concepts from scCBGM...
Training Concept Bottleneck Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [08:06<00:00,  2.43s/it, avg_loss=4.340e-01, lr=1.64990e-04]


Performing intervention with CB-FM (learned)...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

Performing intervention with CB-FM (learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

Training and Evaluating Raw FM
Training Conditonal Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [08:18<00:00,  2.49s/it, avg_loss=1.933e+00, lr=1.64990e-04]


Performing intervention with Raw Flow Matching(learned)...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

Performing intervention with Raw Flow Matching(learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

Training and Evaluating Concept Flow VAE
Training Concept Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|███████████████████████████████| 200/200 [20:42<00:00,  6.21s/it, avg_loss=4.428e-01, fm_loss=1.321e-01, concept_f1=9.648e-01]


Performing intervention with CB_FM_VAE...
Decoding with 1000 steps and cfg_strength 1.0


                                                                                        

  > scCBGM: mmd_ratio=0.2314
  > CB-FM (edit): mmd_ratio=0.2002
  > CB-FM (guided): mmd_ratio=0.2125
  > Raw-FM (edit): mmd_ratio=1.1409
  > Raw-FM (guided): mmd_ratio=1.7329
  > CB-FM (VAE): mmd_ratio=2.1640
Splitting data with simplified logic...

--- Processing Label: NK_cell_IL15 ---
Train set: 110278 cells, Test set: 100 cells, Intervention set: 100 cells
Training and Evaluating CBGM
Training scCBGM model...
Starting training on cuda for 200 epochs...


Training Progress: 100%|█████████████████████████████████████| 200/200 [13:20<00:00,  4.00s/it, avg_loss=1.035e-01, concept_f1=0.9848, lr=1.64990e-04]


Training finished.
Performing intervention with scCBGM...
Training and Evaluating FM
Generating learned concepts from scCBGM...
Training Concept Bottleneck Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [08:11<00:00,  2.46s/it, avg_loss=4.830e-01, lr=1.64990e-04]


Performing intervention with CB-FM (learned)...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

Performing intervention with CB-FM (learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

Training and Evaluating Raw FM
Training Conditonal Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [08:19<00:00,  2.50s/it, avg_loss=1.934e+00, lr=1.64990e-04]


Performing intervention with Raw Flow Matching(learned)...
Editing from t=1.0 back to t=0.00, then forward with new condition.


                                                                                        

Performing intervention with Raw Flow Matching(learned)...
Decoding with 1000 steps and CFG scale w=1.0


                                                                                        

Training and Evaluating Concept Flow VAE
Training Concept Flow Model
Starting training on cuda for 200 epochs...


Training Progress:  38%|████████████▏                   | 76/200 [07:56<12:58,  6.28s/it, avg_loss=6.927e-01, fm_loss=1.468e-01, concept_f1=8.771e-01]/pytorch/aten/src/ATen/native/cuda/Loss.cu:90: operator(): block: [4,0,0], thread: [0,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:90: operator(): block: [4,0,0], thread: [1,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:90: operator(): block: [4,0,0], thread: [2,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:90: operator(): block: [4,0,0], thread: [3,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:90: operator(): block: [4,0,0], thread: [4,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/pytorch/aten/src/ATen/native/cuda/Loss.cu:90: operator(): block: [4,0,0], thread: [5,0,0] Assertion `inp

AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
labels = [f"{instance['cell_type']}\n{instance['ctrl']}->{instance['stim']}" for instance in benchmark_instances[:len(benchmark_results['scCBGM']['mmd_ratio'])]]

In [20]:
long_form_data = []
for model_name, metrics in benchmark_results.items():
    # Ensure the number of scores matches the number of labels for robustness
    if len(metrics.get('true_mmd', [])) != len(labels):
        print(f"Warning: Mismatch between score count and label count for model '{model_name}'. Skipping.")
        continue
        
    for i, label in enumerate(labels):
        long_form_data.append({
            'label': label,
            'model': model_name,
            'true_mmd': metrics['true_mmd'][i],
            'mmd_ratio': metrics['mmd_ratio'][i]
        })

df = pd.DataFrame(long_form_data)


# --- 2. Create Plots ---
fig, ax = plt.subplots(1, 1, figsize=(24, 12)) # Adjusted figsize for a single plot

sorted_labels = sorted(list(set(labels)))

# --- Plot MMD Ratio, ensuring the x-axis is sorted ---
sns.barplot(data=df, x='label', y='mmd_ratio', hue='model', ax=ax, order=sorted_labels)

# --- SET Y-AXIS TO LOG SCALE ---
ax.set_yscale('log')

ax.set_ylabel('MMD Ratio (Log Scale, Lower is Better)', fontsize=14)
ax.set_xlabel('Cell Type', fontsize=14)
ax.set_title('Normalized Prediction Error (MMD Ratio)', fontsize=18)

# Format x-tick labels to be on two lines based on the sorted list
# new_labels = [label.replace('_', '\n') for label in sorted_labels]
# new_labels = [label.replace('Monocytes', 'Mono') for label in new_labels]
# new_labels = [label.replace('cells', '') for label in new_labels]
ax.set_xticklabels(sorted_labels, fontsize=12)

ax.legend(title='Model', bbox_to_anchor=(1.02, 1), loc='upper left')

# Adjust layout to make space for the external legend and the main title
plt.tight_layout()
plt.show()