# Config

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import conceptlab as clab
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import anndata as ad
import scanpy as sc
import torch
import scipy.spatial
import matplotlib.patches as mpatches
import string
import colorcet 

from omegaconf import OmegaConf
import pytorch_lightning as pl

In [3]:
DATA_PATH = '/braid/havivd/immune_dictionary/lig_seurat_with_concepts.h5ad'
USE_PC = True
Z_SCORE = False
CONCEPT_KEY = 'concepts'
RANDOM_SEED = 0

# Set random seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7f52909ae7b0>

In [4]:
CTYPE = 'Dendritic_cell'
PERT = 'IL15'
CTRL = 'PBS'

L2_INT = 'pDC'

In [5]:
# --- Plotting Configuration ---
CT_CMAP = {'B_cell': '#1f77b3',
 'Basophil': '#ff7e0e',
 'FRC': '#2ba02b',
 'ILC': '#d62628',
 'LEC': '#9367bc',
 'Langerhans': '#8c564b',
 'Macrophage': '#e277c1',
 'Mast_cell': '#7e7e7e',
 'MigDC': '#bcbc21',
 'Monocyte': '#16bdcf',
 'NK_cell': '#3a0182',
 'Neutrophil': '#004201',
 'T_cell_CD4': '#0fffa8',
 'T_cell_CD8': '#5d003f',
 'T_cell_Ki67': '#bcbcff',
 'T_cell_gd': '#d8afa1',
 'Treg': '#b80080',
 'cDC1': '#004d52',
 'cDC2': '#6b6400',
 'eTAC': '#7c0100',
 'pDC': '#6026ff'}

STIM_CMAP = {'ctrl': '#b80080', 'stim': '#e277c1', 'other': '#7e7e7e'}
# IDENT_CMAP = {
#     'train': '#676765', 'held out for intervention': '#c84639',
#     'held out as GT': '#048757', 'intervened on': '#06d400'
# }

IDENT_CMAP = {
    'train': '#676765', 'intervention': '#c84639','held out for intervention': '#c84639',
    'held out as GT': '#048757', 'intervened on': '#06d400'
}

TITLE_MAP = {'celltype': 'Cell Type', 'stim': 'State', 'ident': 'Split'}


cell_type_clusters = {
    # Lymphoid cells are crucial for the adaptive immune response.
    'Lymphoid': {
        'T_cell': [
            'T_cell_CD4',
            'T_cell_CD8',
            'T_cell_Ki67',
            'T_cell_gd',
            'Treg'
        ],
        'B_cell': [
            'B_cell'
        ],
        'Innate_lymphoid_cell': [
            'ILC',
            'NK_cell'
        ]
    },
    # Myeloid cells are primarily part of the innate immune system.
    'Myeloid': {
        'Dendritic_cell': [
            'cDC1',
            'cDC2',
            'Langerhans',
            'MigDC',
            'pDC'
        ],
        'Monocyte_Macrophage': [
            'Macrophage',
            'Monocyte'
        ],
        'Granulocyte': [
            'Basophil',
            'Mast_cell',
            'Neutrophil'
        ]
    },
    # Stromal cells provide structure and support within tissues.
    'Stromal_Other': {
        'Stromal': [
            'FRC',
            'LEC',
            'eTAC'
        ]
    }
}

# Example of how to access a specific group:
# print(cell_type_clusters['Myeloid']['Granulocyte'])

# --- Look-Up Table (LUT) Generation ---

# Create an LUT to map each specific subtype to its parent group (e.g., 'T_cell_CD4' -> 'T_cell').
# This uses a dictionary comprehension to iterate through the nested structure.
subtype_to_group_lut = {
    subtype: group
    for major_group, sub_groups in cell_type_clusters.items()
    for group, subtypes in sub_groups.items()
    for subtype in subtypes
}

# DATA LOADING AND PREPARATION

In [7]:
# def split_data_for_counterfactuals(adata, hold_out_label, mod_label, p_intervention=0.2):
#     """Splits data into train, intervention, and ground truth sets."""
#     print("Splitting data for counterfactual experiment...")
#     labels = adata.obs['cell_stim']
#     is_test = (labels == hold_out_label)
#     is_inter_pool = (labels == mod_label)

#     # Create a random mask to select a subset for intervention
#     inter_mask = np.random.binomial(1, p=p_intervention, size=is_inter_pool.sum()).astype(bool)
#     is_inter = np.zeros_like(labels, dtype=bool)
#     is_inter[is_inter_pool] = inter_mask
    
