# Simulation using Concord

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import scanpy as sc
import time
from pathlib import Path
import torch
import concord as ccd
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

from matplotlib import font_manager, rcParams
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['pdf.fonttype'] = 42

In [3]:
proj_name = "simulation_treehard_singlebatch"
save_dir = f"../save/dev_{proj_name}-{time.strftime('%b%d')}/"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)

data_dir = f"../data/{proj_name}/"
data_dir = Path(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print(device)
seed = 0
ccd.ul.set_seed(seed)

file_suffix = f"{time.strftime('%b%d-%H%M')}"
file_suffix

cpu


'Aug16-2021'

In [4]:
group_key = 'depth'
batch_key = 'batch'
state_key = 'branch'
leiden_key = 'leiden_no_noise'
time_key = 'time'
state_type = 'tree'
batch_type = 'batch_specific_features'
distribution = 'normal'

concord_methods = ["contrastive", 'concord_hcl', 'concord_knn']
other_methods = [
    "PCA", "UMAP", "t-SNE", "DiffusionMap", "NMF", 
    "FactorAnalysis", "FastICA", "LDA", "scVI", "PHATE", 
]
run_methods = concord_methods + other_methods
# exclude ["UMAP", "t-SNE"] from run_method, and save as combined_keys
exclude_keys = ["PCA", "UMAP", "t-SNE"]
combined_methods = ['no_noise', 'wt_noise'] + [key for key in run_methods if key not in exclude_keys]

In [None]:
from concord.simulation import (
    Simulation,
    SimConfig,
    TreeConfig,
    BatchConfig,
)

# ── 1. global settings ──────────────────────────────────────────────
sim_cfg = SimConfig(
    n_cells = 6_000,
    n_genes = 3_000,
    seed    = 42,
    non_neg = True,
    to_int  = True,
)

# ── 2. state / tree settings ───────────────────────────────────────
tree_cfg = TreeConfig(
    # generic state parameters
    distribution   = "normal",
    level          = 10,      # <- state_level
    min_level      = 1,       # <- state_min_level
    dispersion     = 5.0,     # <- state_dispersion

    # tree-specific knobs (straight copy of your old call)
    branching_factor        = [3, 5, 8],
    depth                   = 3,
    program_structure       = "linear_increasing",
    program_on_time_fraction= 0.2,
    program_decay           = 0.4,
    cellcount_decay         = 0.5,
    # program_gap_size, noise_in_block, initial_inherited_genes keep defaults
)

# ── 3. batch settings (only one batch in the old call) ─────────────
batch_cfg = BatchConfig(
    n_batches    = 1,
    effect_type  = "batch_specific_features",
    distribution = "normal",
    level        = 10,        # same value for the single batch
    dispersion   = 5.0,
    feature_frac = 0.1,       # → 0.1 was the old default; change if needed
    # cell_proportion defaults to [1.0] because n_batches == 1
)

# ── 4. run the simulation ──────────────────────────────────────────
sim   = Simulation(sim_cfg, tree_cfg, batch_cfg)
adata, adata_state = sim.simulate_data()

adata.layers["counts"] = adata.X.copy()


In [None]:
ccd.pl.heatmap_with_annotations(adata, val='no_noise', obs_keys=[state_key], yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='True state', save_path=save_dir/f'true_state_heatmap_{file_suffix}.svg', figsize=(6, 4), dpi=300)
ccd.pl.heatmap_with_annotations(adata, val='wt_noise', obs_keys=[state_key], yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='True state with noise', save_path=save_dir/f'true_state_with_noise_heatmap_{file_suffix}.svg', figsize=(6, 4), dpi=300)
ccd.pl.heatmap_with_annotations(adata, val='X', obs_keys=[state_key, batch_key], yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='Simulated data with batch signal', save_path=save_dir/f'simulated_data_heatmap_{file_suffix}.svg', figsize=(6, 4), dpi=300)

### Run dimension reduction

In [None]:
# adata = sc.read(data_dir / f"adata_Feb07-1524.h5ad")
# adata_state = sc.read(data_dir / f"adata_state_Feb07-1524.h5ad")
adata = sc.read(data_dir / f"adata_Jul20-2008.h5ad")
adata_state = sc.read(data_dir / f"adata_Jul20-2008.h5ad")

In [None]:
import matplotlib.pyplot as plt
figsize = (2.3, 1.8)
ncols = 6
title_fontsize = 9
dpi = 600
_, _, state_pal = ccd.pl.get_color_mapping(adata, state_key, pal='Paired', seed=seed)
_, _, batch_pal = ccd.pl.get_color_mapping(adata, batch_key, pal='Set1', seed=seed)
pal = {state_key: state_pal, batch_key: batch_pal}
color_bys = [state_key, batch_key]

In [None]:
latent_dim = 100
ccd.ul.run_pca(adata, source_key='no_noise', result_key='PCA_no_noise', n_pc=latent_dim, random_state=seed)
ccd.ul.run_umap(adata, source_key='no_noise', result_key='UMAP_no_noise', random_state=seed)
sc.pp.neighbors(adata, use_rep='PCA_no_noise', n_neighbors=30, random_state=seed)
adata.obsm['no_noise'] = adata.layers['no_noise']
adata.obsm['wt_noise'] = adata.layers['wt_noise']

In [None]:
concord_kwargs = {
    'latent_dim': latent_dim,
    'batch_size':32,
    'n_epochs': 10,
    'element_mask_prob': 0.3,
    'feature_mask_prob': 0.0,
    'load_data_into_memory': True,
    'verbose': False,
}

bench_res = ccd.bm.run_dimensionality_reduction_pipeline(
    adata,
    source_key="X",
    methods=run_methods,
    n_components=latent_dim,
    seed=seed,
    device=device,
    save_dir=save_dir,
    concord_kwargs=concord_kwargs,
)

In [None]:
# Use previous knn result for consistency
adata.obsm['concord_knn'] = adata.obsm['Concord_p_intra_knn_0.3']

In [None]:
# Run umap and PCA for all latent embeddings
for basis in combined_methods:
    if 'UMAP' in basis or 't-SNE' in basis:
        continue  # Skip UMAP and t-SNE for now
    print("Running UMAP for", basis)
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.5, metric='euclidean', random_state=seed)

