# Benchmark CBCE

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import scanpy as sc
import time
from pathlib import Path
import torch
import Concord as ccd
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

from matplotlib import font_manager, rcParams
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
proj_name = "benchmark_CBCE"
data_dir = Path('../data/CE_CB/')
save_dir = f"../save/dev_{proj_name}-{time.strftime('%b%d')}/"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
#device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps')
print(device)
seed = 0
ccd.ul.set_seed(seed)
file_suffix = f"{time.strftime('%b%d-%H%M')}"

In [None]:
#file_suffix = 'Jan30-1028'

In [None]:
#adata = sc.read(data_dir / "adata_cbce_Dec21-0244.h5ad")
adata = sc.read(data_dir / "adata_cbce_Jan30-1028.h5ad")
adata.obsm.keys()

In [None]:
# Add Contrastive learning result
basis = 'Contrastive'
new_obsm = ccd.ul.load_obsm_from_hdf5(Path("../save/dev_cbce_1217-Jan22/")/ "obsm_Contrastive_Jan22-1932_Jan22-1932.h5")
adata.obsm[basis] = new_obsm['Contrastive_Jan22-1932']
ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)

In [None]:
concord_keys = ["Concord", 'Concord-decoder']
other_keys = ["Unintegrated", "Scanorama", "Liger", "Harmony", "scVI", "Seurat", "Contrastive"]
combined_keys = other_keys + concord_keys

In [None]:
# keep keys only contain combined_keys
keep_keys = combined_keys + [f'{key}_UMAP' for key in combined_keys] + [f'{key}_UMAP_3D' for key in combined_keys]
adata.obsm = {key: adata.obsm[key] for key in adata.obsm.keys() if key in keep_keys}
adata.obsm

In [None]:
# Define color palette for broad_cell_type_qz
_, _, tissue_pal = ccd.pl.get_color_mapping(adata, 'broad_cell_type_qz', pal='Paired', seed=2)
_, _, species_pal = ccd.pl.get_color_mapping(adata, 'species', pal='Set1', seed=seed)
_, _, broadlin_pal = ccd.pl.get_color_mapping(adata,'broad_lineage', pal='Paired', seed=seed)
_, _, batch_pal = ccd.pl.get_color_mapping(adata,'dataset3', pal='Set1', seed=seed)
pal = {'embryo.time': 'BlueGreenRed', 
       "cell_type": 'Paired', 
       'species': species_pal, 
       'dataset3': batch_pal,
       'lineage_complete': 'Paired',
       'broad_lineage': broadlin_pal,
       'ct_or_lin': 'Paired', 
       'broad_cell_type_qz': tissue_pal, 
       'ct_or_broad_lin': 'Paired',
       'plot_cell_type': 'Paired',}


In [None]:
### SKIP if loaded obsm
adata.obsm['Concord'] = adata.obsm['Concord_Dec17-0930']
adata.obsm['Concord_UMAP'] = adata.obsm['Concord_Dec17-0930_UMAP']
adata.obsm['Concord_UMAP_3D'] = adata.obsm['Concord_Dec17-0930_UMAP_3D']
adata.obsm['Concord-decoder'] = adata.obsm['Concord-decoder_Dec18-1358']
adata.obsm['Concord-decoder_UMAP'] = adata.obsm['Concord-decoder_Dec18-1358_UMAP']
adata.obsm['Concord-decoder_UMAP_3D'] = adata.obsm['Concord-decoder_Dec18-1358_UMAP_3D']
adata.obsm['Seurat'] = adata.obsm['integrated.rpca']
adata.obsm['Seurat_UMAP'] = adata.obsm['umap.rpca']
#adata.obsm['Unintegrated_UMAP'] = adata.obsm['Unintegrated_umap']
#del adata.obsm['Unintegrated_umap']


In [None]:
# SKIP if not fresh run
adata.obs['ct_or_lin'] = adata.obs['cell_type'].astype(str)
adata.obs['ct_or_lin'][adata.obs['cell_type'].isin(['unassigned', 'NaN'])] = adata.obs['lineage_complete'][adata.obs['cell_type'].isin(['unassigned', 'NaN'])].astype(str)
adata.obs['ct_or_lin'][adata.obs['ct_or_lin'].isin(['unassigned', 'NaN'])] = np.NaN
adata.obs['lin_or_ct'] = adata.obs['lineage_complete'].astype(str)
adata.obs['lin_or_ct'][adata.obs['lineage_complete'].isin(['unassigned', 'NaN'])] = adata.obs['cell_type'][adata.obs['lineage_complete'].isin(['unassigned', 'NaN'])].astype(str)
adata.obs['lin_or_ct'][adata.obs['lin_or_ct'].isin(['unassigned', 'NaN'])] = np.NaN

adata.obs['cell_type'][adata.obs['cell_type'].isin(['unassigned', 'NaN'])] = np.NaN
adata.obs['lineage_complete'][adata.obs['lineage_complete'].isin(['unassigned', 'NaN'])] = np.NaN

In [None]:
# plot everything
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import font_manager, rcParams

# Set Arial as the default font
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

show_keys = combined_keys
show_cols = ['embryo.time', 'cell_type', 'species', 'lineage_complete', 'ct_or_lin', 'lin_or_ct']
show_cols = ['lin_or_ct', 'ct_or_lin']
basis_types = ['UMAP']

font_size=10
point_size=.1
alpha=0.8
figsize=(11.5,1.5)
ncols = 9
nrows = int(np.ceil(len(show_keys) / ncols))


with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        show_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=pal,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        save_dir=save_dir,
        file_suffix=file_suffix,
        save_format='svg'
    )


### plot

In [None]:
# plot everything
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import font_manager, rcParams

# Set Arial as the default font
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

show_keys = combined_keys
show_cols = ['embryo.time', 'species']
basis_types = ['UMAP']

font_size=10
point_size=.1
alpha=0.8
figsize=(11.5,1.4)
ncols = len(show_keys)
nrows = int(np.ceil(len(show_keys) / ncols))


with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        show_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=pal,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        save_dir=save_dir,
        file_suffix=file_suffix,
        dpi=600,
        save_format='svg'
    )


In [None]:
output_key = 'Concord'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
#show_cols = ['cell.type', 'assigned_cell_type', 'round.1.cell.type', 'cell.class', 'tissue.type', 'br.broad.cell.type']
show_cols = ['embryo.time', 'cell_type', 'species']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata, show_basis, show_cols, figsize=(10,3.3), dpi=600, ncols=3, font_size=4, point_size=0.5, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_wttext.pdf"
    )

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
show_cols = ['embryo.time', 'cell_type', 'species']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata, show_basis, show_cols, figsize=(10,3.3), dpi=600, ncols=3, font_size=4, point_size=.8, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_wttext.pdf"
    )

In [None]:
# Leiden clustering
sc.pp.neighbors(adata, n_neighbors=30, use_rep='Concord')
sc.tl.leiden(adata, resolution=0.5, key_added='leiden_Concord_broad')
sc.pp.neighbors(adata, n_neighbors=30, use_rep='Concord-decoder')
sc.tl.leiden(adata, resolution=0.5, key_added='leiden_Concord-decoder_broad')

In [None]:
output_key = 'Concord'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
show_cols = ['embryo.time', 'cell_type', 'leiden_Concord_broad']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata, show_basis, show_cols, figsize=(10,3.3), dpi=600, ncols=3, font_size=6, point_size=.8, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_wttext.pdf"
    )

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
show_cols = ['embryo.time', 'cell_type', 'leiden_Concord-decoder_broad']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata, show_basis, show_cols, figsize=(10,3.3), dpi=600, ncols=3, font_size=6, point_size=.8, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_wttext.pdf"
    )

In [None]:
# Cross tab cluster and cell type, 
import pandas as pd
cluster_key = 'leiden_Concord-decoder_broad'
cell_type_key = 'cell_type'
ct_clus_crosstab = pd.crosstab(adata.obs[cluster_key], adata.obs[cell_type_key])
# for each cluster list top 2 cell types
top_cts = []
for clus in ct_clus_crosstab.index:
    top_cts.append(ct_clus_crosstab.loc[clus].sort_values(ascending=False).index[0:2].values)

top_cts = pd.DataFrame(top_cts, columns=['top1', 'top2'], index=ct_clus_crosstab.index)
top_cts

In [None]:
### Assign broad lineage by taking up to first 5 chars of complete_lineage
adata.obs['broad_lineage'] = adata.obs['lineage_complete'].astype(str)
adata.obs['broad_lineage'] = adata.obs['broad_lineage'].str.slice(0,5)


# Refine broad lineage
# Anything start with 'Cxa', 'Cpa', 'Caa' map to 'Cxa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('Cxa')] = 'Cxa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('Cpa')] = 'Cxa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('Caa')] = 'Cxa'
# Anything start with 'Cxp', 'Cpp', 'Cap' map to 'Cxp'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('Cxp')] = 'Cxp'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('Cpp')] = 'Cxp'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('Cap')] = 'Cxp'

# Anything start with 'E' map to 'E'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('E')] = 'E'

# Anything start with 'D' map to 'D'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('D')] = 'D'


# Anything start with 'MSxpa', 'MSapa', 'MSxpp' map to 'MSxpa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSxpa')] = 'MSxpa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSapa')] = 'MSxpa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSppa')] = 'MSxpa'

# Anything start with 'MSxaa', 'MSpaa', 'MSaaa' map to 'MSxaa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSxaa')] = 'MSxaa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSpaa')] = 'MSxaa'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSaaa')] = 'MSxaa'

# Anything start with 'MSxap', 'MSpap', 'MSaap' map to 'MSxap'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSxap')] = 'MSxap'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSpap')] = 'MSxap'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSaap')] = 'MSxap'

# Anything start with 'MSxpp', 'MSppp', 'MSapp' map to 'MSxpp'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSxpp')] = 'MSxpp'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSppp')] = 'MSxpp'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MSapp')] = 'MSxpp'

# Anything start with 'ABpxap', 'ABplap', 'ABprap' map to 'ABpxap'
# adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('ABpxap')] = 'ABpxap'
# adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('ABplap')] = 'ABpxap'
# adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('ABprap')] = 'ABpxap'

# If a string start with AB has less than 6 chars, map to 'AB early'
adata.obs['broad_lineage'][(adata.obs['broad_lineage'].str.startswith('AB') & (adata.obs['broad_lineage'].str.len() < 5)) | (adata.obs['broad_lineage'] == 'ABaxx')] = 'AB early'

adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('MS') & (adata.obs['broad_lineage'].str.len() < 5)] = 'MS early'
adata.obs['broad_lineage'][adata.obs['broad_lineage'].isin(['AB early', 'MS early', 'Cx', '28_ce', 'possi'])] = 'early cells'

adata.obs['broad_lineage'][adata.obs['broad_lineage'].str.startswith('Z2/Z3')] = 'Z2/Z3'
# Drop a few lineages that are not informative
drop_lineages = ['unass', 'nan']
adata.obs['broad_lineage'][adata.obs['broad_lineage'].isin(drop_lineages)] = 'NaN' 
adata.obs['broad_lineage'].value_counts()

In [None]:
# Assign broad cell types based on clustering
adata.obs['broad_cell_type_qz'] = adata.obs['leiden_Concord-decoder_broad'].astype(str)
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['56', '18', '58'])] = 'Intestine'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['30','22','12','3','23','72', '57', '77'])] = 'Pharynx'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['45'])] = 'Germline'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['28'])] = 'Early embryo'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['11','27','4','0','6', '10','5','17','38','36','26', '13', '39','59', '75', '47'])] = 'Mesoderm'
adata.obs['broad_cell_type_qz'][(adata.obs['broad_cell_type_qz'].isin(['Early embryo'])) & (adata.obs['broad_lineage'].isin(['Cxp', 'D']))] = 'Mesoderm'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['67'])] = 'M cell'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['74'])] = 'Z1_Z4'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['46'])] = 'doublet/debris'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['24','21','31','15','43','20', '1', '40', '53', '73', '62'])] = 'Hypodermis/Seam'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['55']) & (adata.obs['cell_type'] == 'hyp3')] = 'Hypodermis/Seam'
adata.obs['broad_cell_type_qz'][adata.obs['broad_cell_type_qz'].isin(['Hypodermis/Seam']) & (adata.obs['cell_type'] == 'AMso')] = 'AB lineage (non-hyp/seam/pha)'
adata.obs['broad_cell_type_qz'][adata.obs['cell_type'].isin(['Excretory_duct_and_pore', 'Excretory_cell'])] = 'AB lineage (non-hyp/seam/pha)'
# Set rest to 'AB lineage'
adata.obs['broad_cell_type_qz'][~adata.obs['broad_cell_type_qz'].isin(['Intestine', 'Pharynx', 'Germline', 'Mesoderm', 'M cell', 'Z1_Z4', 'Early embryo', 'Hypodermis/Seam', 'doublet/debris'])] = 'AB lineage (non-hyp/seam/pha)'


In [None]:
adata.obs['ct_or_broad_lin'] = adata.obs['cell_type'].astype(str)
adata.obs['ct_or_broad_lin'][adata.obs['cell_type']=='NaN'] = adata.obs['broad_lineage'][adata.obs['cell_type']=='NaN'].astype(str)
adata.obs['ct_or_broad_lin'].value_counts()


In [None]:
output_key = 'Concord-decoder'
#output_key = 'Concord'
#output_key = 'Contrastive'

In [None]:
adata.obsm.keys()

In [None]:
# Run umap and PCA for all latent embeddings
other_keys = ['scVI', 'Seurat']
for basis in other_keys:
    print("Running UMAP for", basis)
    if basis not in adata.obsm:
        continue
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP_3D', n_components=3, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [None]:
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
show_cols = ['broad_cell_type_qz', 'cell_type', 'ct_or_broad_lin', 'broad_lineage', 'embryo.time', 'species']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata, show_basis, show_cols, figsize=(10,6.7), dpi=600, ncols=3, font_size=3, point_size=.8, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_wttext.pdf"
    )

In [None]:
basis = output_key
adata_clean = adata[adata.obs['broad_cell_type_qz'] != 'doublet/debris']
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = basis + '_UMAP'
show_cols = ['broad_cell_type_qz', 'cell_type', 'ct_or_broad_lin', 'broad_lineage', 'embryo.time', 'species']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_clean, show_basis, show_cols, figsize=(10, 6.7), dpi=600, ncols=3, font_size=3, point_size=.2, alpha=0.9, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_cleaned_wttext.pdf"
    )

In [None]:
concord_keys = ['Concord-decoder']
#azims = [0, 45, 90, 135, 180, 225, 270, 315]
azims = [30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330]
elevs = [30, 45, 60]
# azims = [30]
# elevs = [30]
show_cols = ['embryo.time']
for azim in azims:
    for elev in elevs:
        with plt.rc_context(rc=custom_rc):
            ccd.pl.plot_all_embeddings_3d(
                adata=adata_clean,
                combined_keys=concord_keys,         # or however you name your embeddings
                color_bys=show_cols,  # columns or gene names to color by
                basis_types=['UMAP_3D'],              # or whatever naming convention is in adata.obsm
                pal=pal,    
                ncols=2,
                rasterized=True,                   # Points get rasterized
                point_size=1,
                alpha=0.8,
                elev=elev,
                azim=azim,
                zoom_factor=0.05,
                show_grid=True,
                show_axis_labels=False,            # Turn off axis labels
                show_ticks=False,                  # No ticks
                show_legend=False,
                tick_label_font_size=6,
                legend_font_size=6,
                save_dir=save_dir,
                file_suffix=file_suffix + f'_azim{azim}_elev{elev}',
                save_format='pdf',
            )

In [None]:
# Plot the distribution of 'est.bg.prop.mus', 'est.bg.prop.hyp', 'est.bg.prop.both', 'est.bg.prop' between the bad cells and good cells as violin plots
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

bad_clus_name = 'doublet/debris'
bad_cells = adata.obs[adata.obs['broad_cell_type_qz'] == bad_clus_name]
good_cells = adata.obs[adata.obs['broad_cell_type_qz'] != bad_clus_name]
# Downsample good cells
good_cells = good_cells.sample(n=bad_cells.shape[0], random_state=seed)

data = pd.concat([bad_cells, good_cells])
data['plot_cell_type'] = data['broad_cell_type_qz'].astype(str)
data['plot_cell_type'][data['plot_cell_type'] != bad_clus_name] = 'Others'
data['plot_cell_type'] = data['plot_cell_type'].astype(str)
data['plot_cell_type'] = data['plot_cell_type'].astype('category')

fig, axes = plt.subplots(2, 2, figsize=(7, 7), dpi=300)
axes = axes.flatten()
sns.violinplot(x='plot_cell_type', y='est.bg.prop.mus', data=data, ax=axes[0])
axes[0].set_ylabel('est.bg.prop.mus')
sns.violinplot(x='plot_cell_type', y='est.bg.prop.hyp', data=data, ax=axes[1])
axes[1].set_ylabel('est.bg.prop.hyp')
sns.violinplot(x='plot_cell_type', y='est.bg.prop.both', data=data, ax=axes[2])
axes[2].set_ylabel('est.bg.prop.both')
sns.violinplot(x='plot_cell_type', y='est.bg.prop', data=data, ax=axes[3])
axes[3].set_ylabel('est.bg.prop')
plt.tight_layout()

plt.savefig(save_dir / f"est_bg_prop_{file_suffix}.pdf")




In [None]:
# Save adata
adata.write_h5ad(data_dir / f"adata_cbce_{file_suffix}.h5ad")
print(data_dir / f"adata_cbce_{file_suffix}.h5ad")

In [None]:
adata.obsm = {key: adata.obsm[key] for key in adata.obsm.keys() if 'UMAP' in key}
adata.obsm

In [None]:
ccd.ul.anndata_to_viscello(adata, data_dir / f"cello_{proj_name}_{file_suffix}", project_name = proj_name, organism='cel')

In [None]:
adata_subsets = {}
adata_subsets['Global_dataset_cleaned'] = adata_clean
viscello_dir = str(data_dir / f"cello_{proj_name}_{file_suffix}")
ccd.ul.update_clist_with_subsets(global_adata = adata, adata_subsets = adata_subsets, viscello_dir = viscello_dir)

### Plot 3D embeddings