#     is_train = ~is_test & ~is_inter

#     # Create AnnData objects for each split
#     adata_train = adata[is_train].copy()
#     adata_test = adata[is_test].copy()
#     adata_inter = adata[is_inter].copy()

#     # Store split identifiers in the original object for later merging
#     ident_vec = np.array(['train'] * len(adata)).astype('<U32')
#     ident_vec[is_test] = 'held out as GT'
#     ident_vec[is_inter] = 'held out for intervention'
#     adata.obs['ident'] = ident_vec
    
#     print(f"Train set: {len(adata_train)} cells")
#     print(f"Intervention set: {len(adata_inter)} cells")
#     print(f"Ground Truth set: {len(adata_test)} cells")

#     return adata, adata_train, adata_test, adata_inter


def split_data(adata, hold_out_label, mod_label):
    """
    Splits data into train, intervention, and ground truth sets.

    - Ground Truth: All cells with the `hold_out_label`.
    - Intervention: All cells with the `mod_label`.
    - Train: All remaining cells.
    """
    print("Splitting data with simplified logic...")
    labels = adata.obs['cell_stim']

    # Define the three disjoint sets based on their labels
    is_test = (labels == hold_out_label)
    is_inter = (labels == mod_label)
    is_train = ~is_test

    # Create AnnData objects for each split
    adata_train = adata[is_train].copy()
    adata_test = adata[is_test].copy()
    adata_inter = adata[is_inter].copy()

    # Store split identifiers in the original object
    ident_vec = np.array(['train'] * len(adata)).astype('<U32')
    ident_vec[is_test] = 'held out as GT'
    ident_vec[is_inter] = 'intervention'
    adata.obs['ident'] = ident_vec
    

    return adata, adata_train, adata_test, adata_inter



# MODELING & PREDICTION METHODS

## Method 1: scCBGM

In [8]:


def train_method_1_cbgm(adata_train, pc = USE_PC, z_score = Z_SCORE):
    """Trains and returns the scCBGM model."""
    print("Training scCBGM model...")

    # --- MODIFICATION START ---
    # Conditionally set the data source and input dimension based on the 'pca' flag
    if(pc):
        data_matrix = adata_train.obsm['X_pca']
    else:
        data_matrix = adata_train.X
        if(z_score):
            data_matrix = (data_matrix - adata_train.var['mean'].to_numpy()[None, :]) / adata_train.var['std'].to_numpy()[None, :]  # Z-score normalization

    # --- MODIFICATION END ---

    torch.set_flush_denormal(True)

    
    config = OmegaConf.create(dict(
        has_cbm=True, 
        lr=5e-4, 
        hidden_dim=1024, 
        n_layers = 4,
        beta=1e-5,
        input_dim=data_matrix.shape[-1],  # <-- Use the dynamically set input dimension
        latent_dim=128,
        n_concepts=adata_train.obsm[CONCEPT_KEY].shape[1],
        min_bottleneck_size=128, independent_training=True,
        concepts_hp=0.1, orthogonality_hp=0.5, use_soft_concepts=False
    ))
    model = clab.models.scCBGM(config)

    model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[CONCEPT_KEY].to_numpy().astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return model

def predict_with_method_1_cbgm(model, adata_inter, hold_out_label, pc = USE_PC):
    """Performs intervention using a trained scCBGM model."""
    print("Performing intervention with scCBGM...")
    if(pc):
        x_intervene_on =  torch.tensor(adata_inter.obsm['X_pca'], dtype=torch.float32)
    else:
        x_intervene_on = torch.tensor(adata_inter.X, dtype=torch.float32)
    c_intervene_on = adata_inter.obsm[CONCEPT_KEY].to_numpy().astype(np.float32)

    # Define the intervention by creating a mask and new concept values
    mask = torch.zeros(c_intervene_on.shape, dtype=torch.float32)
    mask[:, CTRL_INDEX] = 1
    mask[:, STIM_INDEX] = 1 

    inter_concepts = torch.tensor(c_intervene_on, dtype=torch.float32)
    inter_concepts[:, CTRL_INDEX] = 0
    inter_concepts[:, STIM_INDEX] = 1 # Set stim concept to 1

    with torch.no_grad():
        inter_preds = model.intervene(x_intervene_on.to('cuda'), mask=mask.to('cuda'), concepts=inter_concepts.to('cuda'))
    
    inter_preds = inter_preds['x_pred'].cpu().numpy()
    
    if(pc):
        x_inter_preds = adata_inter.uns['pc_transform'].inverse_transform(inter_preds)
    else:
        x_inter_preds = inter_preds

    pred_adata = ad.AnnData(x_inter_preds, var=adata_inter.var)
    pred_adata.obs['ident'] = 'intervened on'
    pred_adata.obs['cell_stim'] = hold_out_label + '*'
    pred_adata.obs['celltype'] = adata_inter.obs['celltype']
    pred_adata.obs['stim'] = 'stim'
    
    if(pc):
        pred_adata.obsm['X_pca'] = inter_preds
    return pred_adata


