# Simulation using Concord

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import scanpy as sc
import time
from pathlib import Path
import torch
import concord as ccd
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
import matplotlib as mpl

from matplotlib import font_manager, rcParams
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
proj_name = "simulation_all_topo"
save_dir = f"../save/dev_{proj_name}-{time.strftime('%b%d')}/"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)

data_dir = f"../data/{proj_name}/"
data_dir = Path(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print(device)
seed = 0
ccd.ul.set_seed(seed)

file_suffix = f"{time.strftime('%b%d-%H%M')}"
file_suffix

In [None]:
concord_methods = ['concord_hcl', 'concord_knn', 'contrastive']
other_methods = [
        "unintegrated",
        "scanorama", "liger", "harmony",
        "scvi", 
]
run_methods = concord_methods + other_methods
# exclude ["UMAP", "t-SNE"] from run_method, and save as combined_keys
exclude_keys = ["PCA", "UMAP", "t-SNE"]
combined_keys = ['no_noise', 'wt_noise'] + [key for key in run_methods if key not in exclude_keys]

In [None]:
state_key = 'structure'
batch_key = 'batch'
# ─────────────── General parameters ───────────────
N_GENES_BASE = 300
N_CELLS_BASE = 500
MEAN_EXPRESSION = 10
GLOBAL_NOISE_DISP = 5

In [None]:
from concord.simulation import (
    Simulation, SimConfig,
    ClusterConfig, TrajectoryConfig, TreeConfig,
    BatchConfig,
)
from concord.utils.anndata_utils import ordered_concat
import numpy as np
import pandas as pd
import anndata as ad

# ─────────────── 1) Cluster ───────────────
sim_cfg_cluster = SimConfig(n_cells=N_CELLS_BASE, n_genes=100, seed=1, non_neg=True, to_int=True)
cluster_cfg = ClusterConfig(
    n_states=1,
    program_structure="uniform",
    program_on_time_fraction=0.5,
    level=MEAN_EXPRESSION,
    dispersion=0,
    min_level=0,
)
sim_cluster = Simulation(sim_cfg_cluster, cluster_cfg, BatchConfig(n_batches=1))
adata_cluster = sim_cluster.simulate_state()
adata_cluster.obs["structure"] = "cluster"

# ─────────────── 2) Loop ───────────────
sim_cfg_loop = SimConfig(n_cells=N_CELLS_BASE, n_genes=N_GENES_BASE, seed=2, non_neg=True, to_int=True)
loop_cfg = TrajectoryConfig(
    program_num=4,
    loop_to=[0],
    cell_block_size_ratio=0.5,
    program_structure="linear_bidirectional",
    program_on_time_fraction=0.1,
    distribution="normal",
    level=MEAN_EXPRESSION,
    dispersion=0,
    min_level=0,
)
sim_loop = Simulation(sim_cfg_loop, loop_cfg, BatchConfig(n_batches=1))
adata_loop = sim_loop.simulate_state()
adata_loop.obs["structure"] = "loop"

# ─────────────── 3) Tree ───────────────
sim_cfg_tree = SimConfig(n_cells=N_CELLS_BASE, n_genes=400, seed=3, non_neg=True, to_int=True)
tree_cfg = TreeConfig(
    branching_factor=2,
    depth=2,
    program_structure = "dimension_increase",
    program_on_time_fraction=0.1,
    program_gap_size=1,
    program_decay=1.0,
    cellcount_decay=1.0,
    level=MEAN_EXPRESSION,
    dispersion=0,
    min_level=0,
    noise_in_block=False,
)
sim_tree = Simulation(sim_cfg_tree, tree_cfg, BatchConfig(n_batches=1))
adata_tree = sim_tree.simulate_state()
adata_tree.obs["structure"] = "tree"

# ─────────────── 4) Trajectory ───────────────
sim_cfg_traj = SimConfig(n_cells=N_CELLS_BASE, n_genes=N_GENES_BASE, seed=4, non_neg=True, to_int=True)
traj_cfg = TrajectoryConfig(
    program_num=5,
    cell_block_size_ratio=0.5,
    program_structure="linear_bidirectional",
    program_on_time_fraction=0.1,
    distribution="normal",
    level=MEAN_EXPRESSION,
    dispersion=0,
    min_level=0,
)
sim_traj = Simulation(sim_cfg_traj, traj_cfg, BatchConfig(n_batches=1))
adata_traj = sim_traj.simulate_state()
adata_traj.obs["structure"] = "trajectory"

def rename_genes(adata: ad.AnnData, prefix: str) -> ad.AnnData:
    """Give unique gene names to avoid overlap across topologies."""
    adata.var_names = [f"{prefix}_Gene_{i+1}" for i in range(adata.n_vars)]
    adata.var.index = adata.var_names
    return adata

adata_cluster = rename_genes(adata_cluster, "cluster")
adata_loop    = rename_genes(adata_loop, "loop")
adata_tree    = rename_genes(adata_tree, "tree")
adata_traj    = rename_genes(adata_traj, "traj")

adata_state = ordered_concat(
    [adata_cluster, adata_loop, adata_tree, adata_traj],
    join="outer",
    label="topology",
    index_unique="-",
)

adata_state.X = np.nan_to_num(adata_state.X, nan=0.0)
adata_state.layers['no_noise'] = adata_state.X.copy()  # store the original expression matrix
adata_state.X += Simulation.simulate_distribution("normal", adata_state.X, GLOBAL_NOISE_DISP, nonzero_only=False)
adata_state.X[adata_state.X < 0] = 0
adata_state.layers['wt_noise'] = adata_state.X.copy()  # store the noise layer

# Step 1: Create a dummy Simulation instance to access batch methods

batch_cfg = BatchConfig(
    n_batches=2,
    effect_type="batch_specific_features",
    distribution="normal",
    level=[MEAN_EXPRESSION, MEAN_EXPRESSION],
    dispersion=[GLOBAL_NOISE_DISP, GLOBAL_NOISE_DISP],
    feature_frac=0.1,            # batch_feature_frac
    cell_proportion=[0.5, 0.5],  # proportion of cells in each batch
)

sim = Simulation(SimConfig(non_neg=True, to_int=True), ClusterConfig(), batch_cfg)  # state config can be any

batch_list, state_list = [], []
for i in range(sim.batch_config.n_batches):
    rng = np.random.default_rng(sim.sim_config.seed)

    # Determine cells for this batch
    cell_proportion = sim.batch_config.cell_proportion[i]
    n_cells = int(adata_state.n_obs * cell_proportion)
    #cell_indices = rng.choice(adata.n_obs, n_cells, replace=False)
    cell_indices = np.sort(rng.choice(adata_state.n_obs, n_cells, replace=False))
    batch_adata_pre = adata_state[cell_indices].copy()
    batch_adata = sim.simulate_batch(
        batch_adata_pre,
        batch_idx=i
    )
    batch_list.append(batch_adata)
    state_list.append(batch_adata_pre)

adata = sim._finalize_anndata(batch_list, join='outer')
adata_state = sim._finalize_anndata(state_list, join='outer')


In [None]:
ccd.pl.heatmap_with_annotations(adata_state, val='no_noise', obs_keys=[state_key], yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='True state', save_path=save_dir/f'true_state_heatmap_{file_suffix}.png', figsize=(6, 4), dpi=300)
ccd.pl.heatmap_with_annotations(adata_state, val='wt_noise', obs_keys=[state_key], yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='True state with noise', save_path=save_dir/f'true_state_with_noise_heatmap_{file_suffix}.png', figsize=(6, 4), dpi=300)
ccd.pl.heatmap_with_annotations(adata, val='X', obs_keys=[state_key, batch_key], yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='Simulated data with batch signal', save_path=save_dir/f'simulated_data_heatmap_{file_suffix}.png', figsize=(6, 4), dpi=300)

### Concord

In [None]:
run_methods = concord_methods + other_methods
#run_methods = concord_methods
latent_dim = 30

In [None]:
# Add ground truth
ccd.ul.run_pca(adata_state, source_key='no_noise', result_key='PCA_no_noise', n_pc=latent_dim, random_state=seed)
ccd.ul.run_pca(adata_state, source_key='wt_noise', result_key='PCA_wt_noise', n_pc=latent_dim, random_state=seed)
# Put the PCA result in the adata object, so only one object is needed
adata.obsm['no_noise'] = adata_state.obsm['PCA_no_noise']
adata.obsm['wt_noise'] = adata_state.obsm['PCA_wt_noise']
ccd.ul.run_umap(adata, source_key='no_noise', result_key='no_noise_UMAP', n_components=2, random_state=seed)
ccd.ul.run_umap(adata, source_key='wt_noise', result_key='wt_noise_UMAP', n_components=2, random_state=seed)

In [None]:
adata.shape

In [None]:
concord_kwargs = {
    'batch_size': 128,
    'n_epochs': 20,
    'load_data_into_memory': True,
    'verbose': False,
}
profile_logs = ccd.bm.run_integration_methods_pipeline(
    adata=adata,                          
    methods=run_methods,            # List of methods to run
    batch_key=batch_key,                    # Column in adata.obs for batch info
    count_layer="counts",                 # Layer name containing raw counts
    class_key=None,               # Column in adata.obs for class labels (used in SCANVI and CONCORD variants)
    latent_dim=latent_dim,                        # Latent dimensionality for PCA and embeddings
    device="cpu",                        # Or "cpu", or "mps" for Apple Silicon
    return_corrected=False,                   # Whether to store corrected expression matrices
    transform_batch=None,                 # Optionally specify a batch to transform to in scVI
    seed=seed,                              # Random seed for reproducibility
    compute_umap=False,                        # Run UMAP for all output embeddings
    umap_n_components=2,
    umap_n_neighbors=30,
    umap_min_dist=0.5,
    verbose=True,                        # Print progress messages
    save_dir=save_dir,
    concord_kwargs=concord_kwargs,
)


In [None]:
for basis in run_methods:
    print("Running UMAP for", basis)
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.8, metric='euclidean', random_state=seed)