In [None]:
adata.write_h5ad(data_dir / f"adata_{file_suffix}.h5ad")
print("Data saved to", data_dir / f"adata_{file_suffix}.h5ad")
adata_state.write_h5ad(data_dir / f"adata_state_{file_suffix}.h5ad")
print("State data saved to", data_dir / f"adata_state_{file_suffix}.h5ad")
bench_res.to_csv(save_dir / f"benchmark_results_{file_suffix}.csv", index=False)
print("Benchmark results saved to", save_dir / f"benchmark_results_{file_suffix}.csv")

In [None]:
# Highlight cells in branch 0_4 in global knn
show_branch = '2_0'
hl_cells = adata.obs['branch'].str.startswith(show_branch)
hl_cell_indices = np.where(hl_cells)[0]
font_size=8
point_size=1
alpha=0.8
figsize=(0.9*len(combined_methods),1)
ncols = len(combined_methods)
nrows = 1
k=15
edges_color='grey'
edges_width=0.01
layout='kk'
threshold = 0.1
node_size_scale=0.1
edge_width_scale=0.1
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        combined_methods,
        color_bys=color_bys,
        basis_types=basis_types,
        pal=pal,
        k=k,
        edges_color=edges_color,
        edges_width=edges_width,
        layout=layout,
        threshold=threshold,
        node_size_scale=node_size_scale,
        edge_width_scale=edge_width_scale,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        leiden_key='leiden',
        save_dir=save_dir,
        highlight_indices = hl_cell_indices,
        highlight_size=point_size,
        file_suffix=file_suffix+f'hl_cells_branch{show_branch}',
        save_format='pdf'
    )