##  Method 2: Flow Matching with Learned Concepts

In [9]:
def get_learned_concepts(scCBGM_model, adata_full, pc = USE_PC):
    """Uses a trained scCBGM to generate learned concepts for all data."""
    print("Generating learned concepts from scCBGM...")

    if(pc):
        all_x = torch.tensor(adata_full.obsm['X_pca'], dtype=torch.float32).to('cuda')
    else:
        all_x = torch.tensor(adata_full.X, dtype=torch.float32).to('cuda')

    with torch.no_grad():
        enc = scCBGM_model.encode(all_x)
        adata_full.obsm['scCBGM_concepts_known'] = scCBGM_model.cb_concepts_layers(enc['mu']).cpu().numpy()
        adata_full.obsm['scCBGM_concepts_unknown'] = scCBGM_model.cb_unk_layers(enc['mu']).cpu().numpy()

    adata_full.obsm['scCBGM_concepts'] = np.concatenate([adata_full.obsm['scCBGM_concepts_known'], adata_full.obsm['scCBGM_concepts_unknown']], axis=1)
    return adata_full

def train_method_2_fm_learned(adata_train, pc = USE_PC, z_score = Z_SCORE):
    """Trains and returns the CB-FM model using learned concepts."""
    print("Training Conditonal Flow Model")

    if(pc):
        data_matrix = adata_train.obsm['X_pca']
    else:
        data_matrix = adata_train.X
        if(z_score):
            data_matrix = (data_matrix - adata_train.var['mean'].to_numpy()[None, :]) / adata_train.var['std'].to_numpy()[None, :]  # Z-score normalization

    config = dict(
        input_dim=data_matrix.shape[1],
        hidden_dim=1024,
        latent_dim=128,
        n_concepts=adata_train.obsm['scCBGM_concepts'].shape[1],
        n_layers=4,
        dropout=0.1,
        p_uncond = 0.0)

    fm_model = clab.models.cond_fm.Cond_FM(config=config)

    fm_model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm['scCBGM_concepts'].astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return fm_model



def predict_with_method_2_fm_learned(model, adata_inter, hold_out_label, pc = USE_PC, z_score = Z_SCORE, ):
    """Performs intervention using a trained learned-concept CB-FM model."""
    print("Performing intervention with CB-FM (learned)...")
    c_known_inter = adata_inter.obsm['scCBGM_concepts_known'].astype(np.float32)
    c_unknown_inter = torch.from_numpy(adata_inter.obsm['scCBGM_concepts_unknown'].astype(np.float32))


    inter_concepts_known = c_known_inter.copy()
    inter_concepts_known[:, CTRL_INDEX] = 0 # Set crl concept to 0
    inter_concepts_known[:, STIM_INDEX] = 1 # Set stim concept to 1
    
    if(pc):
        x_inter = adata_inter.obsm['X_pca']
    else:
        x_inter = adata_inter.X
        if(z_score):
            x_inter = (x_inter - adata_inter.var['mean'].to_numpy()[None, :]) / adata_inter.var['std'].to_numpy()[None, :]

    init_concepts = np.concatenate([c_known_inter, c_unknown_inter], axis=1)
    edit_concepts = np.concatenate([inter_concepts_known, c_unknown_inter], axis=1)

    inter_preds = model.edit(
             x = torch.from_numpy(x_inter.astype(np.float32)).to('cuda'),
             c = torch.from_numpy(init_concepts.astype(np.float32)).to('cuda'),
             c_prime = torch.from_numpy(edit_concepts.astype(np.float32)).to('cuda'),
             t_edit = 0.1,
             n_steps = 1000,
             w_cfg_forward = 1.0,
             w_cfg_backward = 1.0,
             noise_add = 0.0)
    
    inter_preds = inter_preds.detach().cpu().numpy()

    if(pc):
        x_inter_preds = adata_inter.uns['pc_transform'].inverse_transform(inter_preds)
    else:
        x_inter_preds = inter_preds

    if(z_score):
        x_inter_preds = (x_inter_preds * adata_inter.var['std'].to_numpy()[None, :]) + adata_inter.var['mean'].to_numpy()[None, :]

    pred_adata = ad.AnnData(x_inter_preds, var=adata_inter.var)
    pred_adata.obs['ident'] = 'intervened on'
    pred_adata.obs['cell_stim'] = hold_out_label + '*'
    pred_adata.obs['celltype'] = adata_inter.obs['celltype']
    pred_adata.obs['stim'] = 'stim'
    
    if(pc):
        pred_adata.obsm['X_pca'] = inter_preds

    return pred_adata