In [None]:
adata.obs['lineage'] = adata.obs['lineage_complete'].astype(str)
adata.obs['lineage'][adata.obs['lineage'].str.startswith('nan')] = 'NaN'

In [None]:
show_cols = ['broad_cell_type_qz', 'cell_type', 'ct_or_lin', 'lineage', 'broad_lineage', 'embryo.time', 'species', 'lin_or_ct']
basis = 'Concord-decoder'
for col in show_cols:
    show_basis = f'{basis}_UMAP_3D'
    ccd.pl.plot_embedding_3d(
            adata, basis=show_basis, color_by=col,
            pal = pal, 
            save_path=save_dir / f'{show_basis}_{col}_{file_suffix}.html',
            point_size=1, opacity=0.8, width=1000, height=1000, 
            autosize=True,
            static=False
        )

### Load and map lineage tree table

In [None]:
import pandas as pd
lineage_tree_tbl = pd.read_csv(data_dir / "cel_lineage_tree_tbl.csv", index_col=0)
lineage_tree_tbl

In [None]:
# Map lineage tree table to actual lineage annotation in adata
unique_adata_lineage = adata.obs['lineage_complete'].unique()
unique_adata_lineage_in_lineage = [x for x in unique_adata_lineage if x in lineage_tree_tbl['lineage'].values]
print(f"Number of unique lineages in adata: {len(unique_adata_lineage)}, in lineage: {len(unique_adata_lineage_in_lineage)}")
unique_adata_lineage_in_to = [x for x in unique_adata_lineage if x in lineage_tree_tbl['to'].values]
print(f"Number of unique lineages in adata: {len(unique_adata_lineage)}, in to: {len(unique_adata_lineage_in_to)}")
unique_adata_lineage_in_tree = [x for x in unique_adata_lineage if (x in lineage_tree_tbl['lineage'].values) | (x in lineage_tree_tbl['to'].values)]    
print(f"Number of unique lineages in adata: {len(unique_adata_lineage)}, in tree: {len(unique_adata_lineage_in_tree)}")

# lineage not in tree
unique_adata_lineage_not_in_tree = [x for x in unique_adata_lineage if (x not in lineage_tree_tbl['lineage'].values) & (x not in lineage_tree_tbl['to'].values)]
unique_adata_lineage_not_in_tree

In [None]:
import itertools

# Example lineage tree targets (the real one is much bigger)
lineage_actual = set(lineage_tree_tbl['to'].values)

def expand_x(lineage_str):
    """
    Given a string that may contain 'x' as wildcards,
    return all possible expansions where each 'x' -> a/p/r/l.
    """
    # If no 'x', just return the string itself
    if 'x' not in lineage_str:
        return [lineage_str]
    
    # Collect lists of possible chars for each position
    # e.g. 'Dxax' -> [('D',), ('a','p','r','l'), ('a',), ('a','p','r','l')]
    char_options = []
    for ch in lineage_str:
        if ch == 'x':
            char_options.append(['a','p','r','l','d','v'])
        else:
            char_options.append([ch])
    
    # Cartesian product over those lists
    # e.g. product(*char_options) will yield tuples like ('D','a','a','a'), etc.
    expansions = []
    for combo in itertools.product(*char_options):
        expansions.append("".join(combo))
    return expansions

def map_lineage_name(lineage_str, lineage_actual):
    """
    Given one lineage string which may contain slashes,
    return all valid matches in the lineage tree.
    """
    # 1) Split on slash
    # 2) Expand each sub-name for 'x'
    # 3) Retain expansions that appear in lineage_actual
    all_matches = []
    sub_names = lineage_str.split('/')
    for sub_name in sub_names:
        expansions = expand_x(sub_name)
        # Filter expansions to only those present in lineage_actual
        valid = [exp for exp in expansions if exp in lineage_actual]
        # If you want to keep duplicates separate, you could store them.
        # Usually, you might just union them or combine them.
        all_matches.extend(valid)
    return list(set(all_matches))  # or sorted, etc.

# -----------------------------
# Example usage
# -----------------------------
#lineage_list = unique_adata_lineage_not_in_tree
lineage_list = unique_adata_lineage
lin_annot_to_actual = {}
for lin in lineage_list:
    # If these special strings like 'Z2/Z3:pseudotime_bin_2' or '28_cell_or_earlier'
    # do not actually map to your lineage tree, you might just skip them or handle separately.
    # Convert to string
    lin = str(lin)
    if ':' in lin or 'cell' in lin:
        lin_annot_to_actual[lin] = []
        continue
    
    # Otherwise, do the slash/x expansion
    matched = map_lineage_name(lin, lineage_actual)
    lin_annot_to_actual[lin] = matched

# # Print results:
# for k, v in lin_annot_to_actual.items():
#     print(f"{k} -> {v}")

# also create the inverse map
lin_actual_to_annot = {}
for k, v in lin_annot_to_actual.items():
    for vv in v:
        if vv in lin_actual_to_annot:
            lin_actual_to_annot[vv].append(k)
        else:
            lin_actual_to_annot[vv] = [k]

# Print results:
for k, v in lin_actual_to_annot.items():
    print(f"{k} -> {v}")


In [None]:
canonical_array = lineage_tree_tbl['Cell']  # e.g. a NumPy array
canonical_cells = [str(x) for x in canonical_array if str(x).lower() != 'nan']

# For fast membership tests:
canonical_cell_set = set(canonical_cells)

# ------------------------------------------------------------------------------
# 1) Curated dictionary for names that should map to known expansions directly
#    (no wildcard or prefix logic).
# ------------------------------------------------------------------------------
SPECIAL_MAPPINGS = {
    # Examples you mentioned
    "OLQ": ["OLQDL", "OLQDR", "OLQVL", "OLQVR"],
    "OLL": ["OLLL", "OLLR"],
    'ADE': ['ADEL', 'ADER'],
    "BWM_headrow1_in": [],
    "BWM_headrow2_in": [],
    "M_cell": ["M"],
    "P_cells": [],
    "B_F_K_Kp_U_Y": ["B", "F", "K", "Kp", "U", "Y"],
    "Seam_hyp_early_cells": [],
    "hyp7_AB_lineage": ["hyp7"],
    "hyp7_C_lineage": ["hyp7"],
    "hyp4_hyp5_hyp6": ["hyp4", "hyp5", "hyp6"],
    "mu_int_mu_anal": ["mu_int_L", "mu_int_R", "mu_anal"],
    "mu_int_mu_anal_related": [],
    # pm3_pm4_pm5c => explicit expansions
    "pm3_pm4_pm5c": [
        "pm3DL", "pm3DR", "pm3L", "pm3R", "pm3VL", "pm3VR",
        "pm4DL", "pm4DR", "pm4L", "pm4R", "pm4VL", "pm4VR",
        "pm5DL", "pm5DR", "pm5L", "pm5R", "pm5VL", "pm5VR"
    ],
    "pm3_pm4_pm5b": [
        "pm3DL", "pm3DR", "pm3L", "pm3R", "pm3VL", "pm3VR",
        "pm4DL", "pm4DR", "pm4L", "pm4R", "pm4VL", "pm4VR",
        "pm5DL", "pm5DR", "pm5L", "pm5R", "pm5VL", "pm5VR"
    ],
    "Coelomocytes": ['ccAL', 'ccAR', 'ccPL', 'ccPR'],
    #"early_arcade_cell": ['arc ant DL', 'arc ant DR', 'arc ant V', 'arc post DL', 'arc post DR', 'arc post V', 'arc post VR', 'arc post VL'],
    'mu_sph': ['mu_sph'],
    'Seam_cells': [],
    'Seam_cells_early': ['HOL', 'HOR', 'H1L', 'H1R', 'H2L', 'H2R', 'V1L', 'V1R', 'V2L', 'V2R', 'V3L', 'V3R', 'V4L', 'V4R', 'V5L', 'V5R', 'V6L', 'V6R'],
    'mc2b': ['mc2DL', 'mc2DR', 'mc2V'],
    'mc2a': ['mc2DL', 'mc2DR', 'mc2V'],
    'Tail_hypodermis': ['hyp8/9', 'hyp10'],
    'Rectal_gland': ['rect_D', 'rect_VL', 'rect_VR'],
    'Anterior_arcade_cell': ['arc ant DL', 'arc ant DR', 'arc ant V'],
    'Posterior_arcade_cell': ['arc post D', 'arc post DL', 'arc post DR', 'arc post V', 'arc post VR', 'arc post VL'],
    'Pharyngeal_intestinal_valve': ['vpi1', 'vpi2DL', 'vpi2DR', 'vpi2V', 'vpi3D', 'vpi3V'],
    'hyp1V_and_ant_arc_V': ['ant arc V'],
    'hyp1V': ['ant arc V'],
    'Excretory_cell': ['exc_cell'],
    'Excretory_duct_and_pore': ['exc_duct'],
    'Excretory_gland': ['exc_gl_L', 'exc_gl_R'],
    'mu_int_mu_anal_related': [],
    'CEP': ['CEPDL', 'CEPDR', 'CEPVL', 'CEPVR'],
    'Arcade_cell': ['arc ant DL', 'arc ant DR', 'arc ant V', 'arc post D', 'arc post DL', 'arc post DR', 'arc post V', 'arc post VR', 'arc post VL'],

    # etc. Add more as needed...
}

# ------------------------------------------------------------------------------
# 2) Optionally define a synonym/prefix dictionary for partial expansions
#    This can handle simpler patterns like "AWB" => search for "AWB" in canonical cells
#    or "BWM" => "mu_bod", etc.
# ------------------------------------------------------------------------------
SYNONYM_PREFIXES = {
    #"BWM": "mu_bod",  # if you want "BWM_..." => "mu_bod"
}

def custom_name_to_canonical(cell_name: str) -> list[str]:
    """
    Convert one fuzzy cell name into a list of matching canonical names.
    'x' is treated literally (no wildcard expansion).
    Slash '/' is treated as an OR: we union the mappings for each segment.
    Special mappings override everything else.
    If not in special mappings or synonyms, we fallback to direct canonical match 
    (or prefix if so desired). Otherwise, return [].
    """

    # 1. Handle trivial cases: empty string or "NaN"
    cell_name = cell_name.strip()
    if not cell_name or cell_name.lower() == 'nan':
        return []

    # 2. Check SPECIAL_MAPPINGS overrides
    if cell_name in SPECIAL_MAPPINGS:
        return SPECIAL_MAPPINGS[cell_name]

    # 3. If there's a slash '/', split by slash and map each piece, then union results
    if '/' in cell_name:
        segments = cell_name.split('/')
        total_hits = set()
        for seg in segments:
            seg_hits = custom_name_to_canonical(seg)  # Recursive call
            total_hits.update(seg_hits)
        return sorted(total_hits)

    # 4. Apply synonyms/prefix expansions
    #    We can split by underscore '_' to handle multi-part strings,
    #    or you can do it differently if you prefer.
    tokens = cell_name.split('_')

    final_hits = set()
    for t in tokens:
        # If there's a direct synonym prefix, use that
        if t in SYNONYM_PREFIXES:
            prefix = SYNONYM_PREFIXES[t]
        else:
            prefix = t  # no change

        # Attempt a direct membership check first
        if prefix in canonical_cell_set:
            final_hits.add(prefix)
        else:
            # Optionally do a prefix search among all canonical cells
            # e.g. "AWB" => AWBL, AWBR if they exist
            for c in canonical_cells:
                if c.startswith(prefix):
                    final_hits.add(c)

    # If we found any matches from synonyms or prefix checks, return them
    if final_hits:
        return sorted(final_hits)

    # 5. Final fallback: if cell_name itself is a canonical cell, return it
    if cell_name in canonical_cell_set:
        return [cell_name]

    # 6. Otherwise, no match
    return []

# ------------------------------------------------------------------------------
# 5) Example usage
# ------------------------------------------------------------------------------

celltype_to_cellname = {}

fuzzy_list = adata.obs['cell_type'].unique()

for name in fuzzy_list:
    celltype_to_cellname[name] = custom_name_to_canonical(name)

# Print results:
for k, v in celltype_to_cellname.items():
    print(f"{k} -> {v}")

# Further map to lineage
lineage_to_cellname = dict(zip(lineage_tree_tbl['to'], lineage_tree_tbl['Cell']))
cellname_to_lineage = dict(zip(lineage_tree_tbl['Cell'], lineage_tree_tbl['to']))
cellname_to_lineage


In [None]:
celltype_to_lineage = {}
for k, v in celltype_to_cellname.items():
    if len(v) == 0:
        celltype_to_lineage[k] = []
    else:
        celltype_to_lineage[k] = [cellname_to_lineage[x] if x in cellname_to_lineage.keys() else None for x in v]

lineage_to_celltype = {}
for k, v in celltype_to_lineage.items():
    for vv in v:
        if vv in lineage_to_celltype:
            lineage_to_celltype[vv].append(k)
        else:
            if vv is not None:
                lineage_to_celltype[vv] = [k]

lineage_tree_tbl['celltype_mapped'] = lineage_tree_tbl['to'].map(lineage_to_celltype)
lineage_tree_tbl['linannot_mapped'] = lineage_tree_tbl['to'].map(lin_actual_to_annot)
lineage_tree_tbl['lin_or_ct_mapped'] = lineage_tree_tbl['linannot_mapped']
lineage_tree_tbl['lin_or_ct_mapped'][lineage_tree_tbl['linannot_mapped'].isna()] = lineage_tree_tbl['celltype_mapped'][lineage_tree_tbl['linannot_mapped'].isna()]
#lineage_tree_tbl.iloc[0, lineage_tree_tbl.columns.get_loc('linannot_mapped')] = ['28_cell_or_earlier']

# If any celltype_mapped or linannot_mapped is not NaN, consider it as mapped
lineage_tree_tbl['mapped'] = (lineage_tree_tbl['celltype_mapped'].notna()) | (lineage_tree_tbl['linannot_mapped'].notna())
lineage_tree_tbl['mapped'].value_counts()


In [None]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

# Suppose lineage_tree_tbl is a DataFrame with columns:
#   'from', 'to', 'mapped', 'celltype_mapped', 'linannot_mapped'
# We build a DiGraph from it.

# Example broad_lineage_groups dictionary:
broad_lineage_groups = {
    'Cxa':   ['Cxa', 'Cpa', 'Caa'],
    'Cxp':   ['Cxp', 'Cpp', 'Cap'],
    'D':     ['D'],
    'E':     ['E'],
    'MSxpa': ['MSxpa', 'MSapa', 'MSppa'],
    'MSxaa': ['MSxaa', 'MSpaa', 'MSaaa'],
    'MSxap': ['MSxap', 'MSpap', 'MSaap'],
    'MSxpp': ['MSxpp', 'MSppp', 'MSapp'],
    'ABala': ['ABala'],
    'ABalp': ['ABalp'],
    'ABara': ['ABara'],
    'ABarp': ['ABarp'],
    'ABpla': ['ABpla'],
    'ABplp': ['ABplp'],
    'ABpra': ['ABpra'],
    'ABprp': ['ABprp'],
    'Z2/Z3': ['Z2', 'Z3'],
}

def map_node_to_broad_group(node_name: str, broad_map: dict) -> str:
    """
    Given a node name (string), check if it starts with any prefixes
    in broad_map[group_key]. If so, return that group_key. Otherwise None.
    """
    for group_key, prefixes in broad_map.items():
        for prefix in prefixes:
            if node_name.startswith(prefix):
                return group_key
    return None

# -------------------------------------------------------------------
# 1) Build the graph from lineage_tree_tbl
# -------------------------------------------------------------------
G = nx.DiGraph()

for _, row in lineage_tree_tbl.iterrows():
    parent = row['from']
    child  = row['to']
    mapped = row['mapped']  # boolean
    celltype = row['celltype_mapped']
    linannot = row['linannot_mapped']
    linorct = []
    if isinstance(linannot, list):
        linorct.extend([x for x in linannot if pd.notna(x)])
    if isinstance(celltype, list):
        linorct.extend([x for x in celltype if pd.notna(x)])
    

    # Skip invalid rows
    if pd.isna(parent) or pd.isna(child):
        continue

    # If the parent isn't in the graph, add it with default attributes
    if not G.has_node(parent):
        G.add_node(parent, mapped=False, celltype=None, linannot=None, linorct=None)

    # Add/update the child node with the known attributes
    G.add_node(child, mapped=mapped, celltype=celltype, linannot=linannot, linorct=linorct)

    # Add edge
    G.add_edge(parent, child)

print(f"Graph built with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")

# -------------------------------------------------------------------
# 2) For each node, figure out which broad_lineage_group it belongs to
# -------------------------------------------------------------------
node2group = {}
for node in G.nodes():
    group_label = map_node_to_broad_group(node, broad_lineage_groups)
    node2group[node] = group_label

# Collect all *defined* group labels (i.e. not None)
all_defined_groups = [g for g in node2group.values() if g is not None]
unique_groups = sorted(set(all_defined_groups))

print("Unique groups from prefix matching:", unique_groups)

# -------------------------------------------------------------------
# 3) Assign colors to each group label
# -------------------------------------------------------------------
# If you do not use ccd, you can do a standard matplotlib colormap:
#
# import matplotlib.cm as cm
# cmap = cm.get_cmap("tab20", len(unique_groups))
# group2color = {grp: cmap(i) for i, grp in enumerate(unique_groups)}

# If you have ccd.pl.get_factor_color(...):
#group2color = ccd.pl.get_factor_color(unique_groups, pal='Paired', permute=True, seed=42)
group2color = {k: broadlin_pal[k] for k in unique_groups}

# -------------------------------------------------------------------
# 4) Build a node_color map:
#     - If node2group[node] is not None => group2color[node2group[node]]
#     - else => if node is mapped => black, else grey
# -------------------------------------------------------------------
node2color = {}
for node in G.nodes():
    grp = node2group[node]  # could be None
    mapped_flag = G.nodes[node]['mapped']

    if grp is not None and mapped_flag:
        # This node belongs to a recognized broad group
        node2color[node] = group2color[grp]
    else:
        # No recognized group => black if mapped, grey if not
        node2color[node] = 'black' if mapped_flag else 'lightgray'

# -------------------------------------------------------------------
# 5) Plot the graph
# -------------------------------------------------------------------
# We'll use graphviz layout (dot) for a tree-like structure
pos = nx.nx_agraph.graphviz_layout(G, prog='dot')  # requires pygraphviz