In [None]:
show_branches = adata[adata.obs['depth']==1].obs['branch'].unique()
show_keys = [] 
for key in combined_methods:
    key_new = f"p_knn={key.split('_')[-1]}" if 'p_intra_knn' in key else key
    #adata.obsm[key_new] = adata.obsm[key]
    show_keys.append(key_new)

for show_branch in show_branches:
    # Change the names of combined_keys for easy plotting, like Concord_p_intra_knn_0.0 to p_knn=0.0
    adata_sub = adata[adata.obs['branch'].str.startswith(show_branch)].copy()

    basis_types = ['KNN']
    font_size=8
    point_size=20
    alpha=0.8
    figsize=(0.9*len(combined_methods),1)
    ncols = len(combined_methods)
    nrows = 1
    k=15

    with plt.rc_context(rc=custom_rc):
        ccd.pl.plot_all_embeddings(
            adata_sub,
            show_keys,
            color_bys=color_bys,
            basis_types=basis_types,
            pal=pal,
            k=k,
            edges_color=edges_color,
            edges_width=edges_width,
            layout=layout,
            threshold=threshold,
            node_size_scale=node_size_scale,
            edge_width_scale=edge_width_scale,
            font_size=font_size,
            point_size=point_size,
            alpha=alpha,
            figsize=figsize,
            ncols=ncols,
            seed=seed,
            leiden_key='leiden',
            save_dir=save_dir,
            file_suffix= 'allmethods_' + file_suffix + f'_branch{show_branch}',
            save_format='pdf'
        )


### Concord-kNN

In [None]:
import copy
concord_args_base = {
    'latent_dim': 100,
    'encoder_dims':[300],
    'decoder_dims':[300],
    'batch_size':16,
    'n_epochs': 10,
    'p_intra_knn': 0.0,
    'element_mask_prob': 0.3,
    'feature_mask_prob': 0.0,
    'clr_temperature':0.2,
    'sampler_knn':100,
    'domain_key': batch_key,
    'seed': seed,
    'device': device,
    'save_dir': save_dir, 
    'load_data_into_memory': True,
    'verbose': False,
}

In [None]:
param_variations = {
    'p_intra_knn': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],  # Example: vary p_intra_knn
    #'p_intra_knn': [0.0, 0.1, 0.2, 0.3, 0.5, 0.7, 1.0 ],  # Example: vary p_intra_knn
}
param_keys = []
# -------------------------------------
# 3. Loop over each parameter to vary
# -------------------------------------
for param_name, values in param_variations.items():
    print(f"\n[INFO] Varying '{param_name}' with possible values: {values}\n")
    
    for value in values:
        # 3a. Copy the base arguments
        concord_args = copy.deepcopy(concord_args_base)
        
        # 3b. Update the specific parameter we want to vary
        concord_args[param_name] = value
        
        # 3c. Indicate which param is varied and to which value
        output_key = f"Concord_{param_name}_{value}"
        umap_key   = output_key + '_UMAP'
        param_keys.append(output_key)
        print(f"[INFO] Running Concord with {param_name} = {value}")
        
        # -------------------------------------
        # 4. Run Concord
        # -------------------------------------
        cur_ccd = ccd.Concord(adata, **concord_args)

        # Encode data, saving the latent embedding in adata.obsm[output_key]
        cur_ccd.fit_transform(
            output_key=output_key,
        )       

print("[DONE] Finished varying each parameter individually.")


In [None]:
combined_keys = ['no_noise', 'wt_noise'] + param_keys
for key in combined_keys:
    ccd.ul.run_umap(adata, source_key=key, result_key=key + '_UMAP', random_state=seed, min_dist=0.5)