## Method 3: Concept Flows



In [10]:
def train_method_3_concept_flows(adata_train, pc = USE_PC, z_score = Z_SCORE):
    """Trains and returns the CB-FM model using learned concepts."""
    print("Training Concept Flow Mode")

    if(pc):
        print("Using PCA-reduced data for training.")
        data_matrix = adata_train.obsm['X_pca']
    else:
        print("Using full data for training.")
        data_matrix = adata_train.X
        if(z_score):
            data_matrix = (data_matrix - adata_train.var['mean'].to_numpy()[None, :]) / adata_train.var['std'].to_numpy()[None, :]  # Z-score normalization
            
    config = dict(
        input_dim=data_matrix.shape[1],
        hidden_dim=1024,
        latent_dim=128,
        n_concepts=adata_train.obsm[CONCEPT_KEY].to_numpy().shape[1],
        n_unknown=128,
        n_layers=4,
        dropout=0.1,
        p_uncond=0.0,
        flow_hp=1.0,
        concepts_hp=0.2,
        orthogonality_hp=0.5)


    fm_model = clab.models.concept_fm.Concept_FM(config=config)

    fm_model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[CONCEPT_KEY].to_numpy().astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return fm_model

def predict_with_method_3_concept_flows(model, adata_inter, hold_out_label, pc = USE_PC, z_score = Z_SCORE):
    """Performs intervention using a trained scCBGM model."""
    print("Performing intervention with concept flow...")


    if(pc):
        data_matrix = adata_inter.obsm['X_pca']
    else:
        data_matrix = adata_inter.X
        if(z_score):
            data_matrix = (data_matrix - adata_inter.var['mean'].to_numpy()[None, :]) / adata_inter.var['std'].to_numpy()[None, :]  # Z-score normalization

    x_intervene_on = torch.from_numpy(data_matrix.astype(np.float32))
    c_intervene_on = torch.from_numpy(adata_inter.obsm[CONCEPT_KEY].to_numpy().astype(np.float32))

    mask = torch.zeros(c_intervene_on.shape, dtype=torch.float32)
    mask[:, CTRL_INDEX] = 1
    mask[:, STIM_INDEX] = 1

    inter_concepts = torch.tensor(c_intervene_on, dtype=torch.float32)
    inter_concepts[:, CTRL_INDEX] = 0
    inter_concepts[:, STIM_INDEX] = 1

    with torch.no_grad():
        inter_preds = model.intervene(x_intervene_on.to('cuda'), mask=mask.to('cuda'), concepts=inter_concepts.to('cuda'))

    inter_preds = inter_preds.detach().cpu().numpy()

    if(pc):
        x_inter_preds = adata_inter.uns['pc_transform'].inverse_transform(inter_preds)
    else:
        x_inter_preds = inter_preds

    if(z_score):
        x_inter_preds = (x_inter_preds * adata_inter.var['std'].to_numpy()[None, :]) + adata_inter.var['mean'].to_numpy()[None, :]

    pred_adata = ad.AnnData(x_inter_preds, var=adata_inter.var)
    pred_adata.obs['ident'] = 'intervened on'
    pred_adata.obs['cell_stim'] = hold_out_label + '*'
    pred_adata.obs['celltype'] = adata_inter.obs['celltype']
    pred_adata.obs['stim'] = 'stim'
    
    if(pc):
        pred_adata.obsm['X_pca'] = inter_preds
    return pred_adata



## Method 4: Flow Matching with Raw Concepts 

In [11]:

def train_method_4_fm_raw(adata_train, pc = USE_PC, z_score = Z_SCORE):
    """Trains and returns the CB-FM model using learned concepts."""

    print("Training Conditonal Flow Model")

    if(pc):
        data_matrix = adata_train.obsm['X_pca']
    else:
        data_matrix = adata_train.X
        if(z_score):
            data_matrix = (data_matrix - adata_train.var['mean'].to_numpy()[None, :]) / adata_train.var['std'].to_numpy()[None, :]  # Z-score normalization

    config = dict(
        input_dim=data_matrix.shape[1],
        hidden_dim=1024,
        latent_dim=128,
        n_concepts=adata_train.obsm[CONCEPT_KEY].to_numpy().shape[1],
        n_layers=4,
        p_uncond = 0.0,
        dropout=0.1)

    fm_model = clab.models.cond_fm.Cond_FM(config=config)

    fm_model.train_loop(
        data=torch.from_numpy(data_matrix.astype(np.float32)),
        concepts=torch.from_numpy(adata_train.obsm[CONCEPT_KEY].to_numpy().astype(np.float32)),
        num_epochs=200, batch_size=128, lr=3e-4,
    )
    return fm_model


  

def predict_with_method_4_fm_raw(model, adata_inter, hold_out_label, pc = USE_PC, z_score = Z_SCORE):
    """Performs intervention using a trained raw-concept CB-FM model."""
    print("Performing intervention with CB-FM (raw)...")

    if(pc):
        data_matrix = adata_inter.obsm['X_pca']
    else:
        data_matrix = adata_inter.X
        if(z_score):
            data_matrix = (data_matrix - adata_inter.var['mean'].to_numpy()[None, :]) / adata_inter.var['std'].to_numpy()[None, :]  # Z-score normalization

    x_inter = torch.from_numpy(data_matrix.astype(np.float32))
    
    c_intervene_on = torch.from_numpy(adata_inter.obsm[CONCEPT_KEY].to_numpy().astype(np.float32))
    inter_concepts = c_intervene_on.clone()
    
    inter_concepts[:, CTRL_INDEX] = 0
    inter_concepts[:, STIM_INDEX] = 1

    with torch.no_grad():


        # inter_preds = model.edit(
        #         x = x_inter.to('cuda'),
        #         c = c_intervene_on.to('cuda'),
        #         c_prime = inter_concepts.to('cuda'),
        #         t_edit = 0.01,
        #         n_steps = 1000,
        #         w_cfg_forward = 1.0,
        #         w_cfg_backward = 1.0,
        #         noise_add = 0.0)
    
        inter_preds = model.decode(
            h=inter_concepts.to('cuda'),
            n_steps=1000,
            w_cfg = 1.0
        )
        
    inter_preds = inter_preds.detach().cpu().numpy()

    if(pc):
        x_inter_preds = adata_inter.uns['pc_transform'].inverse_transform(inter_preds)
    else:
        x_inter_preds = inter_preds

    if(z_score):
        x_inter_preds = (x_inter_preds * adata_inter.var['std'].to_numpy()[None, :]) + adata_inter.var['mean'].to_numpy()[None, :]

    pred_adata = ad.AnnData(x_inter_preds, var=adata_inter.var)
    pred_adata.obs['ident'] = 'intervened on'
    pred_adata.obs['cell_stim'] = hold_out_label + '*'
    pred_adata.obs['celltype'] = adata_inter.obs['celltype']
    pred_adata.obs['stim'] = 'stim'
    
    if(pc):
        pred_adata.obsm['X_pca'] = inter_preds
    return pred_adata


# Plotting