# Extract node colors in consistent order
ordered_nodes = list(G.nodes())
final_node_colors = [node2color[n] for n in ordered_nodes]

plt.figure(figsize=(10, 1), dpi=600)
nx.draw(
    G, pos,
    with_labels=True,
    nodelist=ordered_nodes,
    node_color=final_node_colors,
    node_size=3,
    font_size=3,
    arrowsize=2,
    width=0.5
)

plt.title("Lineage Tree: Node-wise Prefix Grouping (or Black/Grey if Unmapped/No Group)")
plt.savefig(save_dir / f"lineage_tree_nodewise_prefix_{file_suffix}.pdf")
plt.show()


In [None]:
broad_lineage_groups = {
    'Cxa':   ['Cxa', 'Cpa', 'Caa'],
    'Cxp':   ['Cxp', 'Cpp', 'Cap'],
    'D':     ['D'],
    'E':     ['E'],
    'MSxpa': ['MSxpa', 'MSapa', 'MSppa'],
    'MSxaa': ['MSxaa', 'MSpaa', 'MSaaa'],
    'MSxap': ['MSxap', 'MSpap', 'MSaap'],
    'MSxpp': ['MSxpp', 'MSppp', 'MSapp'],
    'ABala': ['ABala'],
    'ABalp': ['ABalp'],
    'ABara': ['ABara'],
    'ABarp': ['ABarp'],
    'ABpla': ['ABpla'],
    'ABplp': ['ABplp'],
    'ABpra': ['ABpra'],
    'ABprp': ['ABprp'],
    'Z2/Z3': ['Z2', 'Z3'],
}

def map_leaf_to_broad_group(leaf_node: str, broad_map: dict) -> str:
    """
    If 'leaf_node' starts with any prefix in 'broad_map[group_key]',
    return group_key. Otherwise return None.
    """
    for group_key, prefixes in broad_map.items():
        for prefix in prefixes:
            if leaf_node.startswith(prefix):
                return group_key
    return None  # no prefix match

# -----------------------------------------------------
# Find root, leaves, and construct root->leaf paths
# -----------------------------------------------------
root_candidates = [n for n in G.nodes if G.in_degree(n) == 0]
root_node = root_candidates[0] 
leaf_nodes = [n for n in G.nodes if G.out_degree(n) == 0]

paths = []
for leaf in leaf_nodes:
    path = nx.shortest_path(G, source=root_node, target=leaf)
    paths.append(path)

print(f"Found {len(paths)} root-to-leaf paths.")

# -----------------------------------------------------
# 1) For each path, figure out "group label" from the LAST node
# -----------------------------------------------------
path_labels = []
for path in paths:
    leaf_node = path[-1]
    group_label = map_leaf_to_broad_group(leaf_node, broad_lineage_groups)
    # If there's no match, fallback to the leaf's name
    if group_label is None:
        group_label = leaf_node
    path_labels.append(group_label)

# -----------------------------------------------------
# 2) Build a color dictionary for all unique labels
# -----------------------------------------------------
unique_labels = sorted(set(path_labels))
print("Unique group labels:", unique_labels)


# (If you prefer ccd.pl.get_factor_color(...) or another approach, do that here.)

# -----------------------------------------------------
# Some utility functions
# -----------------------------------------------------
def parse_annotation(annot_value):
    """
    Convert annotation (which could be None, np.nan, a string, 
    a list of strings, etc.) into a clean list of valid strings.
    """
    if annot_value is None:
        return []
    if not isinstance(annot_value, list):
        annot_list = [annot_value]
    else:
        annot_list = annot_value

    valid_strings = []
    for val in annot_list:
        if pd.isna(val):  # real NaN
            continue
        if val == 'nan':  # string 'nan'
            continue
        valid_strings.append(str(val))
    return valid_strings

def get_representative_point(coords, method='medoid', max_n_medoid=2000,
                             k_top=3, seed=0, jitter=0, return_idx=False):
    """
    Return a single (x, y) representing these coords. 
    - If 'medoid' and len(coords)<=max_n_medoid, we pick one
      randomly from the top k_top best medoid candidates.
    - Otherwise, nearest to centroid.
    """
    import random
    from scipy.spatial.distance import cdist
    random.seed(seed)  # for reproducibility
    n = len(coords)
    if n == 0:
        return np.array([np.nan, np.nan])
    if method == 'medoid' and n <= max_n_medoid:
        dist_mat = cdist(coords, coords)  # shape (n,n)
        sum_dists = dist_mat.sum(axis=1)
        sorted_indices = np.argsort(sum_dists)
        if k_top>n: 
            k_top=n
        best_indices = sorted_indices[:k_top]
        chosen_idx = random.choice(best_indices)
        coord = coords[chosen_idx]
    else:
        # fallback
        centroid = coords.mean(axis=0)
        dists = np.linalg.norm(coords - centroid, axis=1)
        chosen_idx = np.argmin(dists)
        coord = coords[chosen_idx]
    
    # Add jitter for visualization
    jittered_coord = coord + jitter * np.random.randn(2)
    if return_idx:
        return jittered_coord, chosen_idx
    else:
        return jittered_coord

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import networkx as nx
from scipy.spatial.distance import cdist
import random

# -----------------------------------------------------
# 3) Plot everything on UMAP
# -----------------------------------------------------
#basis = 'Concord-decoder_UMAP'
basis = 'Seurat_UMAP'
#basis = 'scVI_UMAP'
plot_label = False  # set True if you want text labels

plt.figure(figsize=(4, 4), dpi=600)

# a) Plot all points in background
plt.scatter(
    adata.obsm[basis][:,0],
    adata.obsm[basis][:,1],
    rasterized=True, 
    zorder=0,  
    s=0.1, color="lightgray", alpha=0.4,
    edgecolors="none"
)

# b) For each path, fetch the group label, color, etc.
for path_idx, path in enumerate(paths):
    group_label = path_labels[path_idx]  # assigned above
    color = group2color[group_label]     # get color

    rep_points = []
    labels = []

    for node in path:
        attrs = G.nodes[node]
        linannot_list = parse_annotation(attrs.get('linannot'))
        celltype_list = parse_annotation(attrs.get('celltype'))

        # Decide which annotation is used
        if linannot_list:
            mask = adata.obs['lineage_complete'].isin(linannot_list)
            used_annot = linannot_list
        elif celltype_list:
            mask = adata.obs['cell_type'].isin(celltype_list)
            used_annot = celltype_list
        else:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
            continue

        cell_indices = np.where(mask)[0]
        if len(cell_indices)==0:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
        else:
            coords = adata.obsm[basis][cell_indices]
            rp = get_representative_point(coords, method='medoid', 
                                          max_n_medoid=2000, 
                                          k_top=3, jitter=0, seed=seed)
            rep_points.append(rp)
            labels.append((node, used_annot))

    rep_points = np.array(rep_points)
    valid_mask = ~np.isnan(rep_points[:,0])
    valid_rep_points = rep_points[valid_mask]
    valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]

    # Draw path in chosen color
    plt.plot(
        valid_rep_points[:,0],
        valid_rep_points[:,1],
        color=color,  # Fill color
        marker='o', 
        markersize=2, 
        markeredgecolor='black',  # Edge color
        markeredgewidth=0.1,  # Edge thickness
        linewidth=0.3,
        alpha=0.8,
        zorder=1
    )

    # Optionally label text
    if plot_label:
        for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
            label_text = f"{node_name}\n{annot_list}"
            plt.text(cx, cy, label_text, fontsize=2, color="black", zorder=2, alpha=0.5)


plt.title(f"Lineage Paths on {basis}")
plt.xlabel("")
plt.ylabel("")
plt.xticks([])
plt.yticks([])

text_ext = "with_text" if plot_label else "no_text"
plt.savefig(save_dir / f"lineage_paths_{basis}_{file_suffix}_{text_ext}.pdf")
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

basis = 'Concord-decoder_UMAP'
#basis = 'Seurat_UMAP'
#basis = 'scVI_UMAP'
plot_label = False  # set True if you want text labels

# Get unique groups
unique_path_groups = sorted(set(path_labels))
unique_path_groups = [group for group in unique_path_groups if group != 'Z2/Z3']
# Determine grid size for subplots
n_groups = len(unique_path_groups)
n_cols = 8  # Customize the number of columns
n_rows = int(np.ceil(n_groups / n_cols))

# Set up the figure with subplots
fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2 * n_rows), dpi=300)
axes = axes.flatten()  # Flatten for easier indexing

for i, group_label in enumerate(unique_path_groups):
    print(f"Plotting group {group_label} in subplot {i}")
    ax = axes[i]

    # Plot all points in the background
    ax.scatter(
        adata.obsm[basis][:, 0],
        adata.obsm[basis][:, 1],
        rasterized=True,
        zorder=0,
        s=0.1,
        color="lightgray",
        alpha=0.5,
        edgecolors="none"
    )

    # Filter paths that belong to the current group
    for path_idx, path in enumerate(paths):
        if path_labels[path_idx] != group_label:
            continue  # Skip paths that don't belong to this group

        color = group2color[group_label]  # Get color for this group

        rep_points = []
        labels = []

        for node in path:
            attrs = G.nodes[node]
            linannot_list = parse_annotation(attrs.get('linannot'))
            celltype_list = parse_annotation(attrs.get('celltype'))

            # Decide which annotation is used
            if linannot_list:
                mask = adata.obs['lineage_complete'].isin(linannot_list)
                used_annot = linannot_list
            elif celltype_list:
                mask = adata.obs['cell_type'].isin(celltype_list)
                used_annot = celltype_list
            else:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
                continue

            cell_indices = np.where(mask)[0]
            if len(cell_indices) == 0:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
            else:
                coords = adata.obsm[basis][cell_indices]
                rp = get_representative_point(coords, method='medoid',
                                              max_n_medoid=2000,
                                              k_top=3, jitter=0, seed=seed)
                rep_points.append(rp)
                labels.append((node, used_annot))

        rep_points = np.array(rep_points)
        valid_mask = ~np.isnan(rep_points[:, 0])
        valid_rep_points = rep_points[valid_mask]
        valid_labels = [labels[j] for j in range(len(labels)) if valid_mask[j]]

        # Draw path in chosen color
        ax.plot(
            valid_rep_points[:, 0],
            valid_rep_points[:, 1],
            color=color,  # Fill color
            marker='o',
            markersize=3,
            markeredgecolor='black',  # Edge color
            markeredgewidth=0.1,  # Edge thickness
            linewidth=0.4,
            alpha=0.8,
            zorder=1
        )

        # Optionally label text
        if plot_label:
            for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
                label_text = f"{node_name}\n{annot_list}"
                ax.text(cx, cy, label_text, fontsize=2, color="black", zorder=2, alpha=0.5)

    # Add title for the group
    ax.set_title(f"{group_label}", fontsize=12)
    # Remove x,y axis label , remove ticks
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_xticks([])
    ax.set_yticks([])

# Turn off unused axes
for j in range(len(unique_path_groups), len(axes)):
    axes[j].axis("off")

# Adjust layout
plt.tight_layout()

# Save the figure
plt.savefig(save_dir / f"lineage_paths_subplots_{basis}_{file_suffix}.pdf", dpi=300)
plt.show()


#### Trace selected lineage

In [None]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

# -------------------------------------------------------------------
# 2) Extract subgraph starting from multiple roots
# -------------------------------------------------------------------
# Define the multiple roots
#roots = ['ABarappppp', 'ABalpapppa', 'ABalpappaa', 'ABalappppa', 'ABalapappa', 'ABalappapp', 'ABalapaapp', 'ABalaappp']  # Replace with your desired roots
roots = ['ABpra']

# Collect all descendants for each root
all_descendants = set()
for root in roots:
    descendants = nx.descendants(G, root)
    descendants.add(root)  # Include the root itself
    all_descendants.update(descendants)

# Create the subgraph with all collected descendants
subgraph = G.subgraph(all_descendants).copy()

# Optionally trim specific branches if needed
# Example:
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABprap')) + ['ABprap'])
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraap')) + ['ABpraap'])
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaaa')) + ['ABpraaaa'])
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaapa')) + ['ABpraaapa'])
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaappa')) + ['ABpraaappa'])

# -------------------------------------------------------------------
# 3) For each node, figure out which broad_lineage_group it belongs to
# -------------------------------------------------------------------
node2group = {}
for node in subgraph.nodes():
    group_label = map_node_to_broad_group(node, broad_lineage_groups)
    node2group[node] = group_label

# Collect all *defined* group labels (i.e. not None)
all_defined_groups = [g for g in node2group.values() if g is not None]
unique_groups = sorted(set(all_defined_groups))

# -------------------------------------------------------------------
# 4) Build a node_color map
# -------------------------------------------------------------------
node2color = {}
for node in subgraph.nodes():
    grp = node2group[node]  # could be None
    mapped_flag = subgraph.nodes[node].get('mapped', False)

    if grp is not None and mapped_flag:
        # This node belongs to a recognized broad group
        node2color[node] = group2color[grp]
    else:
        # No recognized group => black if mapped, grey if not
        node2color[node] = 'black' if mapped_flag else 'lightgray'

# -------------------------------------------------------------------
# 5) Plot the graph
# -------------------------------------------------------------------
# We'll use graphviz layout (dot) for a tree-like structure
pos = nx.nx_agraph.graphviz_layout(subgraph, prog='dot', args='-Grankdir=LR')  # requires pygraphviz

# Extract node colors in consistent order
ordered_nodes = list(subgraph.nodes())
final_node_colors = [node2color[n] for n in ordered_nodes]

# -------------------------------------------------------------------
# 6) Label each node by its original label plus celltype_mapped
# -------------------------------------------------------------------
node_labels = {}
for node in subgraph.nodes():
    celltype_mapped = subgraph.nodes[node].get('celltype', [])
    print(celltype_mapped)
    if celltype_mapped:
        if(isinstance(celltype_mapped, list)):
            label = f"{node}/{','.join(celltype_mapped)}"
        else:
            if pd.notna(celltype_mapped):
                label = f"{node}/{celltype_mapped}"
            else:
                label = node
    else:
        label = node
    node_labels[node] = label

# Update the plot to use the new labels
plt.figure(figsize=(2, 1), dpi=600)
nx.draw(
    subgraph, pos,
    with_labels=True,
    labels=node_labels,
    nodelist=ordered_nodes,
    node_color=final_node_colors,
    node_size=15,
    font_size=5,
    arrowsize=4,
    width=0.5
)

# Use a combined title for the roots
plt.title(f'Subgraph for Roots: {", ".join(roots)}')
plt.savefig(save_dir / f"lineage_subtree_{'_'.join(roots)}_{file_suffix}.pdf")
plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

k = 30
basis = f'Concord-decoder'  
basis = f'scVI'
basis = f'Seurat'
show_basis = f'{basis}_UMAP' 
zoom_in = True
show_square=False


# Define multiple roots
#roots = ['ABpra', 'ABarapppp', 'ABalp']  # Add your desired roots here

subset_name = 'global'
adata_subset = adata

# -------------------------------------------------------------------
# Extract subpaths for multiple roots
# -------------------------------------------------------------------
subpaths = []
leaf_nodes = [node for node in subgraph.nodes() if subgraph.out_degree(node) == 0]

# Collect paths for each root
for root in roots:
    for leaf in leaf_nodes:
        if nx.has_path(subgraph, source=root, target=leaf):
            path = nx.shortest_path(subgraph, source=root, target=leaf)
            subpaths.append(path)

print(f"Found {len(subpaths)} root-to-leaf paths for roots: {roots}")

# -------------------------------------------------------------------
# Map cells in the subgraph
# -------------------------------------------------------------------
add_inferred_trajectory = False
cellpaths = []
cells_in_subgraph = [subgraph.nodes[node].get('linorct', []) for node in subgraph.nodes]
cells_in_subgraph = [item for sublist in cells_in_subgraph if isinstance(sublist, list) for item in sublist]

# Map colors to lin_or_ct
adata_subsub = adata_subset[adata_subset.obs['lin_or_ct'].isin(cells_in_subgraph)]
_, _, lin_or_ct_palette = ccd.pl.get_color_mapping(adata_subsub, 'lin_or_ct', pal='Set1')
colors = adata_subsub.obs['lin_or_ct'].astype(str).map(lin_or_ct_palette)

# -------------------------------------------------------------------
# Identify valid cell paths
# -------------------------------------------------------------------
for end_cell in leaf_nodes:
    selected_path = [path for path in subpaths if path[-1] == end_cell]
    if len(selected_path) == 0:
        continue
    start_node = selected_path[0][0]
    end_node = selected_path[0][-1]
    selected_path_start = subgraph.nodes[start_node].get('linorct', [])
    selected_path_end = subgraph.nodes[end_node].get('celltype', [])
    if selected_path_end and pd.notna(selected_path_end):  # Optionally filter terminals
        print(f"Start: {start_node} ({selected_path_start}), End: {end_node} ({selected_path_end})")
        cellpaths.append(selected_path[0])

# -------------------------------------------------------------------
# Plot subpaths on the UMAP
# -------------------------------------------------------------------
plot_label = True  # Set True if you want text labels

plt.figure(figsize=(2, 2), dpi=600)

# a) Plot all points in the background
plt.scatter(
    adata_subset.obsm[show_basis][:, 0],
    adata_subset.obsm[show_basis][:, 1],
    rasterized=True,
    zorder=0,
    s=0.1, color="lightgray", alpha=0.8,
    edgecolors="none"
)

# Highlight cells in the subgraph, colored by lin_or_ct
plt.scatter(
    adata_subsub.obsm[show_basis][:, 0],
    adata_subsub.obsm[show_basis][:, 1],
    c=colors,
    rasterized=True,
    zorder=1,
    s=0.3, alpha=0.8,
    edgecolors="none"
)