In [None]:
# plot everything
import matplotlib.pyplot as plt

show_keys = run_methods
# check which methods are run successfully
color_bys = [state_key, batch_key]
basis_types = ['KNN', 'UMAP']
font_size=8
point_size=2
alpha=0.8
figsize=(0.9*len(show_keys),1)
ncols = len(show_keys)
nrows = int(np.ceil(len(show_keys) / ncols))
pal = {'time':'viridis', 'batch':'Set1'}
k=30
edges_color='grey'
edges_width=0
layout='kk'
threshold = 0.1
node_size_scale=0.1
edge_width_scale=0.1

rasterized = True
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        show_keys,
        color_bys=color_bys,
        basis_types=basis_types,
        pal=pal,
        k=k,
        edges_color=edges_color,
        edges_width=edges_width,
        layout=layout,
        threshold=threshold,
        node_size_scale=node_size_scale,
        edge_width_scale=edge_width_scale,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        rasterized=rasterized,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        leiden_key='leiden',
        save_dir=save_dir,
        legend_loc = 'on data',
        file_suffix=file_suffix+f'rasterized_{rasterized}',
        save_format='svg'
    )


In [None]:
concord_args = {
    'adata': adata,
    'input_feature': None, 
    'domain_key': None,
    #'p_intra_knn':0.3,
    'latent_dim': LATENT_DIM, # latent dimension size
    'seed': seed, # random seed
    'verbose': False, # print training progress
    'device': device, # device to run on
    'load_data_into_memory': True,
    'save_dir': save_dir # directory to save model checkpoints
}
cur_ccd = ccd.Concord(**concord_args)