In [12]:
def analyze_and_plot_results(adata_original, pred_adata, method_name, score):
    """Merges data, runs UMAP, calculates error, and plots the results."""
    print(f"Analyzing and plotting results for {method_name}...")
    
    # Keep only the original data splits, not the predicted one
    adata_to_merge = adata_original[adata_original.obs['ident'] != 'intervened on'].copy()
    adata_merged = ad.concat([adata_to_merge, pred_adata], join='inner', merge='unique')

    # Add other metadata for coloring
    # Dimensionality Reduction for visualization
    #sc.pp.pca(adata_merged)
    
    #adata_merged.obsm['X_pca'] = adata_merged.X
    # check if obsm has X_pca:

    if 'X_pca' in adata_merged.obsm:
        print("using pca already computed")
    else:
        print("computing pca")
        sc.pp.pca(adata_merged)
    

    sc.pp.neighbors(adata_merged)
    sc.tl.umap(adata_merged, random_state=RANDOM_SEED)

    # # Set plotting order
    # adata_merged.obs['ident'] = pd.Categorical(adata_merged.obs['ident'],
    #     categories=['train', 'intervention', 'held out as GT', 'intervened on'])

    adata_merged.obs['ident'] = pd.Categorical(adata_merged.obs['ident'])
    adata_merged.obs['stim'] = pd.Categorical(adata_merged.obs['stim'])
    # --- Plotting ---
    fig, axes = plt.subplots(1, 3, figsize=(21, 5), constrained_layout=True)
    fig.suptitle(f"Counterfactual Prediction Results: {method_name}", fontsize=20)
    
    cmaps = [CT_CMAP, STIM_CMAP, IDENT_CMAP]
    color_keys = ['celltype', 'stim', 'ident']

    for i, (ax, cmap, key) in enumerate(zip(axes, cmaps, color_keys)):
        sc.pl.umap(adata_merged, color=key, ax=ax, show=False, palette=cmap, s=10,
                   title=f"{string.ascii_uppercase[i]}) {TITLE_MAP[key]}")
        
        # Add error to the title of the last plot
        if key == 'ident':
            ax.set_title(f"{ax.get_title()}, rMMD score {score:.2f}")

            # Add intervention arrow
            source_coords = adata_merged[adata_merged.obs['ident'] == 'intervention'].obsm['X_umap'].mean(0)
            #source_coords = adata_merged[adata_merged.obs['ident'] == 'held out for intervention'].obsm['X_umap'].mean(0)
            target_coords = adata_merged[adata_merged.obs['ident'] == 'intervened on'].obsm['X_umap'].mean(0)
            arrow = mpatches.FancyArrowPatch(
                source_coords, target_coords,
                connectionstyle="arc3,rad=0.3", arrowstyle="-|>",
                linewidth=2, linestyle='dashed', color="black", mutation_scale=20
            )
            ax.add_patch(arrow)
            
    plt.show()


# Main

## Proccesing

In [13]:
import sklearn.decomposition

In [14]:
print("Loading and preprocessing data...")
adata = ad.read_h5ad(DATA_PATH)


Loading and preprocessing data...


In [15]:
adata.obs['cell_type_L1'] = adata.obs['celltype'].map(subtype_to_group_lut)
adata.obs['cell_type_L2'] = adata.obs['celltype']

In [16]:
adata.obsm['concepts'] = pd.get_dummies(adata.obs[['sample', 'cell_type_L1']]).astype(np.float32)
adata.obsm['concepts'].columns = [col.replace('sample_', '').replace('cell_type_L1_', '') for col in adata.obsm['concepts'].columns]

In [17]:
adata.obs['cell_stim'] = [ctype + "_" + stim for ctype, stim in zip(adata.obs['cell_type_L1'], adata.obs['sample'])]

In [18]:
hold_out_label = CTYPE + '_' + PERT
mod_label =  CTYPE + '_' + CTRL


adata, adata_train, adata_test, adata_inter = split_data(
    adata, hold_out_label, mod_label
)

adata_inter = adata_inter[adata_inter.obs['cell_type_L2'] == L2_INT].copy()
adata_test = adata_test[adata_test.obs['cell_type_L2'] == L2_INT].copy()
# adata_test = adata_test[adata_test.obs['stim'] == 'stim'].copy()

adata.obs['stim'] = np.where(adata.obs['cell_stim'] ==  CTYPE + '_' + CTRL, 'ctrl', 'other')
adata.obs['stim'] = np.where(adata.obs['cell_stim'] ==  CTYPE + '_' + PERT, 'stim', adata.obs['stim'])


print(f"Train set: {len(adata_train)} cells")
print(f"Intervention set: {len(adata_inter)} cells")
print(f"Ground Truth set: {len(adata_test)} cells")

Splitting data with simplified logic...
Train set: 109888 cells
Intervention set: 100 cells
Ground Truth set: 100 cells


In [19]:
CTRL_INDEX = np.where(adata.obsm['concepts'].columns == CTRL)[0][0]
STIM_INDEX = np.where(adata.obsm['concepts'].columns == PERT)[0][0]

In [20]:
adata.uns['pc_transform'] = sklearn.decomposition.PCA(n_components=256).fit(adata_train.X)