all_rep_points = []
for path in cellpaths:
    leaf_node = path[-1]
    group_label = map_leaf_to_broad_group(leaf_node, broad_lineage_groups)
    print(group_label)
    color = group2color[group_label]  # Get color

    rep_points = []
    rep_points_idx = []
    labels = []

    for node in path:
        attrs = G.nodes[node]
        linorct_list = parse_annotation(attrs.get('linorct'))

        # Decide which annotation is used
        if linorct_list:
            mask = adata_subset.obs['lin_or_ct'].isin(linorct_list)
            used_annot = linorct_list
        else:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
            continue

        cell_indices = np.where(mask)[0]
        if len(cell_indices) == 0:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
        else:
            coords = adata_subset.obsm[show_basis][cell_indices]
            rp, idx = get_representative_point(coords, method='medoid',
                                               max_n_medoid=2000,
                                               k_top=10, jitter=0, return_idx=True, seed=seed)
            rep_points.append(rp)
            rep_points_idx.append(cell_indices[idx])
            labels.append((node, used_annot))


    start_idx = rep_points_idx[0]
    end_idx = rep_points_idx[-1]

    rep_points = np.array(rep_points)
    valid_mask = ~np.isnan(rep_points[:, 0])
    valid_rep_points = rep_points[valid_mask]
    valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]
    all_rep_points.append(valid_rep_points)

    # Draw path in the chosen color
    plt.plot(
        valid_rep_points[:, 0],
        valid_rep_points[:, 1],
        color='black',  # Line color
        marker='o',
        markersize=3,
        markerfacecolor=color,  # Fill color
        markeredgecolor='black',  # Edge color
        markeredgewidth=0.2,  # Edge thickness
        linewidth=0.3,
        alpha=0.8,
        zorder=2
    )

    if add_inferred_trajectory:
        neighborhood = ccd.ml.Neighborhood(adata_subset.obsm[show_basis], k=k, use_faiss=False)
        celltrajectory, _ = ccd.ul.shortest_path_on_knn_graph(neighborhood, k=k, point_a=start_idx, point_b=end_idx, use_faiss=False)

        plt.plot(
            adata_subset.obsm[show_basis][celltrajectory, 0],
            adata_subset.obsm[show_basis][celltrajectory, 1],
            color='black',
            marker='o',
            markersize=0.5,
            markeredgecolor='black',
            markeredgewidth=0.1,
            linewidth=0.3,
            alpha=0.8,
            zorder=1
        )

    # Optionally label text
    if plot_label:
        for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
            label_text = f"{node_name}\n{annot_list}"
            plt.text(cx, cy, label_text, fontsize=0.5, color="black", zorder=2, alpha=0.5)

# Add some margin
if zoom_in:
    all_rep_points = np.concatenate(all_rep_points, axis=0)
    min_x, min_y = np.nanmin(all_rep_points, axis=0)
    max_x, max_y = np.nanmax(all_rep_points, axis=0)
    
    if show_square:
        # Ensure square aspect ratio
        margin = 0.1 * max(max_x - min_x, max_y - min_y)
        center_x = (min_x + max_x) / 2
        center_y = (min_y + max_y) / 2
        half_side = max(max_x - min_x, max_y - min_y) / 2 + margin

        plt.xlim(center_x - half_side, center_x + half_side)
        plt.ylim(center_y - half_side, center_y + half_side)
    else:
        margin = 0.1 * max(max_x - min_x, max_y - min_y)
        plt.xlim(min_x - margin, max_x + margin)
        plt.ylim(min_y - margin, max_y + margin)

plt.title(f"Lineage Paths in UMAP (Roots: {', '.join(roots)})")
plt.xlabel("")
plt.ylabel("")
plt.xticks([])
plt.yticks([])

text_ext = "with_text" if plot_label else "no_text"
zoomin_ext = "zoomin" if zoom_in else "nozoomin"
square_ext = "square" if show_square else "nosquare"
plt.savefig(save_dir / f"lineage_subpaths_{subset_name}_{'_'.join(roots)}_{show_basis}_{file_suffix}_{text_ext}_{zoomin_ext}_{square_ext}.pdf")
plt.show()


In [None]:
# 3D view
sanitized_ct = 'ABpra_Neuron'
show_cols = ['embryo.time']
show_cols = ['lin_or_ct']
show_keys = ['Concord-decoder']
show_keys = ['scVI']
show_keys = ['Seurat']
_, _, lin_or_ct_palette = ccd.pl.get_color_mapping(adata_subsub, 'lin_or_ct', pal='Set1')
azims = [30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330]
#azims =[130]
elevs = [15, 30, 45, 60, 75]
for azim in azims:
    for elev in elevs:
        with plt.rc_context(rc=custom_rc):
            ccd.pl.plot_all_embeddings_3d(
                adata=adata_subsub,
                combined_keys=show_keys,         # or however you name your embeddings
                color_bys=show_cols,  # columns or gene names to color by
                basis_types=['UMAP_3D'],              # or whatever naming convention is in adata.obsm
                pal=lin_or_ct_palette,    
                ncols=2,
                rasterized=True,                   # Points get rasterized
                point_size=20,
                alpha=0.8,
                elev=elev,
                azim=azim,
                zoom_factor=0.05,
                show_grid=True,
                show_axis_labels=False,            # Turn off axis labels
                show_ticks=False,                  # No ticks
                show_legend=False,
                tick_label_font_size=6,
                legend_font_size=6,
                save_dir=save_dir,
                file_suffix=file_suffix + f'_azim{azim}_elev{elev}' + f"_{sanitized_ct}" + f"_{show_keys[0]}" + f"_{show_cols[0]}",
                save_format='pdf',
            )

In [None]:
adata_subsub.write_h5ad(data_dir / f"adata_subsub_aseasjaua_{file_suffix}.h5ad")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

basis = f'Concord-decoder'
#basis = f'scVI'  
#basis = 'Seurat'
show_basis = f'{basis}_UMAP_zoomrecomp' 

# Map colors to lin_or_ct
adata_subsub = adata_subset[adata_subset.obs['lin_or_ct'].isin(cells_in_subgraph)]
_, _, lin_or_ct_palette = ccd.pl.get_color_mapping(adata_subsub, 'lin_or_ct', pal='Set1')
colors = adata_subsub.obs['lin_or_ct'].astype(str).map(lin_or_ct_palette)
use_seed = 0
ccd.ul.run_umap(adata_subsub, source_key=basis, result_key=show_basis, n_components=2, n_neighbors=30, min_dist=0.2, metric='cosine', random_state=use_seed)

# -------------------------------------------------------------------
# Identify valid cell paths
# -------------------------------------------------------------------
for end_cell in leaf_nodes:
    selected_path = [path for path in subpaths if path[-1] == end_cell]
    if len(selected_path) == 0:
        continue
    start_node = selected_path[0][0]
    end_node = selected_path[0][-1]
    selected_path_start = subgraph.nodes[start_node].get('linorct', [])
    selected_path_end = subgraph.nodes[end_node].get('celltype', [])
    if selected_path_end and pd.notna(selected_path_end):  # Optionally filter terminals
        print(f"Start: {start_node} ({selected_path_start}), End: {end_node} ({selected_path_end})")
        cellpaths.append(selected_path[0])

# -------------------------------------------------------------------
# Plot subpaths on the UMAP
# -------------------------------------------------------------------
plot_label = True  # Set True if you want text labels

plt.figure(figsize=(1.5, 1.5), dpi=600)

# Highlight cells in the subgraph, colored by lin_or_ct
plt.scatter(
    adata_subsub.obsm[show_basis][:, 0],
    adata_subsub.obsm[show_basis][:, 1],
    c=colors,
    rasterized=True,
    zorder=1,
    s=0.3, alpha=0.8,
    edgecolors="none"
)

all_rep_points = []
for path in cellpaths:
    leaf_node = path[-1]
    group_label = map_leaf_to_broad_group(leaf_node, broad_lineage_groups)
    color = group2color[group_label]  # Get color

    rep_points = []
    rep_points_idx = []
    labels = []

    for node in path:
        attrs = G.nodes[node]
        linorct_list = parse_annotation(attrs.get('linorct'))

        # Decide which annotation is used
        if linorct_list:
            mask = adata_subsub.obs['lin_or_ct'].isin(linorct_list)
            used_annot = linorct_list
        else:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
            continue

        cell_indices = np.where(mask)[0]
        if len(cell_indices) == 0:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
        else:
            coords = adata_subsub.obsm[show_basis][cell_indices]
            rp, idx = get_representative_point(coords, method='medoid',
                                               max_n_medoid=2000,
                                               k_top=10, jitter=0, return_idx=True, seed=seed)
            rep_points.append(rp)
            rep_points_idx.append(cell_indices[idx])
            labels.append((node, used_annot))

    start_idx = rep_points_idx[0]
    end_idx = rep_points_idx[-1]

    rep_points = np.array(rep_points)
    valid_mask = ~np.isnan(rep_points[:, 0])
    valid_rep_points = rep_points[valid_mask]
    valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]
    all_rep_points.append(valid_rep_points)

    # Draw path in the chosen color
    plt.plot(
        valid_rep_points[:, 0],
        valid_rep_points[:, 1],
        color='black',  # Line color
        marker='o',
        markersize=3,
        markerfacecolor=color,  # Fill color
        markeredgecolor='black',  # Edge color
        markeredgewidth=0.1,  # Edge thickness
        linewidth=0.1,
        alpha=0.8,
        zorder=2
    )

    if add_inferred_trajectory:
        neighborhood = ccd.ml.Neighborhood(adata_subsub.obsm[show_basis], k=k, use_faiss=False)
        celltrajectory, _ = ccd.ul.shortest_path_on_knn_graph(neighborhood, k=k, point_a=start_idx, point_b=end_idx, use_faiss=False)

        plt.plot(
            adata_subsub.obsm[show_basis][celltrajectory, 0],
            adata_subsub.obsm[show_basis][celltrajectory, 1],
            color='black',
            marker='o',
            markersize=0.5,
            markeredgecolor='black',
            markeredgewidth=0.1,
            linewidth=0.3,
            alpha=0.8,
            zorder=1
        )

    # Optionally label text
    if plot_label:
        for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
            label_text = f"{node_name}\n{annot_list}"
            plt.text(cx, cy, label_text, fontsize=0.5, color="black", zorder=2, alpha=0.5)


plt.title(f"Lineage Paths in UMAP (Roots: {', '.join(roots)})")
plt.xlabel("")
plt.ylabel("")
plt.xticks([])
plt.yticks([])

text_ext = "with_text" if plot_label else "no_text"
plt.savefig(save_dir / f"lineage_subpaths_{subset_name}_{'_'.join(roots)}_{show_basis}_{file_suffix}_{text_ext}_{use_seed}.pdf")
plt.show()


#### Hierachical clustering of the latent

In [None]:
include_terminals = ['ASJ', 'AUA', 'ASE']

In [None]:
# Plot hierarchical clustering of latent_embedding colored by lin_or_ct
adata_plot = adata_subsub[adata_subsub.obs['lin_or_ct'].isin(include_terminals)]
lin_or_ct_palette_use = {k: v for k, v in lin_or_ct_palette.items() if k in include_terminals}
show_basis = 'Concord-decoder'
#show_basis = 'Seurat'
#show_basis = 'scVI'
latent_vals = adata_plot.obsm[show_basis]
# Filter by variance
variance = np.var(latent_vals, axis=0)
top_k=50
top_k_indices = np.argsort(variance)[::-1][:top_k]
latent_vals = latent_vals[:, top_k_indices]

ccd.pl.heatmap_with_annotations(adata_plot, val=latent_vals, transpose=True, obs_keys=['lin_or_ct', 'embryo.time'], 
                             cmap='viridis', vmin=None, vmax=None, 
                             cluster_rows=True, cluster_cols=True, pal={'lin_or_ct': lin_or_ct_palette_use, 'embryo.time': 'BlueGreenRed'}, add_color_legend=True,
                             value_annot=False, title=None, title_fontsize=16, annot_fontsize=8,
                             yticklabels=True, xticklabels=False, 
                             use_clustermap=True, 
                             cluster_method='average',
                             rasterize=True,
                             ax=None,
                             figsize=(14, 12), 
                             dpi=300, show=True, save_path=save_dir / f"heatmap_latent_embedding_top_{top_k}_{show_basis}_{file_suffix}_yticks.pdf")

plt.show()


### Model activation pattern

In [None]:
# Load model and predict latent
config_file = '../save/dev_cbce_1217-Dec18/config.json'
model_file = '../save/dev_cbce_1217-Dec18/final_model.pth'
# Load config
concord_args = ccd.ul.load_json(str(config_file))
concord_args['pretrained_model'] = model_file

# Downsample data to a small subset for fast estimation of feature contribution to the latent space
import re
layer_index = 6
import numpy as np
#adata_subset = adata.copy()[np.random.choice(adata.n_obs, 10000, replace=False), cur_ccd.config.input_feature]


In [None]:
adata_trace = adata_plot[:, concord_args['input_feature']]

cur_ccd = ccd.Concord(adata=adata_trace, **concord_args)
cur_ccd.config.device = 'mps'
cur_ccd.init_model()
cur_ccd.init_dataloader(input_layer_key='X', preprocess=True, train_frac=1.0, use_sampler=False)
input_tensors = torch.tensor(adata_trace.X.toarray()).to(cur_ccd.config.device)
importance_matrix = ccd.ul.compute_feature_importance(cur_ccd.model, input_tensors, layer_index=layer_index)


In [None]:

ranked_lists = ccd.ul.prepare_ranked_list(importance_matrix, adata=adata_trace, expr_level=True)
# Define filters
min_zero_frac = 0.03
min_expression_level = 0

# Filter function
def filter_genes(df, min_zero_frac, min_expression_level):
    return df[(df["Nonzero Fraction"] > min_zero_frac) & (df["Expression Level"] > min_expression_level)]

# Apply the filter to all neuron lists
filtered_gene_lists = {key: filter_genes(df, min_zero_frac, min_expression_level) for key, df in ranked_lists.items()}

In [None]:
filtered_gene_lists

In [None]:
show_neurons = ['Neuron 8', 'Neuron 9', 'Neuron 4', 'Neuron 2', 'Neuron 15', 'Neuron 17']
show_gene_lists = filtered_gene_lists
show_gene_lists = {key: show_gene_lists[key] for key in show_neurons}
show_basis = 'Concord-decoder_UMAP_zoomrecomp'
ccd.pl.plot_top_genes_embedding(adata_subsub, show_gene_lists, show_basis, top_x=8, figsize=(7.5, 1), point_size=1,
                            font_size=7, colorbar_loc=None, vmax_quantile=.99,
                    save_path=save_dir / f"ASEASJAUA_embeddings_{show_basis}")

### Early 200

In [None]:
subset_name = 'early200'
adata_subset = sc.read(data_dir / f"adata_cbce_Dec23-1707_early200.h5ad")

In [None]:
subset_name = 'early200'
adata_subset = adata[adata.obs['embryo.time'] <= 200]

In [None]:
# Run umap and PCA for all latent embeddings
for basis in combined_keys:
    print("Running UMAP for", basis)
    if basis not in adata_subset.obsm:
        continue
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata_subset, source_key=basis, result_key=f'{basis}_UMAP_{subset_name}', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [None]:
# plot everything
import matplotlib.pyplot as plt
import pandas as pd

show_keys = combined_keys
show_cols = ['embryo.time', 'cell_type', 'species', 'lineage_complete', 'ct_or_lin']
basis_types = [f'UMAP_{subset_name}']
font_size=10
point_size=.4
alpha=0.8
figsize=(10,1.4)
ncols = 8
nrows = int(np.ceil(len(show_keys) / ncols))

with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata_subset,
        show_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=pal,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        save_dir=save_dir,
        file_suffix=file_suffix,
        dpi=600,
        save_format='svg'
    )


In [None]:
ct_counts = adata_subset.obs['lineage_complete'].value_counts()
ct_counts.tail(50)

In [None]:
ct_counts = adata_subset.obs['cell_type'].value_counts()
ignore_cts = ct_counts[ct_counts < 50].index
print(ignore_cts)
adata_subset.obs['plot_cell_type'] = adata_subset.obs['cell_type'].astype(str)
adata_subset.obs['plot_cell_type'][adata_subset.obs['plot_cell_type'].isin(ignore_cts)] = 'NaN'

ct_counts = adata_subset.obs['lineage_complete'].value_counts()
ignore_cts = ct_counts[ct_counts < 10].index
print(ignore_cts)
adata_subset.obs['plot_lineage'] = adata_subset.obs['lineage_complete'].astype(str)
adata_subset.obs['plot_lineage'][adata_subset.obs['plot_lineage'].isin(ignore_cts)] = 'NaN'


adata_subset.obs['ct_or_broad_lin'] = adata_subset.obs['plot_cell_type'].astype(str)
adata_subset.obs['ct_or_broad_lin'][adata_subset.obs['plot_cell_type']=='NaN'] = adata_subset.obs['broad_lineage'][adata_subset.obs['plot_cell_type']=='NaN'].astype(str)
ct_counts = adata_subset.obs['ct_or_broad_lin'].value_counts()
ignore_cts = ct_counts[ct_counts < 50].index
print(ignore_cts)
adata_subset.obs['ct_or_broad_lin'][adata_subset.obs['ct_or_broad_lin'].isin(ignore_cts)] = 'NaN'

adata_subset.obs['broad_lin_or_ct'] = adata_subset.obs['broad_lineage'].astype(str)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lineage']=='NaN'] = adata_subset.obs['plot_cell_type'][adata_subset.obs['broad_lineage']=='NaN'].astype(str)
ct_counts = adata_subset.obs['broad_lin_or_ct'].value_counts()
ignore_cts = ct_counts[ct_counts < 50].index
print(ignore_cts)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lin_or_ct'].isin(ignore_cts)] = 'NaN'

use_pal = pal.copy()
_,_,use_pal['plot_cell_type'] = ccd.pl.get_color_mapping(adata_subset, 'plot_cell_type', pal='Paired', seed=42)
_,_,use_pal['ct_or_broad_lin'] = ccd.pl.get_color_mapping(adata_subset, 'ct_or_broad_lin', pal='Paired', seed=seed)
_,_,use_pal['broad_lin_or_ct'] = ccd.pl.get_color_mapping(adata_subset, 'broad_lin_or_ct', pal='Paired', seed=42)

In [None]:
basis = 'Concord-decoder'
show_basis = f'{basis}_UMAP_{subset_name}'
show_cols = ['embryo.time', 'plot_cell_type', 'plot_lineage', 'ct_or_broad_lin', 'broad_lin_or_ct', 'species']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_subset, show_basis, show_cols, figsize=(10,7), dpi=600, ncols=3, font_size=4, point_size=.8, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_wttext.pdf"
    )