In [None]:
show_branches = adata[adata.obs['depth']==1].obs['branch'].unique()
show_keys = [] 
for key in combined_keys:
    key_new = f"p_knn={key.split('_')[-1]}" if 'p_intra_knn' in key else key
    #adata.obsm[key_new] = adata.obsm[key]
    show_keys.append(key_new)

for show_branch in show_branches:
    # Change the names of combined_keys for easy plotting, like Concord_p_intra_knn_0.0 to p_knn=0.0
    adata_sub = adata[adata.obs['branch'].str.startswith(show_branch)].copy()

    basis_types = ['KNN']
    font_size=8
    point_size=20
    alpha=0.8
    figsize=(0.9*len(combined_keys),1)
    ncols = len(combined_keys)
    nrows = int(np.ceil(len(param_keys) / ncols))
    k=15

    with plt.rc_context(rc=custom_rc):
        ccd.pl.plot_all_embeddings(
            adata_sub,
            show_keys,
            color_bys=color_bys,
            basis_types=basis_types,
            pal=pal,
            k=k,
            edges_color=edges_color,
            edges_width=edges_width,
            layout=layout,
            threshold=threshold,
            node_size_scale=node_size_scale,
            edge_width_scale=edge_width_scale,
            font_size=font_size,
            point_size=point_size,
            alpha=alpha,
            figsize=figsize,
            ncols=ncols,
            seed=seed,
            leiden_key='leiden',
            save_dir=save_dir,
            file_suffix=file_suffix + f'_branch{show_branch}',
            save_format='pdf'
        )


In [None]:
# Highlight cells in branch 0_4 in global knn
show_branch = '2_0'
hl_cells = adata.obs['branch'].str.startswith(show_branch)
hl_cell_indices = np.where(hl_cells)[0]
font_size=8
point_size=1
alpha=0.8
figsize=(0.9*len(combined_keys),1)
ncols = len(combined_keys)
nrows = int(np.ceil(len(param_keys) / ncols))
k=15
edges_color='grey'
edges_width=0.01
layout='kk'
threshold = 0.1
node_size_scale=0.1
edge_width_scale=0.1
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        combined_keys,
        color_bys=color_bys,
        basis_types=basis_types,
        pal=pal,
        k=k,
        edges_color=edges_color,
        edges_width=edges_width,
        layout=layout,
        threshold=threshold,
        node_size_scale=node_size_scale,
        edge_width_scale=edge_width_scale,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        leiden_key='leiden',
        save_dir=save_dir,
        highlight_indices = hl_cell_indices,
        highlight_size=point_size,
        file_suffix=file_suffix+f'hl_cells_branch{show_branch}',
        save_format='pdf'
    )


### Concord-hcl

In [None]:
# Zoom in to one of the branches
import copy
concord_args_base = {
    'latent_dim': 100,
    'batch_size':32,
    'n_epochs': 10,
    'p_intra_knn': 0.0,
    'clr_beta': 1.0,
    'element_mask_prob': 0.3,
    'feature_mask_prob': 0.0,
    'domain_key': batch_key,
    'seed': seed,
    'device': device,
    'save_dir': save_dir, 
    'load_data_into_memory': True,
    'verbose': False,
}

In [None]:
param_variations = {
    'clr_beta': [0.0, 0.5, 1.0, 2.0, 10.0, 20.0],  # Example: vary clr_beta
}
param_keys = []

for param_name, values in param_variations.items():
    print(f"\n[INFO] Varying '{param_name}' with possible values: {values}\n")
    
    for value in values:
        concord_args = copy.deepcopy(concord_args_base)
        concord_args[param_name] = value
        output_key = f"Concord_{param_name}_{value}"
        umap_key   = output_key + '_UMAP'
        param_keys.append(output_key)
        print(f"[INFO] Running Concord with {param_name} = {value}")
        
        cur_ccd = ccd.Concord(adata, **concord_args)
        cur_ccd.fit_transform(
            output_key=output_key,
        )       