for x_data in [adata, adata_train, adata_test, adata_inter]:
    x_data.uns['pc_transform'] = adata.uns['pc_transform']
    x_data.obsm['X_pca'] = x_data.uns['pc_transform'].transform(x_data.X)


## --- Method 1: scCBGM ---


In [21]:
cbgm_model = train_method_1_cbgm(adata_train.copy())

Training scCBGM model...
Starting training on cuda for 200 epochs...


Training Progress: 100%|█████████████████████████████████████| 200/200 [13:26<00:00,  4.03s/it, avg_loss=2.314e-01, concept_f1=1.0000, lr=1.64990e-04]

Training finished.





In [22]:
pred_adata_cbgm = predict_with_method_1_cbgm(cbgm_model, adata_inter.copy(), hold_out_label)

Performing intervention with scCBGM...


In [23]:
score_cbfm_learned_mmd = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
    x_train=adata_train.obsm['X_pca'] if USE_PC else adata_train.X,
    x_ivn=pred_adata_cbgm.obsm['X_pca'] if USE_PC else pred_adata_cbgm.X,
    x_target=adata_test.obsm['X_pca'] if USE_PC else adata_test.X,
    labels_train=adata_train.obs['cell_stim'].values)


In [24]:
score_cbfm_learned_mmd

{'mmd_ratio': np.float64(0.3650403155711375),
 'pre_computed_mmd_train': np.float64(175.81577079928547)}

## --- Method 2: CB-FM with Learned Concepts ---


In [25]:
adata_with_concepts = get_learned_concepts(cbgm_model, adata.copy())

# adata_train.obsm['scCBGM_concepts'] = adata_with_concepts[adata_train.obs.index].obsm['scCBGM_concepts_known']
# adata_inter.obsm['scCBGM_concepts'] = adata_with_concepts[adata_inter.obs.index].obsm['scCBGM_concepts_known']

# Distribute the newly generated concepts to the training set
adata_train.obsm['scCBGM_concepts'] = adata_with_concepts[adata_train.obs.index].obsm['scCBGM_concepts']
adata_inter.obsm['scCBGM_concepts'] = adata_with_concepts[adata_inter.obs.index].obsm['scCBGM_concepts']

adata_train.obsm['scCBGM_concepts_known'] = adata_with_concepts[adata_train.obs.index].obsm['scCBGM_concepts_known']
adata_train.obsm['scCBGM_concepts_unknown'] = adata_with_concepts[adata_train.obs.index].obsm['scCBGM_concepts_unknown']

adata_inter.obsm['scCBGM_concepts_known'] = adata_with_concepts[adata_inter.obs.index].obsm['scCBGM_concepts_known']
adata_inter.obsm['scCBGM_concepts_unknown'] = adata_with_concepts[adata_inter.obs.index].obsm['scCBGM_concepts_unknown']

Generating learned concepts from scCBGM...


In [26]:
fm_learned_model = train_method_2_fm_learned(adata_train.copy())

Training Conditonal Flow Model
Starting training on cuda for 200 epochs...


Training Progress: 100%|██████| 200/200 [08:08<00:00,  2.44s/it, avg_loss=8.574e-01, lr=1.64990e-04]


In [27]:
pred_adata_fm_learned = predict_with_method_2_fm_learned(fm_learned_model, adata_inter.copy(), hold_out_label)

Performing intervention with CB-FM (learned)...
Editing from t=1.0 back to t=0.10, then forward with new condition.


                                                                                        

In [28]:


score_fm_learned_mmd = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
    x_train=adata_train.obsm['X_pca'] if USE_PC else adata_train.X,
    x_ivn=pred_adata_fm_learned.obsm['X_pca'] if USE_PC else pred_adata_fm_learned.X,
    x_target=adata_test.obsm['X_pca'] if USE_PC else adata_test.X,
    labels_train=adata_train.obs['cell_stim'].values,
    pre_computed_mmd_train=score_cbfm_learned_mmd['pre_computed_mmd_train']
)


In [29]:
score_fm_learned_mmd

{'mmd_ratio': np.float64(0.348241692114373),
 'pre_computed_mmd_train': np.float64(175.81577079928547)}

## --- Method 3: Concept Flows ---

In [30]:
USE_PC = True

In [31]:
concept_flow_model = train_method_3_concept_flows(adata_train.copy())

Training Concept Flow Mode
Using PCA-reduced data for training.
Starting training on cuda for 200 epochs...