In [None]:
basis = f'Concord-decoder_UMAP_{subset_name}'
plot_label = True  # set True if you want text labels

plt.figure(figsize=(8, 8), dpi=600)

# a) Plot all points in background
plt.scatter(
    adata_subset.obsm[basis][:,0],
    adata_subset.obsm[basis][:,1],
    rasterized=True, 
    zorder=0,  
    s=0.5, color="lightgray", alpha=0.8,
    edgecolors="none"
)

# b) For each path, fetch the group label, color, etc.
for path_idx, path in enumerate(paths):
    group_label = path_labels[path_idx]  # assigned above
    color = group2color[group_label]     # get color

    rep_points = []
    labels = []

    for node in path:
        attrs = G.nodes[node]
        linannot_list = parse_annotation(attrs.get('linannot'))
        celltype_list = parse_annotation(attrs.get('celltype'))

        # Decide which annotation is used
        if linannot_list:
            mask = adata_subset.obs['plot_lineage'].isin(linannot_list)
            used_annot = linannot_list
        elif celltype_list:
            mask = adata_subset.obs['plot_cell_type'].isin(celltype_list)
            used_annot = celltype_list
        else:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
            continue

        cell_indices = np.where(mask)[0]
        if len(cell_indices)==0:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
        else:
            coords = adata_subset.obsm[basis][cell_indices]
            rp = get_representative_point(coords, method='medoid', 
                                          max_n_medoid=2000, 
                                          k_top=3, jitter=0, seed=seed)
            rep_points.append(rp)
            labels.append((node, used_annot))

    rep_points = np.array(rep_points)
    valid_mask = ~np.isnan(rep_points[:,0])
    valid_rep_points = rep_points[valid_mask]
    valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]

    # Draw path in chosen color
    plt.plot(
        valid_rep_points[:,0],
        valid_rep_points[:,1],
        color=color,  # Fill color
        marker='o', 
        markersize=2, 
        markeredgecolor='black',  # Edge color
        markeredgewidth=0.1,  # Edge thickness
        linewidth=0.3,
        alpha=0.8,
        zorder=1
    )

    # Optionally label text
    if plot_label:
        for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
            label_text = f"{node_name}\n{annot_list}"
            plt.text(cx, cy, label_text, fontsize=2, color="black", zorder=2, alpha=0.5)

plt.title("Lineage Paths in UMAP")
plt.xlabel("")
plt.ylabel("")
plt.xticks([])
plt.yticks([])

text_ext = "with_text" if plot_label else "no_text"
plt.savefig(save_dir / f"lineage_paths_{subset_name}_{file_suffix}_{text_ext}.pdf")
plt.show()

### Run for each major tissue type

#### Ectoderm

In [None]:
# If load previous
subset_name = 'AB broad'
adata_subset = sc.read(data_dir / f"adata_cbce_Dec26-1019_AB broad.h5ad")

In [None]:
adata_subset.obsm['Contrastive'] = adata[adata_subset.obs_names].obsm['Contrastive']

In [None]:
subset_name = 'AB broad'
selected_lins = ['AB lineage (non-hyp/seam/pha)', 'Hypodermis/Seam', 'Early embryo']
adata_subset = adata[adata.obs['broad_cell_type_qz'].isin(selected_lins)]
print(adata_subset.shape)

In [None]:
adata_subset.obsm

In [None]:
# Run umap and PCA for all latent embeddings
for basis in concord_keys:
    print("Running UMAP for", basis)
    if basis not in adata_subset.obsm:
        continue
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata_subset, source_key=basis, result_key=f'{basis}_UMAP_{subset_name}', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [None]:
# Run umap and PCA for all latent embeddings
for basis in concord_keys:
    print("Running 3D UMAP for", basis)
    if basis not in adata_subset.obsm:
        continue
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata_subset, source_key=basis, result_key=f'{basis}_UMAP_3D_{subset_name}', n_components=3, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [None]:
adata_subset.obs['hyp_only_cell_type'] = adata_subset.obs['cell_type'].astype(str)
adata_subset.obs['hyp_only_cell_type'][~adata_subset.obs['cell_type'].str.contains('hyp')] = 'NaN'
adata_subset.obs['hyp_only_cell_type'].value_counts()

In [None]:
# plot everything
import matplotlib.pyplot as plt
import pandas as pd

show_keys = combined_keys
show_keys = ['Contrastive', 'Concord', 'Concord-decoder']
show_cols = ['embryo.time', 'cell_type', 'species', 'lineage_complete', 'ct_or_lin', 'broad_cell_type_qz']
#show_cols = ['hyp_only_cell_type']
basis_types = [f'UMAP_{subset_name}']
font_size=10
point_size=.2
alpha=0.8
figsize=(4.5,1.4)
ncols = 3
nrows = int(np.ceil(len(show_keys) / ncols))

legend_loc = 'on data'
save_file_suffix = f"{file_suffix}_{legend_loc.replace(' ', '_') if legend_loc is not None else 'no_legend'}"
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata_subset,
        show_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=pal,
        font_size=font_size,
        legend_font_size=1,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        legend_loc = legend_loc,
        save_dir=save_dir,
        file_suffix=save_file_suffix,
        dpi=600,
        save_format='pdf'
    )


In [None]:
adata_subset.obs['broad_lin_or_ct'] = adata_subset.obs['broad_lineage'].astype(str)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lineage']=='NaN'] = adata_subset.obs['plot_cell_type'][adata_subset.obs['broad_lineage']=='NaN'].astype(str)
ct_counts = adata_subset.obs['broad_lin_or_ct'].value_counts()
ct_counts.tail(30)

In [None]:
ct_counts = adata_subset.obs['cell_type'].value_counts()
ignore_cts = ct_counts[ct_counts < 100].index
print(ignore_cts)
adata_subset.obs['plot_cell_type'] = adata_subset.obs['cell_type'].astype(str)
adata_subset.obs['plot_cell_type'][adata_subset.obs['plot_cell_type'].isin(ignore_cts)] = 'NaN'

adata_subset.obs['ct_or_broad_lin'] = adata_subset.obs['plot_cell_type'].astype(str)
adata_subset.obs['ct_or_broad_lin'][adata_subset.obs['plot_cell_type']=='NaN'] = adata_subset.obs['broad_lineage'][adata_subset.obs['plot_cell_type']=='NaN'].astype(str)
ct_counts = adata_subset.obs['ct_or_broad_lin'].value_counts()
ignore_cts = ct_counts[ct_counts < 50].index
print(ignore_cts)
adata_subset.obs['ct_or_broad_lin'][adata_subset.obs['ct_or_broad_lin'].isin(ignore_cts)] = 'NaN'

adata_subset.obs['broad_lin_or_ct'] = adata_subset.obs['broad_lineage'].astype(str)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lineage']=='NaN'] = adata_subset.obs['plot_cell_type'][adata_subset.obs['broad_lineage']=='NaN'].astype(str)
ct_counts = adata_subset.obs['broad_lin_or_ct'].value_counts()
ignore_cts = ct_counts[ct_counts < 50].index
print(ignore_cts)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lin_or_ct'].isin(ignore_cts)] = 'NaN'

adata_subset.obs['plot_lin_or_ct'] = adata_subset.obs['lineage_complete'].astype(str)
adata_subset.obs['plot_lin_or_ct'][adata_subset.obs['lineage_complete']=='NaN'] = adata_subset.obs['plot_cell_type'][adata_subset.obs['lineage_complete']=='NaN'].astype(str)
ct_counts = adata_subset.obs['plot_lin_or_ct'].value_counts()
ignore_cts = ct_counts[ct_counts < 20].index
print(ignore_cts)
adata_subset.obs['plot_lin_or_ct'][adata_subset.obs['plot_lin_or_ct'].isin(ignore_cts)] = 'NaN'

use_pal = pal.copy()
_,_,use_pal['plot_cell_type'] = ccd.pl.get_color_mapping(adata_subset, 'plot_cell_type', pal='Paired', seed=42)
_,_,use_pal['ct_or_broad_lin'] = ccd.pl.get_color_mapping(adata_subset, 'ct_or_broad_lin', pal='Paired', seed=seed)
_,_,use_pal['broad_lin_or_ct'] = ccd.pl.get_color_mapping(adata_subset, 'broad_lin_or_ct', pal='Paired', seed=42)

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = f'{basis}_UMAP_{subset_name}'
show_cols = ['broad_cell_type_qz', 'plot_cell_type','broad_lineage', 'embryo.time',  'ct_or_broad_lin', 'broad_lin_or_ct']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_subset, show_basis, show_cols, figsize=(10,6.7), dpi=600, ncols=3, font_size=3, point_size=.8, legend_loc='on data',
        pal = use_pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_{subset_name}.pdf"
    )

In [None]:
# Plot lin_or_ct with text, find cases where terminal cells are almostly fully connected to early cells

output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = f'{basis}_UMAP_{subset_name}'
show_cols = ['lin_or_ct']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_subset, show_basis, show_cols, figsize=(10,10), dpi=600, ncols=1, font_size=3, point_size=.8, legend_loc='on data',
        pal = pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_{subset_name}_lin_or_ct.pdf"
    )

In [None]:
basis = f'Concord-decoder_UMAP_{subset_name}'
plot_label = False  # set True if you want text labels

plt.figure(figsize=(4, 4), dpi=600)

# a) Plot all points in background
plt.scatter(
    adata_subset.obsm[basis][:,0],
    adata_subset.obsm[basis][:,1],
    rasterized=True, 
    zorder=0,  
    s=0.5, color="lightgray", alpha=0.8,
    edgecolors="none"
)

# b) For each path, fetch the group label, color, etc.
for path_idx, path in enumerate(paths):
    group_label = path_labels[path_idx]  # assigned above
    color = group2color[group_label]     # get color

    rep_points = []
    labels = []

    for node in path:
        attrs = G.nodes[node]
        linannot_list = parse_annotation(attrs.get('linannot'))
        celltype_list = parse_annotation(attrs.get('celltype'))

        # Decide which annotation is used
        if linannot_list:
            mask = adata_subset.obs['plot_lin_or_ct'].isin(linannot_list)
            used_annot = linannot_list
        elif celltype_list:
            mask = adata_subset.obs['plot_lin_or_ct'].isin(celltype_list)
            used_annot = celltype_list
        else:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
            continue

        cell_indices = np.where(mask)[0]
        if len(cell_indices)==0:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
        else:
            coords = adata_subset.obsm[basis][cell_indices]
            rp = get_representative_point(coords, method='medoid', 
                                          max_n_medoid=2000, 
                                          k_top=3, jitter=0, seed=seed)
            rep_points.append(rp)
            labels.append((node, used_annot))

    rep_points = np.array(rep_points)
    valid_mask = ~np.isnan(rep_points[:,0])
    valid_rep_points = rep_points[valid_mask]
    valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]

    # Draw path in chosen color
    plt.plot(
        valid_rep_points[:,0],
        valid_rep_points[:,1],
        color=color,  # Fill color
        marker='o', 
        markersize=2, 
        markeredgecolor='black',  # Edge color
        markeredgewidth=0.1,  # Edge thickness
        linewidth=0.3,
        alpha=0.8,
        zorder=1
    )

    # Optionally label text
    if plot_label:
        for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
            label_text = f"{node_name}\n{annot_list}"
            plt.text(cx, cy, label_text, fontsize=0.5, color="black", zorder=2, alpha=0.5)

plt.title("Lineage Paths in {basis}")
plt.xlabel("")
plt.ylabel("")
plt.xticks([])
plt.yticks([])

text_ext = "with_text" if plot_label else "no_text"
plt.savefig(save_dir / f"lineage_paths_{subset_name}_{basis}_{file_suffix}_{text_ext}.pdf")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

#basis = f'Concord-decoder_UMAP_{subset_name}'
basis = f'scVI_UMAP_{subset_name}'
plot_label = False  # set True if you want text labels

# Get unique groups
unique_path_groups = sorted(set(path_labels))
unique_path_groups = [group for group in unique_path_groups if group != 'Z2/Z3']
# Determine grid size for subplots
n_groups = len(unique_path_groups)
n_cols = 8  # Customize the number of columns
n_rows = int(np.ceil(n_groups / n_cols))

# Set up the figure with subplots
fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2 * n_rows), dpi=300)
axes = axes.flatten()  # Flatten for easier indexing

for i, group_label in enumerate(unique_path_groups):
    print(f"Plotting group {group_label} in subplot {i}")
    ax = axes[i]

    # Plot all points in the background
    ax.scatter(
        adata_subset.obsm[basis][:, 0],
        adata_subset.obsm[basis][:, 1],
        rasterized=True,
        zorder=0,
        s=0.1,
        color="lightgray",
        alpha=0.5,
        edgecolors="none"
    )

    # Filter paths that belong to the current group
    for path_idx, path in enumerate(paths):
        if path_labels[path_idx] != group_label:
            continue  # Skip paths that don't belong to this group

        color = group2color[group_label]  # Get color for this group

        rep_points = []
        labels = []

        for node in path:
            attrs = G.nodes[node]
            linannot_list = parse_annotation(attrs.get('linannot'))
            celltype_list = parse_annotation(attrs.get('celltype'))

            # Decide which annotation is used
            if linannot_list:
                mask = adata_subset.obs['plot_lin_or_ct'].isin(linannot_list)
                used_annot = linannot_list
            elif celltype_list:
                mask = adata_subset.obs['plot_lin_or_ct'].isin(celltype_list)
                used_annot = celltype_list
            else:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
                continue

            cell_indices = np.where(mask)[0]
            if len(cell_indices) == 0:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
            else:
                coords = adata_subset.obsm[basis][cell_indices]
                rp = get_representative_point(coords, method='medoid',
                                              max_n_medoid=2000,
                                              k_top=3, jitter=0, seed=seed)
                rep_points.append(rp)
                labels.append((node, used_annot))

        rep_points = np.array(rep_points)
        valid_mask = ~np.isnan(rep_points[:, 0])
        valid_rep_points = rep_points[valid_mask]
        valid_labels = [labels[j] for j in range(len(labels)) if valid_mask[j]]

        # Draw path in chosen color
        ax.plot(
            valid_rep_points[:, 0],
            valid_rep_points[:, 1],
            color=color,  # Fill color
            marker='o',
            markersize=3,
            markeredgecolor='black',  # Edge color
            markeredgewidth=0.1,  # Edge thickness
            linewidth=0.4,
            alpha=0.8,
            zorder=1
        )

        # Optionally label text
        if plot_label:
            for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
                label_text = f"{node_name}\n{annot_list}"
                ax.text(cx, cy, label_text, fontsize=0.3, color="black", zorder=2, alpha=0.5)

    # Add title for the group
    ax.set_title(f"{group_label}", fontsize=12)
    # Remove x,y axis label , remove ticks
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_xticks([])
    ax.set_yticks([])

# Turn off unused axes
for j in range(len(unique_path_groups), len(axes)):
    axes[j].axis("off")

# Adjust layout
plt.tight_layout()

# Save the figure
text_ext = "with_text" if plot_label else "no_text"
plt.savefig(save_dir / f"lineage_paths_subplots_{subset_name}_{basis}_{file_suffix}_{text_ext}.pdf", dpi=300)
plt.show()


#### Trace the ABpraaa and ABalppp lineage that give rise to neuron

In [None]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

# -------------------------------------------------------------------
# 2) Extract subgraph starting from root 'ABpraaa'
# -------------------------------------------------------------------
#root = 'ABpraaapp'
root = 'ABpra'
descendants = nx.descendants(G, root)
descendants.add(root)
subgraph = G.subgraph(descendants).copy()

# Trim ABprap, ABpraap and their descendants
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABprap')) + ['ABprap'])
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraap')) + ['ABpraap'])
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaaa')) + ['ABpraaaa'])
subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaapa')) + ['ABpraaapa'])


# root = 'ABalp'
# descendants = nx.descendants(G, root)
# descendants.add(root)
# subgraph = G.subgraph(descendants).copy()


# # Trim ABalpa, ABalppa, ABalpppa, ABalppppa and their descendants
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABalpa')) + ['ABalpa'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABalppa')) + ['ABalppa'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABalpppa')) + ['ABalpppa'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABalppppa')) + ['ABalppppa'])


# -------------------------------------------------------------------
# 3) For each node, figure out which broad_lineage_group it belongs to
# -------------------------------------------------------------------
node2group = {}
for node in subgraph.nodes():
    group_label = map_node_to_broad_group(node, broad_lineage_groups)
    node2group[node] = group_label

# Collect all *defined* group labels (i.e. not None)
all_defined_groups = [g for g in node2group.values() if g is not None]
unique_groups = sorted(set(all_defined_groups))


# -------------------------------------------------------------------
# 4) Build a node_color map:
#     - If node2group[node] is not None => group2color[node2group[node]]
#     - else => if node is mapped => black, else grey
# -------------------------------------------------------------------
node2color = {}
for node in subgraph.nodes():
    grp = node2group[node]  # could be None
    mapped_flag = subgraph.nodes[node]['mapped']

    if grp is not None and mapped_flag:
        # This node belongs to a recognized broad group
        node2color[node] = group2color[grp]
    else:
        # No recognized group => black if mapped, grey if not
        node2color[node] = 'black' if mapped_flag else 'lightgray'

# -------------------------------------------------------------------
# 5) Plot the graph
# -------------------------------------------------------------------
# We'll use graphviz layout (dot) for a tree-like structure
pos = nx.nx_agraph.graphviz_layout(subgraph, prog='dot', args='-Grankdir=LR')  # requires pygraphviz

# Extract node colors in consistent order
ordered_nodes = list(subgraph.nodes())
final_node_colors = [node2color[n] for n in ordered_nodes]

# -------------------------------------------------------------------
# 6) Label each node by its original label plus celltype_mapped
# -------------------------------------------------------------------
node_labels = {}
for node in subgraph.nodes():
    celltype_mapped = subgraph.nodes[node].get('celltype', [])
    if celltype_mapped and pd.notna(celltype_mapped):
        label = f"{node}/{','.join(celltype_mapped)}"
    else:
        label = node
    node_labels[node] = label

# Update the plot to use the new labels
plt.figure(figsize=(2, 1), dpi=600)
nx.draw(
    subgraph, pos,
    with_labels=True,
    labels=node_labels,
    nodelist=ordered_nodes,
    node_color=final_node_colors,
    node_size=15,
    font_size=5,
    arrowsize=4,
    width=0.5
)