# Encode data, saving the latent embedding in adata.obsm['Concord']
output_key = 'Concord'
cur_ccd.fit_transform(output_key=output_key)

In [None]:
for basis in run_methods:
    print("Running UMAP for", basis)
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.5, metric='euclidean', random_state=seed)

In [None]:
# plot everything
import matplotlib.pyplot as plt
import pandas as pd
combined_keys=['PCA_no_noise', 'PCA_noise', 'Concord']

# Set Arial as the default font
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

color_bys = ['structure', 'time']
basis_types = ['KNN']
#basis_types = ['PCA']
#basis_types = ['KNN']
font_size=8
point_size=2.5
alpha=0.8
figsize=(0.9*len(combined_keys),1)
ncols = len(combined_keys)
nrows = int(np.ceil(len(combined_keys) / ncols))
pal = {'time':'viridis', 'batch':'Set1'}
k=40
edges_color='grey'
edges_width=0.0
layout='kk'
threshold = 0
node_size_scale=0.1
edge_width_scale=0.1

rasterized = True
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        combined_keys,
        color_bys=color_bys,
        basis_types=basis_types,
        pal=pal,
        k=k,
        edges_color=edges_color,
        edges_width=edges_width,
        layout=layout,
        threshold=threshold,
        node_size_scale=node_size_scale,
        edge_width_scale=edge_width_scale,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        rasterized=rasterized,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        leiden_key='leiden',
        save_dir=save_dir,
        file_suffix=file_suffix +f'rasterized_{rasterized}'+ f'ngenestraj_{N_GENES_TRAJ}_ngenestree_{N_GENES_TREE}',
        save_format='svg'
    )


In [None]:
state_key = 'structure'
cell_order = adata.obs['time'].argsort()
_, _, _, feature_order = ccd.ul.sort_and_smooth_signal_along_path(adata, signal_key='Concord', path=cell_order, sigma=2)
adata.obsm['Concord_sorted'] = adata.obsm['Concord'][:, feature_order]
# Plot heatmap of original data and Concord latent
import matplotlib.pyplot as plt
figsize = (2.3, 1.8)
ncols = 3
title_fontsize = 8
dpi = 600
fig, axes = plt.subplots(1, ncols, figsize=(figsize[0] * ncols, figsize[1]), dpi=dpi)
ccd.pl.heatmap_with_annotations(adata, val='no_noise', obs_keys=[state_key], ax = axes[0], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='State', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
ccd.pl.heatmap_with_annotations(adata, val='X', obs_keys=[state_key], ax = axes[1], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='State+noise', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
ccd.pl.heatmap_with_annotations(adata, val='Concord_sorted', obs_keys=[state_key], ax = axes[2], use_clustermap=False, yticklabels=False, cluster_cols=False, cluster_rows=False, value_annot=False, cmap='viridis', title='Concord latent', save_path=None, figsize=figsize, dpi=dpi, title_fontsize=title_fontsize)
plt.tight_layout(w_pad=0.0, h_pad=0.1)
plt.savefig(save_dir / f"all_heatmaps_{file_suffix}_ngenestraj_{N_GENES_TRAJ}_ngenestree_{N_GENES_TREE}.svg", dpi=dpi, bbox_inches='tight')