print("[DONE] Finished varying each parameter individually.")


In [None]:
combined_keys = ['no_noise', 'wt_noise'] + param_keys
for key in combined_keys:
    ccd.ul.run_umap(adata, source_key=key, result_key=key + '_UMAP', random_state=seed, min_dist=0.5)

In [None]:
show_keys = [] 
for key in combined_keys:
    key_new = f"beta={key.split('_')[-1]}" if 'clr_beta' in key else key
    adata.obsm[key_new] = adata.obsm[key]
    show_keys.append(key_new)


In [None]:
show_branches = adata[adata.obs['depth']==1].obs['branch'].unique()

for show_branch in show_branches:
    # Change the names of combined_keys for easy plotting, like Concord_p_intra_knn_0.0 to p_knn=0.0
    adata_sub = adata[adata.obs['branch'].str.startswith(show_branch)].copy()

    basis_types = ['KNN']
    font_size=8
    point_size=20
    alpha=0.8
    figsize=(0.9*len(show_keys),1)
    ncols = len(show_keys)
    nrows = int(np.ceil(len(param_keys) / ncols))
    k=15

    with plt.rc_context(rc=custom_rc):
        ccd.pl.plot_all_embeddings(
            adata_sub,
            show_keys,
            color_bys=color_bys,
            basis_types=basis_types,
            pal=pal,
            k=k,
            edges_color=edges_color,
            edges_width=edges_width,
            layout=layout,
            threshold=threshold,
            node_size_scale=node_size_scale,
            edge_width_scale=edge_width_scale,
            font_size=font_size,
            point_size=point_size,
            alpha=alpha,
            figsize=figsize,
            ncols=ncols,
            seed=seed,
            leiden_key='leiden',
            save_dir=save_dir,
            file_suffix='hcl' + file_suffix + f'_branch{show_branch}',
            save_format='pdf'
        )


In [None]:
show_branch = '2_0'
hl_cells = adata.obs['branch'].str.startswith(show_branch)
hl_cell_indices = np.where(hl_cells)[0]
font_size=8
point_size=1
alpha=0.8
figsize=(0.9*len(show_keys),1)
ncols = len(show_keys)
nrows = 1
k=15
edges_color='grey'
edges_width=0.01
layout='kk'
threshold = 0.1
node_size_scale=0.1
edge_width_scale=0.1
basis_types = ['KNN']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        show_keys,
        color_bys=color_bys,
        basis_types=basis_types,
        pal=pal,
        k=k,
        edges_color=edges_color,
        edges_width=edges_width,
        layout=layout,
        threshold=threshold,
        node_size_scale=node_size_scale,
        edge_width_scale=edge_width_scale,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        leiden_key='leiden',
        save_dir=save_dir,
        highlight_indices = hl_cell_indices,
        highlight_size=point_size,
        file_suffix=file_suffix+f'clr_hl_cells_branch{show_branch}',
        save_format='pdf'
    )


### Visualize heatmap

In [None]:
# sort and smooth the signal along the path
batch_id=adata.obs['batch'].unique()[0]
batch_indices = np.where(adata.obs['batch'] == batch_id)[0]
_, _, _, feature_order = ccd.ul.sort_and_smooth_signal_along_path(adata, signal_key='contrastive', path=batch_indices, sigma=2)
adata.obsm['contrastive_sorted'] = adata.obsm['contrastive'][:, feature_order]

_, _, _, feature_order = ccd.ul.sort_and_smooth_signal_along_path(adata, signal_key='concord_hcl', path=batch_indices, sigma=2)
adata.obsm['concord_hcl_sorted'] = adata.obsm['concord_hcl'][:, feature_order]