plt.title(f'{root}')
plt.savefig(save_dir / f"lineage_subtree_{subset_name}_{root}_{file_suffix}.pdf")
plt.show()



In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

k = 30
basis = f'Concord-decoder'  
show_basis = f'{basis}_UMAP_{subset_name}' 
zoom_in = True
show_square=False


subpaths = []
leaf_nodes = [node for node in subgraph.nodes() if subgraph.out_degree(node) == 0]

# Collect paths for each root
for root in roots:
    for leaf in leaf_nodes:
        if nx.has_path(subgraph, source=root, target=leaf):
            path = nx.shortest_path(subgraph, source=root, target=leaf)
            subpaths.append(path)

print(f"Found {len(subpaths)} root-to-leaf paths for roots: {roots}")

# -------------------------------------------------------------------
# Map cells in the subgraph
# -------------------------------------------------------------------
add_inferred_trajectory = False
cellpaths = []
cells_in_subgraph = [subgraph.nodes[node].get('linorct', []) for node in subgraph.nodes]
cells_in_subgraph = [item for sublist in cells_in_subgraph if isinstance(sublist, list) for item in sublist]

# Map colors to lin_or_ct
adata_subsub = adata_subset[adata_subset.obs['lin_or_ct'].isin(cells_in_subgraph)]
_, _, lin_or_ct_palette = ccd.pl.get_color_mapping(adata_subsub, 'lin_or_ct', pal='Set1')
colors = adata_subsub.obs['lin_or_ct'].astype(str).map(lin_or_ct_palette)

# -------------------------------------------------------------------
# Identify valid cell paths
# -------------------------------------------------------------------
for end_cell in leaf_nodes:
    selected_path = [path for path in subpaths if path[-1] == end_cell]
    if len(selected_path) == 0:
        continue
    start_node = selected_path[0][0]
    end_node = selected_path[0][-1]
    selected_path_start = subgraph.nodes[start_node].get('linorct', [])
    selected_path_end = subgraph.nodes[end_node].get('celltype', [])
    if selected_path_end and pd.notna(selected_path_end):  # Optionally filter terminals
        print(f"Start: {start_node} ({selected_path_start}), End: {end_node} ({selected_path_end})")
        cellpaths.append(selected_path[0])

# -------------------------------------------------------------------
# Plot subpaths on the UMAP
# -------------------------------------------------------------------
plot_label = True  # Set True if you want text labels

plt.figure(figsize=(2, 2), dpi=600)

# a) Plot all points in the background
plt.scatter(
    adata_subset.obsm[show_basis][:, 0],
    adata_subset.obsm[show_basis][:, 1],
    rasterized=True,
    zorder=0,
    s=0.1, color="lightgray", alpha=0.8,
    edgecolors="none"
)

# Highlight cells in the subgraph, colored by lin_or_ct
plt.scatter(
    adata_subsub.obsm[show_basis][:, 0],
    adata_subsub.obsm[show_basis][:, 1],
    c=colors,
    rasterized=True,
    zorder=1,
    s=0.3, alpha=0.8,
    edgecolors="none"
)

all_rep_points = []
for path in cellpaths:
    leaf_node = path[-1]
    group_label = map_leaf_to_broad_group(leaf_node, broad_lineage_groups)
    print(group_label)
    color = group2color[group_label]  # Get color

    rep_points = []
    rep_points_idx = []
    labels = []

    for node in path:
        attrs = G.nodes[node]
        linorct_list = parse_annotation(attrs.get('linorct'))

        # Decide which annotation is used
        if linorct_list:
            mask = adata_subset.obs['lin_or_ct'].isin(linorct_list)
            used_annot = linorct_list
        else:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
            continue

        cell_indices = np.where(mask)[0]
        if len(cell_indices) == 0:
            rep_points.append([np.nan, np.nan])
            labels.append(("", []))
        else:
            coords = adata_subset.obsm[show_basis][cell_indices]
            rp, idx = get_representative_point(coords, method='medoid',
                                               max_n_medoid=2000,
                                               k_top=10, jitter=0, return_idx=True, seed=seed)
            rep_points.append(rp)
            rep_points_idx.append(cell_indices[idx])
            labels.append((node, used_annot))

    start_idx = rep_points_idx[0]
    end_idx = rep_points_idx[-1]

    rep_points = np.array(rep_points)
    valid_mask = ~np.isnan(rep_points[:, 0])
    valid_rep_points = rep_points[valid_mask]
    valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]
    all_rep_points.append(valid_rep_points)

    # Draw path in the chosen color
    plt.plot(
        valid_rep_points[:, 0],
        valid_rep_points[:, 1],
        color='black',  # Line color
        marker='o',
        markersize=3,
        markerfacecolor=color,  # Fill color
        markeredgecolor='black',  # Edge color
        markeredgewidth=0.2,  # Edge thickness
        linewidth=0.3,
        alpha=0.8,
        zorder=2
    )

    if add_inferred_trajectory:
        neighborhood = ccd.ml.Neighborhood(adata_subset.obsm[show_basis], k=k, use_faiss=False)
        celltrajectory, _ = ccd.ul.shortest_path_on_knn_graph(neighborhood, k=k, point_a=start_idx, point_b=end_idx, use_faiss=False)

        plt.plot(
            adata_subset.obsm[show_basis][celltrajectory, 0],
            adata_subset.obsm[show_basis][celltrajectory, 1],
            color='black',
            marker='o',
            markersize=0.5,
            markeredgecolor='black',
            markeredgewidth=0.1,
            linewidth=0.3,
            alpha=0.8,
            zorder=1
        )

    # Optionally label text
    if plot_label:
        for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
            label_text = f"{node_name}\n{annot_list}"
            plt.text(cx, cy, label_text, fontsize=0.5, color="black", zorder=2, alpha=0.5)

# Add some margin
if zoom_in:
    all_rep_points = np.concatenate(all_rep_points, axis=0)
    min_x, min_y = np.nanmin(all_rep_points, axis=0)
    max_x, max_y = np.nanmax(all_rep_points, axis=0)
    
    if show_square:
        # Ensure square aspect ratio
        margin = 0.1 * max(max_x - min_x, max_y - min_y)
        center_x = (min_x + max_x) / 2
        center_y = (min_y + max_y) / 2
        half_side = max(max_x - min_x, max_y - min_y) / 2 + margin

        plt.xlim(center_x - half_side, center_x + half_side)
        plt.ylim(center_y - half_side, center_y + half_side)
    else:
        margin = 0.1 * max(max_x - min_x, max_y - min_y)
        plt.xlim(min_x - margin, max_x + margin)
        plt.ylim(min_y - margin, max_y + margin)

plt.title(f"Lineage Paths in UMAP (Roots: {', '.join(roots)})")
plt.xlabel("")
plt.ylabel("")
plt.xticks([])
plt.yticks([])

text_ext = "with_text" if plot_label else "no_text"
zoomin_ext = "zoomin" if zoom_in else "nozoomin"
square_ext = "square" if show_square else "nosquare"
plt.savefig(save_dir / f"lineage_subpaths_{subset_name}_{'_'.join(roots)}_{show_basis}_{file_suffix}_{text_ext}_{zoomin_ext}_{square_ext}.pdf")
plt.show()


In [None]:
# Save adata_sub
adata_subset.write_h5ad(data_dir / f"adata_cbce_{file_suffix}_{subset_name}.h5ad")
print(f"adata_cbce_{file_suffix}_{subset_name}.h5ad")

#### Mesoderm

In [None]:
# Load from previously computed results if not run from the beginning
subset_name = 'Mesoderm'
adata_subset = sc.read(data_dir / f"adata_cbce_Dec21-0244_Mesoderm.h5ad")

In [None]:
subset_name = 'Mesoderm'
selected_lins = ['Mesoderm']
adata_subset = adata[adata.obs['broad_cell_type_qz'].isin(selected_lins)]
print(adata_subset.shape)

In [None]:
# Run umap and PCA for all latent embeddings
for basis in combined_keys:
    print("Running UMAP for", basis)
    if basis not in adata_subset.obsm:
        continue
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata_subset, source_key=basis, result_key=f'{basis}_UMAP_{subset_name}', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [None]:
# Save adata_sub
adata_subset.write_h5ad(data_dir / f"adata_cbce_{file_suffix}_{subset_name}.h5ad")

In [None]:
group2color

In [None]:
ct_counts = adata_subset.obs['cell_type'].value_counts()
ignore_cts = ct_counts[ct_counts < 100].index
print(ignore_cts)
adata_subset.obs['plot_cell_type'] = adata_subset.obs['cell_type'].astype(str)
adata_subset.obs['plot_cell_type'][adata_subset.obs['plot_cell_type'].isin(ignore_cts)] = 'NaN'
# Merge redundant cell types
adata_subset.obs['plot_cell_type'][adata_subset.obs['plot_cell_type'].str.contains('BWM_headrow2')] = 'BWM_headrow2'
adata_subset.obs['plot_cell_type'][adata_subset.obs['plot_cell_type'].str.contains('hmc')] = 'hmc'
adata_subset.obs['plot_cell_type'][adata_subset.obs['plot_cell_type'].str.contains('BWM_headrow1')] = 'BWM_headrow1'

adata_subset.obs['ct_or_broad_lin'] = adata_subset.obs['plot_cell_type'].astype(str)
adata_subset.obs['ct_or_broad_lin'][adata_subset.obs['plot_cell_type']=='NaN'] = adata_subset.obs['broad_lineage'][adata_subset.obs['plot_cell_type']=='NaN'].astype(str)
ct_counts = adata_subset.obs['ct_or_broad_lin'].value_counts()
ignore_cts = ct_counts[ct_counts < 30].index
print(ignore_cts)
adata_subset.obs['ct_or_broad_lin'][adata_subset.obs['ct_or_broad_lin'].isin(ignore_cts)] = 'NaN'

adata_subset.obs['broad_lin_or_ct'] = adata_subset.obs['broad_lineage'].astype(str)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lineage']=='NaN'] = adata_subset.obs['plot_cell_type'][adata_subset.obs['broad_lineage']=='NaN'].astype(str)
ct_counts = adata_subset.obs['broad_lin_or_ct'].value_counts()
ignore_cts = ct_counts[ct_counts < 30].index
print(ignore_cts)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lin_or_ct'].isin(ignore_cts)] = 'NaN'

adata_subset.obs['plot_lin_or_ct'] = adata_subset.obs['lineage_complete'].astype(str)
adata_subset.obs['plot_lin_or_ct'][adata_subset.obs['lineage_complete']=='NaN'] = adata_subset.obs['plot_cell_type'][adata_subset.obs['lineage_complete']=='NaN'].astype(str)
ct_counts = adata_subset.obs['plot_lin_or_ct'].value_counts()
ignore_cts = ct_counts[ct_counts < 20].index
print(ignore_cts)
adata_subset.obs['plot_lin_or_ct'][adata_subset.obs['plot_lin_or_ct'].isin(ignore_cts)] = 'NaN'

use_pal = pal.copy()
_,_,use_pal['plot_cell_type'] = ccd.pl.get_color_mapping(adata_subset, 'plot_cell_type', pal='Paired', seed=42)
_,_,use_pal['ct_or_broad_lin'] = ccd.pl.get_color_mapping(adata_subset, 'ct_or_broad_lin', pal='Paired', seed=seed)

_,_,use_pal['broad_lin_or_ct'] = ccd.pl.get_color_mapping(adata_subset, 'broad_lin_or_ct', pal='Set1', seed=42)
# Replace subset of use_pal['broad_lin_or_ct'] with those values in group2color
for key, value in group2color.items():
    if key in use_pal['broad_lin_or_ct']:
        use_pal['broad_lin_or_ct'][key] = value

In [None]:
adata_subset.obs['broad_lin_or_ct'].value_counts()

In [None]:
# Exclude Z1/Z4 from plotting as bulk of them cluster separately
adata_subset = adata_subset[~adata_subset.obs['plot_cell_type'].isin(['Z1_Z4'])]

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = f'{basis}_UMAP_{subset_name}'
show_cols = ['plot_cell_type', 'ct_or_broad_lin', 'broad_lin_or_ct', 'broad_lineage', 'embryo.time', 'species']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_subset, show_basis, show_cols, figsize=(10,6.7), dpi=600, ncols=3, font_size=3, point_size=.8, legend_loc='on data',
        pal = use_pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_{subset_name}.pdf"
    )

In [None]:
# plot everything
import matplotlib.pyplot as plt
import pandas as pd

show_keys = combined_keys
show_cols = ['embryo.time', 'plot_cell_type', 'species', 'dataset3', 'lineage_complete', 'ct_or_broad_lin', 'broad_lin_or_ct', 'broad_lineage', 'ct_or_lin', 'broad_cell_type_qz']

basis_types = [f'UMAP_{subset_name}']
font_size=10
point_size=.2
alpha=0.8
figsize=(10,1.4)
ncols = 8
nrows = int(np.ceil(len(show_keys) / ncols))

legend_loc = None
save_file_suffix = f"{file_suffix}_{legend_loc.replace(' ', '_') if legend_loc is not None else 'no_legend'}"
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata_subset,
        show_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=use_pal,
        font_size=font_size,
        legend_font_size=1,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        legend_loc = legend_loc,
        save_dir=save_dir,
        file_suffix=save_file_suffix,
        dpi=600,
        save_format='pdf'
    )


In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = f'{basis}_UMAP_{subset_name}'
show_cols = ['broad_lin_or_ct']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_subset, show_basis, show_cols, figsize=(10,10), dpi=600, ncols=1, font_size=3, point_size=.8, legend_loc='on data',
        pal = use_pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_{subset_name}_plot_lin_or_ct.pdf"
    )

In [None]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

# -------------------------------------------------------------------
# 2) Extract subgraph starting from multiple roots
# -------------------------------------------------------------------
# Define the multiple roots
#roots = ['ABarappppp', 'ABalpapppa', 'ABalpappaa', 'ABalappppa', 'ABalapappa', 'ABalappapp', 'ABalapaapp', 'ABalaappp']  # Replace with your desired roots
roots = ['C', 'D', 'MS', 'ABprpppppa', 'ABplpppppa']

# Collect all descendants for each root
all_descendants = set()
for root in roots:
    descendants = nx.descendants(G, root)
    descendants.add(root)  # Include the root itself
    all_descendants.update(descendants)

# Create the subgraph with all collected descendants
subgraph = G.subgraph(all_descendants).copy()

# Optionally trim specific branches if needed
# Example:
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABprap')) + ['ABprap'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraap')) + ['ABpraap'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaaa')) + ['ABpraaaa'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaapa')) + ['ABpraaapa'])

# -------------------------------------------------------------------
# 3) For each node, figure out which broad_lineage_group it belongs to
# -------------------------------------------------------------------
node2group = {}
for node in subgraph.nodes():
    group_label = map_node_to_broad_group(node, broad_lineage_groups)
    node2group[node] = group_label

# Collect all *defined* group labels (i.e. not None)
all_defined_groups = [g for g in node2group.values() if g is not None]
unique_groups = sorted(set(all_defined_groups))

# -------------------------------------------------------------------
# 4) Build a node_color map
# -------------------------------------------------------------------
node2color = {}
for node in subgraph.nodes():
    grp = node2group[node]  # could be None
    mapped_flag = subgraph.nodes[node].get('mapped', False)

    if grp is not None and mapped_flag:
        # This node belongs to a recognized broad group
        node2color[node] = group2color[grp]
    else:
        # No recognized group => black if mapped, grey if not
        node2color[node] = 'black' if mapped_flag else 'lightgray'

# -------------------------------------------------------------------
# 5) Plot the graph
# -------------------------------------------------------------------
# We'll use graphviz layout (dot) for a tree-like structure
pos = nx.nx_agraph.graphviz_layout(subgraph, prog='dot', args='-Grankdir=LR')  # requires pygraphviz

# Extract node colors in consistent order
ordered_nodes = list(subgraph.nodes())
final_node_colors = [node2color[n] for n in ordered_nodes]

# -------------------------------------------------------------------
# 6) Label each node by its original label plus celltype_mapped
# -------------------------------------------------------------------
node_labels = {}
for node in subgraph.nodes():
    celltype_mapped = subgraph.nodes[node].get('celltype', [])
    #print(celltype_mapped)
    if celltype_mapped:
        if(isinstance(celltype_mapped, list)):
            label = f"{node}/{','.join(celltype_mapped)}"
        else:
            if pd.notna(celltype_mapped):
                label = f"{node}/{celltype_mapped}"
            else:
                label = node
    else:
        label = node
    node_labels[node] = label

# Update the plot to use the new labels
plt.figure(figsize=(2, 10), dpi=600)
nx.draw(
    subgraph, pos,
    with_labels=True,
    labels=node_labels,
    nodelist=ordered_nodes,
    node_color=final_node_colors,
    node_size=15,
    font_size=5,
    arrowsize=4,
    width=0.5
)

# Use a combined title for the roots
plt.title(f'Subgraph for Roots: {", ".join(roots)}')
plt.savefig(save_dir / f"lineage_subtree_{'_'.join(roots)}_{file_suffix}.pdf")
plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