Training Progress: 100%|███████████████████████████████| 200/200 [16:53<00:00,  5.07s/it, avg_loss=1.046e+00, fm_loss=8.775e-01, concept_f1=9.830e-01]


In [32]:
pred_adata_concept_flow = predict_with_method_3_concept_flows(concept_flow_model, adata_inter.copy(), hold_out_label)

#pred_adata_concept_flow = predict_with_method_3_concept_flows_edit(concept_flow_model, adata_inter.copy(), hold_out_label)




  inter_concepts = torch.tensor(c_intervene_on, dtype=torch.float32)


Performing intervention with concept flow...
Decoding with 1000 steps and cfg_strength 1.0


                                                                                        

In [33]:


score_fm_cf_mmd = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
    x_train=adata_train.obsm['X_pca'] if USE_PC else adata_train.X,
    x_ivn=pred_adata_concept_flow.obsm['X_pca'] if USE_PC else pred_adata_concept_flow.X,
    x_target=adata_test.obsm['X_pca'] if USE_PC else adata_test.X,
    labels_train=adata_train.obs['cell_stim'].values,
    pre_computed_mmd_train=score_cbfm_learned_mmd['pre_computed_mmd_train']
)


In [34]:
score_fm_cf_mmd

{'mmd_ratio': np.float64(0.3396590624016524),
 'pre_computed_mmd_train': np.float64(175.81577079928547)}

In [35]:
# analyze_and_plot_results(adata, pred_adata_concept_flow, "Method 3: Concept Flow", score_fm_cf_mmd['mmd_ratio'])

## --- Method 4: CB-FM with Raw Concepts ---


In [None]:
fm_raw_model = train_method_4_fm_raw(adata_train.copy())

Training Conditonal Flow Model
Starting training on cuda for 200 epochs...


Training Progress:  26%|█▊     | 51/200 [02:08<06:23,  2.58s/it, avg_loss=1.754e+00, lr=2.58154e-04]

In [None]:
pred_adata_fm_raw = predict_with_method_4_fm_raw(fm_raw_model, adata_inter.copy(), hold_out_label)

In [None]:
score_fm_raw_mmd = clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
    x_train=adata_train.obsm['X_pca'] if USE_PC else adata_train.X,
    x_ivn=pred_adata_fm_raw.obsm['X_pca'] if USE_PC else pred_adata_fm_raw.X,
    x_target=adata_test.obsm['X_pca'] if USE_PC else adata_test.X,
    labels_train=adata_train.obs['cell_stim'].values,
    #pre_computed_mmd_train=score_cbfm_learned_mmd['pre_computed_mmd_train']
)



In [None]:
score_fm_raw_mmd

In [None]:
score_fm_raw_mmd = clab.evaluation.interventions.evaluate_intervention_emd_with_target(
    x_train=adata_train.obsm['X_pca'] if USE_PC else adata_train.X,
    x_ivn=pred_adata_fm_raw.obsm['X_pca'] if USE_PC else pred_adata_fm_raw.X,
    x_target=adata_test.obsm['X_pca'] if USE_PC else adata_test.X,
    labels_train=adata_train.obs['cell_stim'].values,
    #pre_computed_mmd_train=score_cbfm_learned_mmd['pre_computed_mmd_train']
)



In [None]:
USE_PC = True

In [None]:
score_fm_raw_emd = clab.evaluation.interventions.evaluate_intervention_emd_with_target(
    x_train=adata_train.obsm['X_pca'] if USE_PC else adata_train.X,
    x_ivn=pred_adata_fm_raw.obsm['X_pca'] if USE_PC else pred_adata_fm_raw.X,
    x_target=adata_test.obsm['X_pca'] if USE_PC else adata_test.X,
    labels_train=adata_train.obs['cell_stim'].values)

In [None]:
score_fm_raw_emd

In [None]:
USE_PC = False

In [None]:
pred_adata_cbgm
pred_adata_concept_flow
pred_adata_fm_learned
pred_adata_fm_raw

In [None]:
clab.evaluation.interventions.evaluate_intervention_mmd_with_target(
    x_train=adata_train.obsm['X_pca'] if USE_PC else adata_train.X,
    x_ivn=pred_adata_concept_flow.obsm['X_pca'] if USE_PC else pred_adata_fm_learned.X,
    x_target=adata_test.obsm['X_pca'] if USE_PC else adata_test.X,
    labels_train=adata_train.obs['cell_stim'].values,
    pre_computed_mmd_train=np.float64(0.0857313717126611))