_, _, _, feature_order = ccd.ul.sort_and_smooth_signal_along_path(adata, signal_key='concord_knn', path=batch_indices, sigma=2)
adata.obsm['concord_knn_sorted'] = adata.obsm['concord_knn'][:, feature_order]

In [None]:
# Plot heatmap of original data and Concord latent
import matplotlib.pyplot as plt
figsize = (2.3, 1.8)
ncols = 6
title_fontsize = 9
dpi = 600
_, _, state_pal = ccd.pl.get_color_mapping(adata, state_key, pal='Paired', seed=seed)
_, _, batch_pal = ccd.pl.get_color_mapping(adata, batch_key, pal='Set1', seed=seed)
pal = {state_key: state_pal, batch_key: batch_pal}
with plt.rc_context(rc=custom_rc):
    fig, axes = plt.subplots(1, ncols, figsize=(figsize[0] * ncols, figsize[1]), dpi=dpi)
    ccd.pl.heatmap_with_annotations(adata, val='no_noise', obs_keys=[state_key], pal=pal, ax = axes[0], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='State', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
    ccd.pl.heatmap_with_annotations(adata, val='wt_noise', obs_keys=[state_key], pal=pal, ax = axes[1], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='State+noise', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
    ccd.pl.heatmap_with_annotations(adata, val='X', obs_keys=[state_key, batch_key], pal=pal, ax = axes[2], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='State+noise+batch', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
    ccd.pl.heatmap_with_annotations(adata, val='contrastive_sorted', obs_keys=[state_key, batch_key], pal=pal, ax = axes[3], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='Contrastive', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
    ccd.pl.heatmap_with_annotations(adata, val='concord_hcl_sorted', obs_keys=[state_key, batch_key], pal=pal, ax = axes[4], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='Concord_hcl', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
    ccd.pl.heatmap_with_annotations(adata, val='concord_knn_sorted', obs_keys=[state_key, batch_key], pal=pal, ax = axes[5], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='Concord_knn', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
    plt.tight_layout(w_pad=0.0, h_pad=0.1)
    plt.savefig(save_dir / f"all_heatmaps_{file_suffix}.svg", dpi=dpi, bbox_inches='tight')

In [None]:
adata.write_h5ad(data_dir / f"adata_{file_suffix}.h5ad")
adata_state.write_h5ad(data_dir / f"adata_state_{file_suffix}.h5ad")

print("Data saved to", data_dir / f"adata_{file_suffix}.h5ad")
print("State data saved to", data_dir / f"adata_state_{file_suffix}.h5ad")

## Evaluation

In [5]:
adata = sc.read(data_dir / f"adata_Jul20-2008.h5ad")

In [6]:
adata.shape

(4620, 1471)

In [None]:
concord_methods = ["contrastive", 'concord_hcl', 'concord_knn']
other_methods = [
    "PCA", "UMAP", "t-SNE", "DiffusionMap", "NMF", 
    "FactorAnalysis", "FastICA", "LDA", "scVI", "PHATE", 
]
run_methods = concord_methods + other_methods
# exclude ["UMAP", "t-SNE"] from run_method, and save as combined_keys
exclude_keys = ["PCA", "UMAP", "t-SNE"]
main_methods = ['no_noise', 'wt_noise'] + [key for key in run_methods if key not in exclude_keys]
hcl_param_keys = ['beta=0.5', 'beta=1.0', 'beta=2.0', 'beta=10.0']  # Example: vary beta
knn_param_keys = ['p_knn=0.1', 'p_knn=0.3', 'p_knn=0.5', 'p_knn=0.7', 'p_knn=1.0']  # Example: vary p_intra_knn
all_keys = main_methods + hcl_param_keys + knn_param_keys
all_keys

In [None]:
results = ccd.bm.run_benchmark_pipeline(
    adata,
    embedding_keys=all_keys,
    state_key=state_key,
    batch_key=batch_key,
    groundtruth_key="no_noise",  # use the Leiden clusters as ground truth
    expected_betti_numbers=[0, 0, 0],
    save_dir=save_dir / "benchmarks_out",
    file_suffix=file_suffix,  # e.g. "2025-06-25"
    run=("geometry"),          # run only these blocks
    plot_individual=False,          # skip the intermediate PDFs
    combine_plots=True,  # combine the plots into a single PDF
)

In [None]:
import pickle
geometry_full = pickle.load(open('../save/dev_simulation_treehard_singlebatch-Jul20/benchmarks_out/geometry_results_Jul20-2008.pkl', 'rb'))


In [None]:
trustworthiness_scores = geometry_full['trustworthiness']['scores']
# Drop row with column 'Embedding' in drop_methods
trustworthiness_scores = trustworthiness_scores[trustworthiness_scores['Embedding'].isin(hcl_param_keys + ['contrastive'])]
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_trustworthiness(trustworthiness_scores, text_shift=0.2, min_gap=0.002, legend=False, save_path=save_dir / f"hcl_trustworthiness_{file_suffix}.pdf", y_range=(0.68, 1), figsize=(2.8,1.9))

In [None]:

x_cut = 20
y_cut = 0.95
trustworthiness_zoom = trustworthiness_scores[(trustworthiness_scores['n_neighbors'] <= x_cut) & (trustworthiness_scores['Trustworthiness'] > y_cut)]
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_trustworthiness(trustworthiness_zoom, text_shift=0.2, min_gap=0.0, fontsize=3, legend=False, save_path=save_dir / f"hcl_trustworthiness_zoom{x_cut}_{file_suffix}.pdf", figsize=(1,1))

In [None]:
trustworthiness_scores = geometry_full['trustworthiness']['scores']
trustworthiness_scores = trustworthiness_scores[trustworthiness_scores['Embedding'].isin(knn_param_keys + ['contrastive'])]
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_trustworthiness(trustworthiness_scores, text_shift=0.2, min_gap=0.002, legend=False, save_path=save_dir / f"knn_trustworthiness_{file_suffix}.pdf", y_range=(0.68, 1), figsize=(2.8,1.9))

In [None]:

x_cut = 20
y_cut = 0.95
trustworthiness_zoom = trustworthiness_scores[(trustworthiness_scores['n_neighbors'] <= x_cut) & (trustworthiness_scores['Trustworthiness'] > y_cut)]
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_trustworthiness(trustworthiness_zoom, text_shift=0.2, min_gap=0.0, fontsize=3, legend=False, save_path=save_dir / f"knn_trustworthiness_zoom{x_cut}_{file_suffix}.pdf", figsize=(1,1))

## Combine all metrics

In [None]:
# Combine all metrics into one table
import pandas as pd
all_scores = pd.concat([geometry_scores, connectivity_scores], axis=1)
all_scores

In [None]:
#all_scores[('Aggregate score', 'Topology')] = all_scores[('', 'Topology Score')]
all_scores[('Aggregate score', 'Geometry')] = all_scores[('', 'Geometric Score')]
all_scores[('Aggregate score', 'Graph')] = all_scores[('', 'Graph Score')]
all_scores[('Aggregate score', 'Average Score')] = all_scores["Aggregate score"][["Geometry", "Graph"]].mean(axis=1)
# sort by average score
all_scores.sort_values(by=[('Aggregate score', 'Average Score')], ascending=False, inplace=True)
all_scores.drop(
    columns=[
        ('', 'Graph Score'),
        ('', 'Geometric Score')
    ],
    inplace=True
)
# Save table
all_scores.to_csv(save_dir / f"benchmark_results_{file_suffix}.csv")
all_scores

In [None]:
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_benchmark_table(all_scores, pal='PRGn', pal_agg='RdBu_r', cmap_method = 'minmax', agg_name = 'Aggregate score', save_path=save_dir / f"all_results_{file_suffix}.pdf", figsize=(15, 7), dpi=300)