k = 30
figsize=(1.5,1.5)
plot_label = False
#basis = f'Concord-decoder'  
for basis in combined_keys:
    show_basis = f'{basis}_UMAP_{subset_name}' 
    zoom_in = False
    show_square=False

    subpaths = []
    leaf_nodes = [node for node in subgraph.nodes() if subgraph.out_degree(node) == 0]

    # Collect paths for each root
    for root in roots:
        for leaf in leaf_nodes:
            if nx.has_path(subgraph, source=root, target=leaf):
                path = nx.shortest_path(subgraph, source=root, target=leaf)
                subpaths.append(path)

    print(f"Found {len(subpaths)} root-to-leaf paths for roots: {roots}")

    # -------------------------------------------------------------------
    # Map cells in the subgraph
    # -------------------------------------------------------------------
    add_inferred_trajectory = False
    cellpaths = []
    cells_in_subgraph = [subgraph.nodes[node].get('linorct', []) for node in subgraph.nodes]
    cells_in_subgraph = [item for sublist in cells_in_subgraph if isinstance(sublist, list) for item in sublist]

    # Map colors to lin_or_ct
    #adata_subsub = adata_subset[adata_subset.obs['lin_or_ct'].isin(cells_in_subgraph)]
    _, _, lin_or_ct_palette = ccd.pl.get_color_mapping(adata_subset, 'plot_lin_or_ct', pal='Set1')
    colors = adata_subset.obs['plot_lin_or_ct'].astype(str).map(lin_or_ct_palette)

    # -------------------------------------------------------------------
    # Identify valid cell paths
    # -------------------------------------------------------------------
    for end_cell in leaf_nodes:
        selected_path = [path for path in subpaths if path[-1] == end_cell]
        if len(selected_path) == 0:
            continue
        start_node = selected_path[0][0]
        end_node = selected_path[0][-1]
        selected_path_start = subgraph.nodes[start_node].get('linorct', [])
        selected_path_end = subgraph.nodes[end_node].get('linorct', [])
        #print(selected_path_end)
        if selected_path_end:  # Optionally filter terminals
            if isinstance(selected_path_end, list):
                selected_path_end = selected_path_end[0]
            if pd.notna(selected_path_end):
                #print(f"Start: {start_node} ({selected_path_start}), End: {end_node} ({selected_path_end})")
                cellpaths.append(selected_path[0])

    # -------------------------------------------------------------------
    # Plot subpaths on the UMAP
    # -------------------------------------------------------------------

    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=600, constrained_layout=True)

    ccd.pl.plot_embedding(
            adata_subset, show_basis, color_by=['broad_lin_or_ct'], ax=ax, font_size=3, point_size=.8, alpha=0.5, text_alpha=0.9, legend_loc=None,
            pal = use_pal, seed=seed,
            save_path=None
        )


    all_rep_points = []
    for path in cellpaths:
        leaf_node = path[-1]
        group_label = map_leaf_to_broad_group(leaf_node, broad_lineage_groups)
        #print(group_label)
        color = group2color[group_label]  # Get color

        rep_points = []
        rep_points_idx = []
        labels = []

        for node in path:
            attrs = G.nodes[node]
            linorct_list = parse_annotation(attrs.get('linorct'))
            #print(linorct_list)
            # Decide which annotation is used
            if linorct_list:
                mask = adata_subset.obs['plot_lin_or_ct'].isin(linorct_list)
                if mask.sum() <= 0:
                    #print(f"Node {node} has no cells")
                    rep_points.append([np.nan, np.nan])
                    labels.append(("", []))
                    continue
                else:
                    used_annot = linorct_list
            else:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
                continue

            cell_indices = np.where(mask)[0]
            if len(cell_indices) == 0:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
            else:
                coords = adata_subset.obsm[show_basis][cell_indices]
                rp, idx = get_representative_point(coords, method='medoid',
                                                max_n_medoid=2000,
                                                k_top=10, jitter=0, return_idx=True, seed=seed)
                rep_points.append(rp)
                rep_points_idx.append(cell_indices[idx])
                labels.append((node, used_annot))

        # Check if rep_points has more than 1 valid points
        if len(rep_points_idx) <= 1:
            continue

        start_idx = rep_points_idx[0]
        end_idx = rep_points_idx[-1]

        rep_points = np.array(rep_points)
        valid_mask = ~np.isnan(rep_points[:, 0])
        valid_rep_points = rep_points[valid_mask]
        valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]
        all_rep_points.append(valid_rep_points)

        # Draw path in the chosen color
        ax.plot(
            valid_rep_points[:, 0],
            valid_rep_points[:, 1],
            color='black',  # Line color
            marker='o',
            markersize=3,
            markerfacecolor=color,  # Fill color
            markeredgecolor='black',  # Edge color
            markeredgewidth=0.2,  # Edge thickness
            linewidth=0.3,
            alpha=0.8,
            zorder=2
        )

        if add_inferred_trajectory:
            neighborhood = ccd.ml.Neighborhood(adata_subset.obsm[show_basis], k=k, use_faiss=False)
            celltrajectory, _ = ccd.ul.shortest_path_on_knn_graph(neighborhood, k=k, point_a=start_idx, point_b=end_idx, use_faiss=False)

            ax.plot(
                adata_subset.obsm[show_basis][celltrajectory, 0],
                adata_subset.obsm[show_basis][celltrajectory, 1],
                color='black',
                marker='o',
                markersize=0.5,
                markeredgecolor='black',
                markeredgewidth=0.1,
                linewidth=0.3,
                alpha=0.8,
                zorder=1
            )

        # Optionally label text
        if plot_label:
            for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
                label_text = f"{node_name}\n{annot_list}"
                ax.text(cx, cy, label_text, fontsize=0.5, color="black", zorder=2, alpha=0.5)

    # Add some margin
    if zoom_in:
        all_rep_points = np.concatenate(all_rep_points, axis=0)
        min_x, min_y = np.nanmin(all_rep_points, axis=0)
        max_x, max_y = np.nanmax(all_rep_points, axis=0)
        
        if show_square:
            # Ensure square aspect ratio
            margin = 0.1 * max(max_x - min_x, max_y - min_y)
            center_x = (min_x + max_x) / 2
            center_y = (min_y + max_y) / 2
            half_side = max(max_x - min_x, max_y - min_y) / 2 + margin

            ax.xlim(center_x - half_side, center_x + half_side)
            ax.ylim(center_y - half_side, center_y + half_side)
        else:
            margin = 0.1 * max(max_x - min_x, max_y - min_y)
            ax.xlim(min_x - margin, max_x + margin)
            ax.ylim(min_y - margin, max_y + margin)

    #plt.title(f"Lineage Paths in UMAP (Roots: {', '.join(roots)})")
    plt.xlabel("")
    plt.ylabel("")
    plt.xticks([])
    plt.yticks([])

    text_ext = "with_text" if plot_label else "no_text"
    zoomin_ext = "zoomin" if zoom_in else "nozoomin"
    square_ext = "square" if show_square else "nosquare"
    plt.savefig(save_dir / f"lineage_subpaths_{subset_name}_{'_'.join(roots)}_{show_basis}_{file_suffix}_{text_ext}_{zoomin_ext}_{square_ext}_sm.pdf")
    #plt.show()


In [None]:
rep_points_idx

In [None]:
adata_subset.obs['plot_lin_or_ct'].value_counts()

### Pharynx

In [None]:
# IF load previous
subset_name = 'Pharynx'
adata_subset = sc.read(data_dir / f"adata_cbce_Dec23-1049_Pharynx.h5ad")

In [None]:
subset_name = 'Pharynx'
selected_lins = ['Pharynx']
adata_subset = adata[adata.obs['broad_cell_type_qz'].isin(selected_lins)]
print(adata_subset.shape)

In [None]:
# Run umap and PCA for all latent embeddings
for basis in combined_keys:
    print("Running UMAP for", basis)
    if basis not in adata_subset.obsm:
        continue
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata_subset, source_key=basis, result_key=f'{basis}_UMAP_{subset_name}', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)


In [None]:
# Save adata_sub
adata_subset.write_h5ad(data_dir / f"adata_cbce_{file_suffix}_{subset_name}.h5ad")

In [None]:
ct_counts = adata_subset.obs['cell_type'].value_counts()
ignore_cts = ct_counts[ct_counts < 80].index
print(ignore_cts)
adata_subset.obs['plot_cell_type'] = adata_subset.obs['cell_type'].astype(str)
adata_subset.obs['plot_cell_type'][adata_subset.obs['plot_cell_type'].isin(ignore_cts)] = 'NaN'
adata_subset.obs['ct_or_broad_lin'] = adata_subset.obs['plot_cell_type'].astype(str)
adata_subset.obs['ct_or_broad_lin'][adata_subset.obs['plot_cell_type']=='NaN'] = adata_subset.obs['broad_lineage'][adata_subset.obs['plot_cell_type']=='NaN'].astype(str)
ct_counts = adata_subset.obs['ct_or_broad_lin'].value_counts()
ignore_cts = ct_counts[ct_counts < 80].index
print(ignore_cts)
adata_subset.obs['ct_or_broad_lin'][adata_subset.obs['ct_or_broad_lin'].isin(ignore_cts)] = 'NaN'

adata_subset.obs['broad_lin_or_ct'] = adata_subset.obs['broad_lineage'].astype(str)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lineage']=='NaN'] = adata_subset.obs['plot_cell_type'][adata_subset.obs['broad_lineage']=='NaN'].astype(str)
ct_counts = adata_subset.obs['broad_lin_or_ct'].value_counts()
ignore_cts = ct_counts[ct_counts < 80].index
print(ignore_cts)
adata_subset.obs['broad_lin_or_ct'][adata_subset.obs['broad_lin_or_ct'].isin(ignore_cts)] = 'NaN'

adata_subset.obs['plot_lin_or_ct'] = adata_subset.obs['lineage_complete'].astype(str)
adata_subset.obs['plot_lin_or_ct'][adata_subset.obs['plot_lin_or_ct']=='nan'] = adata_subset.obs['plot_cell_type'][adata_subset.obs['plot_lin_or_ct']=='nan'].astype(str)
ct_counts = adata_subset.obs['plot_lin_or_ct'].value_counts()
ignore_cts = ct_counts[ct_counts < 20].index
print(ignore_cts)
adata_subset.obs['plot_lin_or_ct'][adata_subset.obs['plot_lin_or_ct'].isin(ignore_cts)] = 'NaN'

use_pal = pal.copy()
_,_,use_pal['plot_cell_type'] = ccd.pl.get_color_mapping(adata_subset, 'plot_cell_type', pal='Paired', seed=42)
_,_,use_pal['ct_or_broad_lin'] = ccd.pl.get_color_mapping(adata_subset, 'ct_or_broad_lin', pal='Paired', seed=seed)


_,_,use_pal['broad_lin_or_ct'] = ccd.pl.get_color_mapping(adata_subset, 'broad_lin_or_ct', pal='Set1', seed=seed)
# Replace subset of use_pal['broad_lin_or_ct'] with those values in group2color2
group2color2 = group2color.copy()
group2color2['ABara'] = '#61b128' # Avoid color
for key, value in group2color2.items():
    if key in use_pal['broad_lin_or_ct']:
        use_pal['broad_lin_or_ct'][key] = value

In [None]:
# Remove 'ABprp' broad lineage which is likely due to annotaton error
adata_subset = adata_subset[~adata_subset.obs['broad_lineage'].isin(['ABprp'])]

In [None]:
output_key = 'Concord-decoder'
basis = output_key
#ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
show_basis = f'{basis}_UMAP_{subset_name}'
show_cols = ['plot_cell_type', 'ct_or_broad_lin', 'broad_lin_or_ct', 'plot_lin_or_ct', 'embryo.time', 'species']
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_embedding(
        adata_subset, show_basis, show_cols, figsize=(10,6.7), dpi=600, ncols=3, font_size=3, point_size=1.2, legend_loc='on data',
        pal = use_pal, seed=seed,
        save_path=save_dir / f"{show_basis}_{file_suffix}_{subset_name}_wtABprp.pdf"
    )

In [None]:
# plot everything
import matplotlib.pyplot as plt
import pandas as pd

show_keys = combined_keys
show_cols = ['embryo.time', 'plot_cell_type', 'species', 'dataset3', 'lineage_complete', 'ct_or_broad_lin', 'broad_lin_or_ct', 'broad_lineage', 'ct_or_lin', 'broad_cell_type_qz']

basis_types = [f'UMAP_{subset_name}']
font_size=10
point_size=.5
alpha=0.8
figsize=(10,1.35)
ncols = 8
nrows = int(np.ceil(len(show_keys) / ncols))

legend_loc = 'on data'
save_file_suffix = f"{file_suffix}_{legend_loc.replace(' ', '_') if legend_loc is not None else 'no_legend'}"
with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata_subset,
        show_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        pal=use_pal,
        font_size=font_size,
        legend_font_size=1,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        legend_loc = legend_loc,
        save_dir=save_dir,
        file_suffix=save_file_suffix,
        dpi=600,
        save_format='pdf'
    )


In [None]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

# -------------------------------------------------------------------
# 2) Extract subgraph starting from multiple roots
# -------------------------------------------------------------------
# Define the multiple roots
#roots = ['ABarappppp', 'ABalpapppa', 'ABalpappaa', 'ABalappppa', 'ABalapappa', 'ABalappapp', 'ABalapaapp', 'ABalaappp']  # Replace with your desired roots
roots = ['MS', 'ABalp', 'ABara']

# Collect all descendants for each root
all_descendants = set()
for root in roots:
    descendants = nx.descendants(G, root)
    descendants.add(root)  # Include the root itself
    all_descendants.update(descendants)

# Create the subgraph with all collected descendants
subgraph = G.subgraph(all_descendants).copy()

# Optionally trim specific branches if needed
# Example:
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABprap')) + ['ABprap'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraap')) + ['ABpraap'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaaa')) + ['ABpraaaa'])
# subgraph.remove_nodes_from(list(nx.descendants(subgraph, 'ABpraaapa')) + ['ABpraaapa'])

# -------------------------------------------------------------------
# 3) For each node, figure out which broad_lineage_group it belongs to
# -------------------------------------------------------------------
node2group = {}
for node in subgraph.nodes():
    group_label = map_node_to_broad_group(node, broad_lineage_groups)
    node2group[node] = group_label

# Collect all *defined* group labels (i.e. not None)
all_defined_groups = [g for g in node2group.values() if g is not None]
unique_groups = sorted(set(all_defined_groups))

# -------------------------------------------------------------------
# 4) Build a node_color map
# -------------------------------------------------------------------
node2color = {}
for node in subgraph.nodes():
    grp = node2group[node]  # could be None
    mapped_flag = subgraph.nodes[node].get('mapped', False)

    if grp is not None and mapped_flag:
        # This node belongs to a recognized broad group
        node2color[node] = group2color2[grp]
    else:
        # No recognized group => black if mapped, grey if not
        node2color[node] = 'black' if mapped_flag else 'lightgray'

# -------------------------------------------------------------------
# 5) Plot the graph
# -------------------------------------------------------------------
# We'll use graphviz layout (dot) for a tree-like structure
pos = nx.nx_agraph.graphviz_layout(subgraph, prog='dot', args='-Grankdir=LR')  # requires pygraphviz

# Extract node colors in consistent order
ordered_nodes = list(subgraph.nodes())
final_node_colors = [node2color[n] for n in ordered_nodes]

# -------------------------------------------------------------------
# 6) Label each node by its original label plus celltype_mapped
# -------------------------------------------------------------------
node_labels = {}
for node in subgraph.nodes():
    celltype_mapped = subgraph.nodes[node].get('celltype', [])
    #print(celltype_mapped)
    if celltype_mapped:
        if(isinstance(celltype_mapped, list)):
            label = f"{node}/{','.join(celltype_mapped)}"
        else:
            if pd.notna(celltype_mapped):
                label = f"{node}/{celltype_mapped}"
            else:
                label = node
    else:
        label = node
    node_labels[node] = label

# Update the plot to use the new labels
plt.figure(figsize=(2, 10), dpi=600)
nx.draw(
    subgraph, pos,
    with_labels=True,
    labels=node_labels,
    nodelist=ordered_nodes,
    node_color=final_node_colors,
    node_size=15,
    font_size=5,
    arrowsize=4,
    width=0.5
)

# Use a combined title for the roots
plt.title(f'Subgraph for Roots: {", ".join(roots)}')
plt.savefig(save_dir / f"lineage_subtree_{'_'.join(roots)}_{file_suffix}.pdf")
plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

k = 30
#figsize=(1.5,1.5)
figsize=(4,4)
plot_label = False
#basis = f'Concord-decoder'
show_keys = ['Concord-decoder', 'Seurat', 'scVI']  
for basis in show_keys:
    show_basis = f'{basis}_UMAP_{subset_name}' 
    zoom_in = False
    show_square=False

    subpaths = []
    leaf_nodes = [node for node in subgraph.nodes() if subgraph.out_degree(node) == 0]

    # Collect paths for each root
    for root in roots:
        for leaf in leaf_nodes:
            if nx.has_path(subgraph, source=root, target=leaf):
                path = nx.shortest_path(subgraph, source=root, target=leaf)
                subpaths.append(path)

    print(f"Found {len(subpaths)} root-to-leaf paths for roots: {roots}")

    # -------------------------------------------------------------------
    # Map cells in the subgraph
    # -------------------------------------------------------------------
    add_inferred_trajectory = False
    cellpaths = []
    highlightpaths = {}
    cells_in_subgraph = [subgraph.nodes[node].get('linorct', []) for node in subgraph.nodes]
    cells_in_subgraph = [item for sublist in cells_in_subgraph if isinstance(sublist, list) for item in sublist]

    # Map colors to lin_or_ct
    #adata_subsub = adata_subset[adata_subset.obs['lin_or_ct'].isin(cells_in_subgraph)]
    _, _, lin_or_ct_palette = ccd.pl.get_color_mapping(adata_subset, 'plot_lin_or_ct', pal='Set1')
    colors = adata_subset.obs['plot_lin_or_ct'].astype(str).map(lin_or_ct_palette)

    # -------------------------------------------------------------------
    # Identify valid cell paths
    # -------------------------------------------------------------------
    for end_cell in leaf_nodes:
        selected_path = [path for path in subpaths if path[-1] == end_cell]
        if len(selected_path) == 0:
            continue
        start_node = selected_path[0][0]
        end_node = selected_path[0][-1]
        selected_path_start = subgraph.nodes[start_node].get('linorct', [])
        selected_path_end = subgraph.nodes[end_node].get('linorct', [])
        if selected_path_end:  # Optionally filter terminals
            if isinstance(selected_path_end, list):
                selected_path_end = selected_path_end[0]
            if pd.notna(selected_path_end):
                #print(f"Start: {start_node} ({selected_path_start}), End: {end_node} ({selected_path_end})")
                cellpaths.append(selected_path[0])
                # check if any element in selected_path_end contains 'pm1_pm2', or ''pm3_pm4_pm5c' or 'pm7', 'pm6', 'pm8'
                if any(x in selected_path_end for x in ['pm1_pm2']):
                    highlightpaths['pm1_pm2'] = selected_path[0]
                if any(x in selected_path_end for x in ['pm3_pm4_pm5c']) & ('MS' == start_node):
                    highlightpaths['pm3_pm4_pm5c_MS'] = selected_path[0]
                if any(x in selected_path_end for x in ['pm3_pm4_pm5c']) & ('ABalp' == start_node):
                    highlightpaths['pm3_pm4_pm5c_ABalp'] = selected_path[0]
                if any(x in selected_path_end for x in ['pm6']):
                    highlightpaths['pm6'] = selected_path[0]

    # -------------------------------------------------------------------
    # Plot subpaths on the UMAP
    # -------------------------------------------------------------------

    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=600, constrained_layout=True)

    ccd.pl.plot_embedding(
            adata_subset, show_basis, color_by=['broad_lin_or_ct'], ax=ax, font_size=3, point_size=4, alpha=0.8, text_alpha=0.9, legend_loc=None,
            pal = use_pal, seed=seed,
            save_path=None
        )


    all_rep_points = []
    for path in cellpaths:
        leaf_node = path[-1]
        # Check if path is in highlightpaths
        if path in highlightpaths.values():
            line_color = 'black'
            line_width = 1
            line_alpha = 0.7
        else:
            line_color = 'lightgrey'
            line_width = 0.5
            line_alpha = 0.5
        group_label = map_leaf_to_broad_group(leaf_node, broad_lineage_groups)
        #print(group_label)
        color = group2color2[group_label]  # Get color

        rep_points = []
        rep_points_idx = []
        labels = []

        for node in path:
            attrs = G.nodes[node]
            linorct_list = parse_annotation(attrs.get('linorct'))
            # Decide which annotation is used
            if linorct_list:
                mask = adata_subset.obs['plot_lin_or_ct'].isin(linorct_list)
                if mask.sum() <= 0:
                    #print(f"Node {node} has no cells")
                    rep_points.append([np.nan, np.nan])
                    labels.append(("", []))
                    continue
                else:
                    used_annot = linorct_list
            else:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
                continue

            cell_indices = np.where(mask)[0]
            if len(cell_indices) == 0:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
            else:
                coords = adata_subset.obsm[show_basis][cell_indices]
                rp, idx = get_representative_point(coords, method='medoid',
                                                max_n_medoid=2000,
                                                k_top=10, jitter=0, return_idx=True, seed=seed)
                rep_points.append(rp)
                rep_points_idx.append(cell_indices[idx])
                labels.append((node, used_annot))

            

        # Check if rep_points has more than 1 valid points
        if len(rep_points_idx) <= 1:
            continue

        start_idx = rep_points_idx[0]
        end_idx = rep_points_idx[-1]

        rep_points = np.array(rep_points)
        valid_mask = ~np.isnan(rep_points[:, 0])
        valid_rep_points = rep_points[valid_mask]
        valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]
        all_rep_points.append(valid_rep_points)

        # Draw path in the chosen color
        ax.plot(
            valid_rep_points[:, 0],
            valid_rep_points[:, 1],
            color=line_color,  # Line color
            marker='o',
            markersize=5,
            markerfacecolor=color,  # Fill color
            markeredgecolor='black',  # Edge color
            markeredgewidth=0.5,  # Edge thickness
            linewidth=line_width,
            alpha=line_alpha,
            zorder=2
        )

        if add_inferred_trajectory:
            neighborhood = ccd.ml.Neighborhood(adata_subset.obsm[show_basis], k=k, use_faiss=False)
            celltrajectory, _ = ccd.ul.shortest_path_on_knn_graph(neighborhood, k=k, point_a=start_idx, point_b=end_idx, use_faiss=False)

            ax.plot(
                adata_subset.obsm[show_basis][celltrajectory, 0],
                adata_subset.obsm[show_basis][celltrajectory, 1],
                color='black',
                marker='o',
                markersize=0.5,
                markeredgecolor='black',
                markeredgewidth=0.1,
                linewidth=0.5,
                alpha=0.8,
                zorder=1
            )

        # Optionally label text
        if plot_label:
            for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
                label_text = f"{node_name}\n{annot_list}"
                ax.text(cx, cy, label_text, fontsize=0.5, color="black", zorder=2, alpha=0.5)

    # Add some margin
    if zoom_in:
        all_rep_points = np.concatenate(all_rep_points, axis=0)
        min_x, min_y = np.nanmin(all_rep_points, axis=0)
        max_x, max_y = np.nanmax(all_rep_points, axis=0)
        
        if show_square:
            # Ensure square aspect ratio
            margin = 0.1 * max(max_x - min_x, max_y - min_y)
            center_x = (min_x + max_x) / 2
            center_y = (min_y + max_y) / 2
            half_side = max(max_x - min_x, max_y - min_y) / 2 + margin

            ax.xlim(center_x - half_side, center_x + half_side)
            ax.ylim(center_y - half_side, center_y + half_side)
        else:
            margin = 0.1 * max(max_x - min_x, max_y - min_y)
            ax.xlim(min_x - margin, max_x + margin)
            ax.ylim(min_y - margin, max_y + margin)

    #plt.title(f"Lineage Paths in UMAP (Roots: {', '.join(roots)})")
    plt.xlabel("")
    plt.ylabel("")
    plt.xticks([])
    plt.yticks([])

    text_ext = "with_text" if plot_label else "no_text"
    zoomin_ext = "zoomin" if zoom_in else "nozoomin"
    square_ext = "square" if show_square else "nosquare"
    plt.savefig(save_dir / f"lineage_subpaths_{subset_name}_{'_'.join(roots)}_{show_basis}_{file_suffix}_{text_ext}_{zoomin_ext}_{square_ext}.pdf")
    #plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

