# Benchmark CBCE

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import scanpy as sc
import time
from pathlib import Path
import torch
import Concord as ccd
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

from matplotlib import font_manager, rcParams
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
proj_name = "cellcycle_CBCE"
data_dir = Path('../data/CE_CB/')
save_dir = f"../save/dev_{proj_name}-{time.strftime('%b%d')}/"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print(device)
seed = 0
ccd.ul.set_seed(seed)
file_suffix = f"{time.strftime('%b%d-%H%M')}"

In [None]:
adata = sc.read(data_dir / "adata_cbce_Dec21-0244.h5ad")
adata.obsm.keys()

In [5]:
concord_keys = ["Concord", 'Concord-decoder']
other_keys = ["Unintegrated", "Scanorama", "Liger", "Harmony", "scVI", "Seurat"]
combined_keys = other_keys + concord_keys

In [6]:
# Define color palette for broad_cell_type_qz
_, _, tissue_pal = ccd.pl.get_color_mapping(adata, 'broad_cell_type_qz', pal='Paired', seed=2)
_, _, species_pal = ccd.pl.get_color_mapping(adata, 'species', pal='Set1', seed=seed)
_, _, broadlin_pal = ccd.pl.get_color_mapping(adata,'broad_lineage', pal='Paired', seed=seed)
_, _, batch_pal = ccd.pl.get_color_mapping(adata,'dataset3', pal='Set1', seed=seed)
pal = {'embryo.time': 'BlueGreenRed', 
       "cell_type": 'Paired', 
       'species': species_pal, 
       'dataset3': batch_pal,
       'lineage_complete': 'Paired',
       'broad_lineage': broadlin_pal,
       'ct_or_lin': 'Paired', 
       'broad_cell_type_qz': tissue_pal, 
       'ct_or_broad_lin': 'Paired',
       'plot_cell_type': 'Paired',}


In [None]:
output_key = 'Concord-decoder'
basis = output_key
adata_clean = adata[adata.obs['broad_cell_type_qz'] != 'doublet/debris']
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
show_cols = ['broad_cell_type_qz', 'cell_type', 'ct_or_broad_lin', 'broad_lineage', 'embryo.time', 'species']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_clean, show_basis, show_cols, figsize=(10, 6.7), dpi=600, ncols=3, font_size=3, point_size=.8, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_cleaned_wttext.pdf"
    )

In [None]:
# Example lists of marker genes by cell-cycle phase (C. elegans)
G1_genes = [
    "cyd-1",   # Cyclin D
    "cdk-4",   # CDK4 partner
    "cki-1",   # G1 inhibitor
    "cki-2",   # G1 inhibitor
    "lin-35"   # Rb
]

S_genes = [
    "cye-1",   # Cyclin E
    "cdk-2",   # CDK2 partner
    "efl-1",   # E2F
    "dpl-1",   # DP
    "cdc-6",   # Licensing factor
    "cdt-1",   # Licensing factor
    "cdt-2"    # E3 ligase component
]

G2M_genes = [
    "cdk-1",   # CDK1
    # cdc-25 family (G2/M transition), you can expand as needed
    "cdc-25.1",
    "cdc-25.2",
    "cdc-25.3",
    "cdc-25.4",
    "cdc-25.5"
]

M_exit_genes = [
    "fzr-1",   # Cdh1 homolog
    "cdc-14"   # Cdc14 phosphatase
]

cc_genes = G1_genes + S_genes + G2M_genes + M_exit_genes
cc_genes = list(set(cc_genes))
print("Number of cell-cycle genes:", len(cc_genes))

In [None]:
genes_in = [g for g in cc_genes if g in adata.var_names]
print("Number of cell-cycle genes in the dataset:", len(genes_in))

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
show_cols = genes_in
base_fig_size = (3,3)
ncols = 6
nrows = int(np.ceil(len(show_cols) / ncols))
figsize = (base_fig_size[0]*ncols, base_fig_size[1]*nrows)
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_clean, show_basis, show_cols, figsize=figsize, dpi=600, ncols=ncols, font_size=3, point_size=.4, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_gexpr.pdf"
    )

In [None]:
gene_sym = ['mom-2', 'pop-1', 'sys-1', 'lin-44', 'lin-17', 'par-6', 'par-1']
gene_sym = [g for g in gene_sym if g in adata.var_names]
print("Number of genes in the dataset:", len(gene_sym))

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = f'{basis}_UMAP'
show_cols = gene_sym
base_fig_size = (3,3)
ncols = 6
nrows = int(np.ceil(len(show_cols) / ncols))
figsize = (base_fig_size[0]*ncols, base_fig_size[1]*nrows)
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata, show_basis, show_cols, figsize=figsize, dpi=600, ncols=ncols, font_size=3, point_size=2, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_sym_gexpr.pdf"
    )

### Early 200

In [18]:
subset_name = 'early200'
adata_subset = adata[adata.obs['embryo.time'] <= 200]

In [None]:
# Run umap and PCA for all latent embeddings
for basis in combined_keys:
    print("Running UMAP for", basis)
    if basis not in adata_subset.obsm:
        continue
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata_subset, source_key=basis, result_key=f'{basis}_UMAP_{subset_name}', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [20]:
adata_subset.write_h5ad(data_dir / f"adata_cbce_{file_suffix}_{subset_name}.h5ad")

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = f'{basis}_UMAP_{subset_name}'
show_cols = genes_in
base_fig_size = (3,3)
ncols = 6
nrows = int(np.ceil(len(show_cols) / ncols))
figsize = (base_fig_size[0]*ncols, base_fig_size[1]*nrows)
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_subset, show_basis, show_cols, figsize=figsize, dpi=600, ncols=ncols, font_size=3, point_size=2, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_gexpr.pdf"
    )

In [None]:
gene_sym = ['mom-2', 'pop-1', 'sys-1', 'wrm-1', 'lin-44', 'lin-17', 'par-6', 'par-1', 'F21D5.9', 'sem-2', ]
gene_sym = [g for g in gene_sym if g in adata.var_names]
print("Number of genes in the dataset:", len(gene_sym))

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = f'{basis}_UMAP_{subset_name}'
show_cols = gene_sym
base_fig_size = (3,3)
ncols = 6
nrows = int(np.ceil(len(show_cols) / ncols))
figsize = (base_fig_size[0]*ncols, base_fig_size[1]*nrows)
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_subset, show_basis, show_cols, figsize=figsize, dpi=600, ncols=ncols, font_size=3, point_size=2, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_sym_gexpr.pdf"
    )