k = 30
figsize=(1.5,1.5)
#figsize=(4,4)
plot_label = False
#basis = f'Concord-decoder'
show_keys = ['Concord-decoder', 'Seurat', 'scVI']  
for basis in show_keys:
    show_basis = f'{basis}_UMAP_{subset_name}' 
    zoom_in = False
    show_square=False

    subpaths = []
    leaf_nodes = [node for node in subgraph.nodes() if subgraph.out_degree(node) == 0]

    # Collect paths for each root
    for root in roots:
        for leaf in leaf_nodes:
            if nx.has_path(subgraph, source=root, target=leaf):
                path = nx.shortest_path(subgraph, source=root, target=leaf)
                subpaths.append(path)

    print(f"Found {len(subpaths)} root-to-leaf paths for roots: {roots}")

    # -------------------------------------------------------------------
    # Map cells in the subgraph
    # -------------------------------------------------------------------
    add_inferred_trajectory = False
    cellpaths = []
    highlightpaths = {}
    cells_in_subgraph = [subgraph.nodes[node].get('linorct', []) for node in subgraph.nodes]
    cells_in_subgraph = [item for sublist in cells_in_subgraph if isinstance(sublist, list) for item in sublist]

    # Map colors to lin_or_ct
    #adata_subsub = adata_subset[adata_subset.obs['lin_or_ct'].isin(cells_in_subgraph)]
    _, _, lin_or_ct_palette = ccd.pl.get_color_mapping(adata_subset, 'plot_lin_or_ct', pal='Set1')
    colors = adata_subset.obs['plot_lin_or_ct'].astype(str).map(lin_or_ct_palette)

    # -------------------------------------------------------------------
    # Identify valid cell paths
    # -------------------------------------------------------------------
    for end_cell in leaf_nodes:
        selected_path = [path for path in subpaths if path[-1] == end_cell]
        if len(selected_path) == 0:
            continue
        start_node = selected_path[0][0]
        end_node = selected_path[0][-1]
        selected_path_start = subgraph.nodes[start_node].get('linorct', [])
        selected_path_end = subgraph.nodes[end_node].get('linorct', [])
        if selected_path_end:  # Optionally filter terminals
            if isinstance(selected_path_end, list):
                selected_path_end = selected_path_end[0]
            if pd.notna(selected_path_end):
                #print(f"Start: {start_node} ({selected_path_start}), End: {end_node} ({selected_path_end})")
                cellpaths.append(selected_path[0])
                # check if any element in selected_path_end contains 'pm1_pm2', or ''pm3_pm4_pm5c' or 'pm7', 'pm6', 'pm8'
                if any(x in selected_path_end for x in ['pm1_pm2']):
                    highlightpaths['pm1_pm2'] = selected_path[0]
                if any(x in selected_path_end for x in ['pm3_pm4_pm5c']) & ('MS' == start_node):
                    highlightpaths['pm3_pm4_pm5c_MS'] = selected_path[0]
                if any(x in selected_path_end for x in ['pm3_pm4_pm5c']) & ('ABalp' == start_node):
                    highlightpaths['pm3_pm4_pm5c_ABalp'] = selected_path[0]
                if any(x in selected_path_end for x in ['pm6']):
                    highlightpaths['pm6'] = selected_path[0]

    # -------------------------------------------------------------------
    # Plot subpaths on the UMAP
    # -------------------------------------------------------------------

    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=600, constrained_layout=True)

    ccd.pl.plot_embedding(
            adata_subset, show_basis, color_by=['broad_lin_or_ct'], ax=ax, font_size=3, point_size=1, alpha=0.8, text_alpha=0.9, legend_loc=None,
            pal = use_pal, seed=seed,
            save_path=None
        )


    all_rep_points = []
    for path in cellpaths:
        leaf_node = path[-1]
        # Check if path is in highlightpaths
        if path in highlightpaths.values():
            line_color = 'black'
            line_width = 1
            line_alpha = 0.7
        else:
            line_color = 'lightgrey'
            line_width = 0.2
            line_alpha = 0.5
        group_label = map_leaf_to_broad_group(leaf_node, broad_lineage_groups)
        #print(group_label)
        color = group2color2[group_label]  # Get color

        rep_points = []
        rep_points_idx = []
        labels = []

        for node in path:
            attrs = G.nodes[node]
            linorct_list = parse_annotation(attrs.get('linorct'))
            # Decide which annotation is used
            if linorct_list:
                mask = adata_subset.obs['plot_lin_or_ct'].isin(linorct_list)
                if mask.sum() <= 0:
                    #print(f"Node {node} has no cells")
                    rep_points.append([np.nan, np.nan])
                    labels.append(("", []))
                    continue
                else:
                    used_annot = linorct_list
            else:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
                continue

            cell_indices = np.where(mask)[0]
            if len(cell_indices) == 0:
                rep_points.append([np.nan, np.nan])
                labels.append(("", []))
            else:
                coords = adata_subset.obsm[show_basis][cell_indices]
                rp, idx = get_representative_point(coords, method='medoid',
                                                max_n_medoid=2000,
                                                k_top=10, jitter=0, return_idx=True, seed=seed)
                rep_points.append(rp)
                rep_points_idx.append(cell_indices[idx])
                labels.append((node, used_annot))

            

        # Check if rep_points has more than 1 valid points
        if len(rep_points_idx) <= 1:
            continue

        start_idx = rep_points_idx[0]
        end_idx = rep_points_idx[-1]

        rep_points = np.array(rep_points)
        valid_mask = ~np.isnan(rep_points[:, 0])
        valid_rep_points = rep_points[valid_mask]
        valid_labels = [labels[i] for i in range(len(labels)) if valid_mask[i]]
        all_rep_points.append(valid_rep_points)

        # Draw path in the chosen color
        ax.plot(
            valid_rep_points[:, 0],
            valid_rep_points[:, 1],
            color=line_color,  # Line color
            marker='o',
            markersize=2,
            markerfacecolor=color,  # Fill color
            markeredgecolor='black',  # Edge color
            markeredgewidth=0.2,  # Edge thickness
            linewidth=line_width,
            alpha=line_alpha,
            zorder=2
        )

        if add_inferred_trajectory:
            neighborhood = ccd.ml.Neighborhood(adata_subset.obsm[show_basis], k=k, use_faiss=False)
            celltrajectory, _ = ccd.ul.shortest_path_on_knn_graph(neighborhood, k=k, point_a=start_idx, point_b=end_idx, use_faiss=False)

            ax.plot(
                adata_subset.obsm[show_basis][celltrajectory, 0],
                adata_subset.obsm[show_basis][celltrajectory, 1],
                color='black',
                marker='o',
                markersize=0.5,
                markeredgecolor='black',
                markeredgewidth=0.1,
                linewidth=0.5,
                alpha=0.8,
                zorder=1
            )

        # Optionally label text
        if plot_label:
            for (cx, cy), (node_name, annot_list) in zip(valid_rep_points, valid_labels):
                label_text = f"{node_name}\n{annot_list}"
                ax.text(cx, cy, label_text, fontsize=0.5, color="black", zorder=2, alpha=0.5)

    # Add some margin
    if zoom_in:
        all_rep_points = np.concatenate(all_rep_points, axis=0)
        min_x, min_y = np.nanmin(all_rep_points, axis=0)
        max_x, max_y = np.nanmax(all_rep_points, axis=0)
        
        if show_square:
            # Ensure square aspect ratio
            margin = 0.1 * max(max_x - min_x, max_y - min_y)
            center_x = (min_x + max_x) / 2
            center_y = (min_y + max_y) / 2
            half_side = max(max_x - min_x, max_y - min_y) / 2 + margin

            ax.xlim(center_x - half_side, center_x + half_side)
            ax.ylim(center_y - half_side, center_y + half_side)
        else:
            margin = 0.1 * max(max_x - min_x, max_y - min_y)
            ax.xlim(min_x - margin, max_x + margin)
            ax.ylim(min_y - margin, max_y + margin)

    #plt.title(f"Lineage Paths in UMAP (Roots: {', '.join(roots)})")
    plt.xlabel("")
    plt.ylabel("")
    plt.xticks([])
    plt.yticks([])

    text_ext = "with_text" if plot_label else "no_text"
    zoomin_ext = "zoomin" if zoom_in else "nozoomin"
    square_ext = "square" if show_square else "nosquare"
    plt.savefig(save_dir / f"lineage_subpaths_{subset_name}_{'_'.join(roots)}_{show_basis}_{file_suffix}_{text_ext}_{zoomin_ext}_{square_ext}.pdf")
    #plt.show()


In [None]:
combined_keys

### Show plot of running time

In [None]:
# Load run time log
timelog_dict = {
    'Concord/Concord-decoder': '../save/dev_cbce_1217-Dec18/time_log_Dec18-1358.pkl',
    'Harmony': '../save/dev_cbce_1217-Dec18/time_log_Harmony_Dec18-1449.pkl',
    'Liger': '../save/dev_cbce_1217-Dec18/time_log_Liger_Dec18-2152.pkl',
    'Scanorama': '../save/dev_cbce_1217-Dec19/time_log_Scanorama_Dec19-0701.pkl',
    'scVI': '../save/dev_cbce_1217-Dec18/time_log_scVI_Dec18-1855.pkl',
}

# Plot time used as bar plot
# A dictionary to store total run times (in seconds) for each method
import pickle
run_times_seconds = {}

for method, pkl_path in timelog_dict.items():
    with open(pkl_path, 'rb') as f:
        time_log = pickle.load(f)

        # If each pickle has exactly one key-value with the runtime
        # e.g. time_log = {'Harmony': 123.456} in seconds
        # or time_log = {'Harmony': 123.456, ...} if multiple keys are stored,
        # pick the appropriate key or sum them if needed.

        # Here we assume each file has a single entry. If that is not the case,
        # adjust the logic as needed.
        run_time_value = list(time_log.values())[0]
        
        run_times_seconds[method] = run_time_value

# Convert seconds to hours
run_times_hours = {method: secs / 3600.0 for method, secs in run_times_seconds.items()}


In [None]:
run_times_hours['Seurat'] = 6.084378

In [None]:
run_times_hours

In [None]:
import pandas as pd
# Turn into a DataFrame for easier manipulation
df = pd.DataFrame.from_dict(run_times_hours, orient='index', columns=['Run Time (hours)'])

# Sort by run time from low to high
df.sort_values(by='Run Time (hours)', inplace=True)

with plt.rc_context(rc=custom_rc):
    fig, ax = plt.subplots(figsize=(4.5, 2))

    # Plot horizontal bar chart
    df.plot(kind='barh', legend=False, ax=ax)

    # Label axes
    ax.set_xlabel("Run Time (hours)")
    ax.set_ylabel("Integration Method")
    ax.set_title("Run Time Comparison")

    # ------ Key part: set ticks at 1-hour intervals ------
    max_hours = df['Run Time (hours)'].max()
    # Create a range from 0 up to the next integer hour
    x_ticks = np.arange(0, int(max_hours) + 2, 1)  # +2 to ensure we cover the top bound
    ax.set_xticks(x_ticks)
    ax.set_xlim([0, int(max_hours)+1])  # Optional: set x-axis limit if desired

    plt.tight_layout()

    # Save and show
    # Replace save_dir / f"integration_methods_run_time_{file_suffix}.pdf" with your actual path
    plt.savefig(save_dir / f"integration_methods_run_time_{file_suffix}.pdf")
    plt.show()

In [None]:
max_hours

### Save to VisCello

In [None]:
ct_dict = {
    'Neuron_ASE_ASJ_AUA': data_dir/'adata_subsub_aseasjaua_Jan30-1028.h5ad',
    'AB_nonpharynx': data_dir/'adata_cbce_Dec26-1019_AB broad.h5ad',
    'Mesoderm_nonpharynx': data_dir/'adata_cbce_Dec21-0244_Mesoderm.h5ad',
    'Pharynx': data_dir/'adata_cbce_Dec23-1049_Pharynx.h5ad',
    'Intestine': data_dir/'adata_cbce_Dec21-0244_Intestine.h5ad',
    'Early200min': data_dir/'adata_cbce_Dec23-1707_early200.h5ad',
}

In [None]:
# Load adata subsets into dict
adata_subsets = {}
adata_clean = adata[adata.obs['broad_cell_type_qz'] != 'doublet/debris']
adata_subsets['Global_dataset_cleaned'] = adata_clean
for ct, path in ct_dict.items():
    adata_subsets[ct] = sc.read(path)
    adata_subsets[ct].obsm = {key: adata_subsets[ct].obsm[key] for key in adata_subsets[ct].obsm.keys() if 'UMAP' in key and 'Dec' not in key}

In [None]:
adata_subsets['Pharynx'].obsm

In [None]:
viscello_dir = str(data_dir / f"cello_{proj_name}_{file_suffix}")

ccd.ul.update_clist_with_subsets(global_adata = adata, adata_subsets = adata_subsets, viscello_dir = viscello_dir)