In [None]:
import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import scipy
import scipy.stats as stats
import seaborn as sns

from tqdm import tqdm

In [None]:
%config InlineBackend.print_figure_kwargs={'dpi': 300.0}

# E11.5

## Load our data and the reference

In [None]:
e115_ref = anndata.read_h5ad("/home/danielyumengzhu/Single-cell-data/Shendure_reference/embryo_E11.5_ref_stereo-seq_gene_intersection.h5ad")
e115 = anndata.read_h5ad("/home/danielyumengzhu/Single-cell-data/cellbin_mouse_embryo_aligned_E11.5/E11.5_full_final_mod.h5ad")

In [None]:
pd.set_option("display.max_rows", 200)
e115.obs["mapped_celltype"].value_counts()

In [None]:
# Map cell types to their major trajectories
major_trajectories = e115_ref.obs["major_trajectory"]
cell_types = e115_ref.obs["celltypes_Spateo_match"]

category_dict_ref = {}

for category in major_trajectories.unique():
    subset_obs2 = cell_types[major_trajectories == category]
    unique_obs2_categories = subset_obs2.unique().tolist()
    category_dict_ref[category] = unique_obs2_categories
category_dict_ref

In [None]:
ref_categories_ordered = []
for value_list in category_dict_ref.values():
    ref_categories_ordered.extend(value_list)

ref_categories_ordered

### Optionally look @ cell type markers

In [None]:
sc.pp.filter_cells(e115_ref, min_genes=200)
sc.pp.filter_genes(e115_ref, min_cells=3)
sc.pp.normalize_total(e115_ref, target_sum=1e4)
sc.pp.log1p(e115_ref)
sc.pp.highly_variable_genes(e115_ref, min_mean=0.0125, max_mean=3, min_disp=0.5)
e115_ref = e115_ref[:, e115_ref.var.highly_variable]
sc.pp.scale(e115_ref, max_value=10)

In [None]:
e115_ref.obs["celltypes_Spateo_match"].value_counts()

In [None]:
sc.tl.rank_genes_groups(e115_ref, 'major_trajectory', method='t-test')
sc.pl.rank_genes_groups(e115_ref, n_genes=25, sharey=False)

In [None]:
trajectory = "Epithelial_cells"

In [None]:
degs = e115_ref.uns['rank_genes_groups']

# Convert the results to a DataFrame
groups = degs['names'].dtype.names
marker_genes_df = pd.DataFrame(
    {group: degs['names'][group] for group in groups},
    index=range(len(degs['names'][groups[0]]))
)
gene_list = marker_genes_df[trajectory].head(25)

In [None]:
subset = e115_ref[e115_ref.obs["major_trajectory"] == "Epithelial_cells"].copy()
sc.tl.rank_genes_groups(subset, 'celltype_update', method='t-test')
sc.pl.rank_genes_groups(subset, n_genes=25, sharey=False)

In [None]:
ct_ref_sub = e115_ref[e115_ref.obs["celltype_update"] == "Apical ectodermal ridge"].copy()

In [None]:
# Calculate the percentage of cells expressing each gene
percentages = {}
for gene in gene_list:
    if gene in e115_ref.var_names:
        expressed_cells = ct_ref_sub[:, gene].X > 0  
        percentage = np.mean(expressed_cells) * 100  
        percentages[gene] = percentage
    else:
        percentages[gene] = None

percentages_df = pd.DataFrame.from_dict(percentages, orient='index', columns=['PercentExpressing'])
percentages_df

In [None]:
ct = "Apical ectodermal ridge"

In [None]:
degs = subset.uns['rank_genes_groups']

# Convert the results to a DataFrame
groups = degs['names'].dtype.names
marker_genes_df = pd.DataFrame(
    {group: degs['names'][group] for group in groups},
    index=range(len(degs['names'][groups[0]]))
)

In [None]:
#gene_list = marker_genes_df[trajectory].head(25).tolist()
gene_list = ["Shh", "Wnt7a"]

In [None]:
ct_sub = e115[e115.obs["mapped_celltype"] == ct].copy()
ct_sub

In [None]:
ct_rest = e115[e115.obs["mapped_celltype"] != ct].copy()
ct_rest

In [None]:
# Calculate the percentage of cells expressing each gene
percentages = {}
for gene in gene_list:
    if gene in e115.var_names:
        expressed_cells = ct_sub[:, gene].X > 0  
        percentage = np.mean(expressed_cells) * 100  
        percentages[gene] = percentage
    else:
        percentages[gene] = None

percentages_df = pd.DataFrame.from_dict(percentages, orient='index', columns=['PercentExpressing'])
percentages_df

In [None]:
# And the percentage of cells not of this type w/ the given genes expressed:
percentages = {}
for gene in tqdm(gene_list):
    if gene in e115.var_names:
        expressed_cells = ct_rest[:, gene].X > 0  
        percentage = np.mean(expressed_cells) * 100  
        percentages[gene] = percentage
    else:
        percentages[gene] = None

percentages_df = pd.DataFrame.from_dict(percentages, orient='index', columns=['PercentExpressing'])
percentages_df

In [None]:
# Find genes expressed in the highest proportion of cells, exclude essential genes:
exclude = ["Act", "Tub", "Rpl", "Rps", "Ub", "Gapdh", "Hk", "Pfk", "Plk", "Cs", "Aco", "Idh", "Sdh", "Ogd", "Fh", "Mdh", "Aca", "Fas", "Cpt", "Glu", "Got", "Shmt", "Rrm", "Dhf", "Snr", "Hnrn", "Ldha", "Hsp", "H2", "H3", "H4", "Hmgb", "Eef", "Eif", "Atp", "Cox", "Ran", "Gnai", "Malat", "Ppia", "mt-", "Ywh", "Elo", "Ptm", "Tms", "Marck", "Nedd", "Fau", "Hba", "Hbb", "Gm", "Rik"]
gene_expression_percentages = {}

for gene in tqdm(ct_sub.var_names):
    if not any(item in gene for item in exclude):
        # Calculate the percentage of cells where the gene is expressed
        expressed_cells = np.sum(ct_sub[:, gene].X.toarray() > 0, axis=0)
        percentage = (expressed_cells / ct_sub.n_obs) * 100
        gene_expression_percentages[gene] = percentage.item()  # Convert numpy object to scalar

# Convert the results to a DataFrame
gene_expression_df = pd.DataFrame.from_dict(gene_expression_percentages, orient='index', columns=['PercentExpressing'])

# Sort the DataFrame to find genes with the highest expression
top_genes = gene_expression_df.sort_values(by='PercentExpressing', ascending=False)

# Display the top genes
print(top_genes.head(25))

In [None]:
cell_types_to_keep = e115.obs["mapped_celltype"].value_counts()[e115.obs["mapped_celltype"].value_counts() > 1000].index.tolist()

In [None]:
e115 = e115[e115.obs['mapped_celltype'].isin(cell_types_to_keep)].copy()

In [None]:
ref_cts = set(e115_ref.obs["celltype_update"])
e115_cts = set(e115.obs["mapped_celltype"])

### Can start E11.5 analysis here

In [None]:
# Some of the mapping will be commented out because of relatively low number of cells assigned that type label (the E11.5 object will be updated to exclude these cells soon)
mapping = {
    "Anterior floor plate": "Anterior floor plate",
    "Anterior roof plate": "Anterior roof plate",
    "Apical ectodermal ridge": "Apical ectodermal ridge",
    "Arterial endothelial cells": "Arterial endothelial cells",
    "Astrocytes": "Astrocytes",
    # Note the reference does not have many heart cell types, although we know the heart is present & relatively large at this stage
    "Atrial cardiomyocytes": "First heart field",
    "Border-associated macrophages": "Border-associated macrophages",
    "Branchial arch epithelium": "Branchial arch epithelium",
    "CNS capillary endothelial cells": "Brain capillary endothelial cells",
    "Cajal-Retzius cells": "Cajal-Retzius cells",
    "Cardiac fibroblasts": "Fibroblasts",
    "Cardiac mesoderm": "Cardiopharyngeal mesoderm",
    #"Cerebellar Purkinje cells": "Cerebellar Purkinje cells",
    "Chondrocytes": "Chondrocytes (Atp1a2+)",
    "Choroid plexus epithelial precursors": "Choroid plexus",
    # My guess is that this corresponds to a combination of chondrocytes & sclerotome from the single-cell- I defined the head region mesoderm-derived cells as being cranial mesoderm because 
    # "sclerotome" as a concept technically doesn't extend into the head region
    "Cranial mesoderm": "Sclerotome",
    "Cranial motor neurons": "Cranial motor neurons",
    "Definitive early erythroblasts": "Definitive early erythroblasts (CD36-)",
    "Dermatome": "Dermatome",
    "Dermomyotome": "Dermomyotome",
    "Diencephalon neuroectoderm": "Diencephalon",
    #"Dorsal root ganglion neurons": "Dorsal root ganglion neurons",
    "Dorsal telencephalon neuroectoderm": "Dorsal telencephalon",
    "Early chondrocytes": "Early chondrocytes",
    "Ectoderm-derived": "Placodal area",
    "Endocardial cells": "Endocardial cells",
    "Endothelium": "Endothelium",
    "Enteric neurons": "Enteric neurons",
    "Epithelial precursors": "Pre-epidermal keratinocytes",
    "Eye field": "Eye field",
    "Facial mesenchyme": "Facial mesenchyme",
    "Fibroblasts": "Fibroblasts",
    "Floorplate and p3 domain": "Floorplate and p3 domain",
    "GABAergic interneurons": "GABAergic cortical interneurons",
    "GABAergic neurons": "GABAergic neurons",
    "Glutamatergic neurons": "Glutamatergic neurons",
    "Gut": "Gut",
    "Gut mesenchyme": "Gut mesenchyme",
    "Hematopoietic progenitors": "Hematopoietic stem cells (Cd34+)",
    "Hepatic mesenchyme": "Hepatic mesenchyme",
    "Hepatocytes": "Hepatocytes",
    "Hindbrain neuroectoderm": "Hindbrain",
    "Hypothalamus (Sim1+) neuroectoderm": "Hypothalamus (Sim1+)",
    "Hypothalamus neuroectoderm": "Hypothalamus",
    "Intermediate neuronal progenitors": "Intermediate neuronal progenitors",
    "Kupffer cells": "Kupffer cells",
    "Lateral plate and intermediate mesoderm": "Splanchnic mesoderm",
    "Limb progenitors and lateral plate mesenchyme": "Limb mesenchyme progenitors",
    "Liver sinusoidal endothelial cells": "Liver sinusoidal endothelial cells",
    "Lung mesenchyme": "Lung mesenchyme",
    "Lung progenitor cells": "Lung progenitor cells",
    "Lymphatic vessel endothelial cells": "Lymphatic vessel endothelial cells",
    "Megakaryocytes": "Megakaryocytes",
    #"Melanocyte cells": "Melanocyte cells",
    "Mesoderm-derived": "Somatic mesoderm",
    "Mesodermal progenitors": "Mesodermal progenitors (Tbx6+)",
    "Metanephric mesenchyme": "Metanephric mesenchyme",
    "Microglia": "Microglia",
    "Midbrain neuroectoderm": "Midbrain",
    "Midbrain-hindbrain boundary": "Midbrain-hindbrain boundary",
    "Muscle progenitor cells": "Muscle progenitor cells",
    "Myelinating Schwann cells": "Myelinating Schwann cells",
    "Myoblasts": "Myoblasts",
    "Myofibroblasts": "Myofibroblasts",
    "Myotubes": "Myotubes",
    "NMPs and spinal cord progenitors": "NMPs and spinal cord progenitors",
    "Neural crest (PNS glia)": "Neural crest (PNS glia)",
    "Neural crest (PNS neurons)": "Neural crest (PNS neurons)",
    "Neural progenitors": "Neural progenitor cells (Neurod1+)",
    # This is generally scattered across the brain- use telencephalon neuroectoderm for comparison
    "Neuroectoderm-derived": "Telencephalon",
    #"Neurons (Slc17a8+)": "Neurons (Slc17a8+)",
    "Olfactory epithelial cells": "Olfactory epithelial cells",
    "Otic epithelial cells": "Otic epithelial cells",
    "Otic sensory neurons": "Otic sensory neurons",
    "Outflow tract cardiomyocytes": "First heart field",
    "Pituitary and Pineal gland progenitors": "Pituitary/Pineal gland progenitors",
    "Placode and neural crest-derived": "Placodal area",
    "Posterior roof plate": "Posterior roof plate",
    "Primitive erythroid cells": "Primitive erythroid cells",
    "Proepicardium": "Proepicardium",
    "Renal mesenchyme": "Metanephric mesenchyme",
    #"Retinal pigment cells": "Retinal pigment cells",
    "Sclerotome": "Sclerotome",
    "Skeletal muscle progenitors": "Muscle progenitor cells (Prdm1+)",
    "Somitic muscle progenitors": "Muscle progenitor cells (Prdm1+)",
    "Spinal V0/1 interneurons": "Spinal V0 interneurons",
    "Spinal V2/3 interneurons": "Spinal V2a interneurons",
    "Spinal cord dorsal progenitors": "Spinal cord dorsal progenitors",
    "Spinal cord motor neurons": "Spinal cord motor neurons",
    "Spinal cord neuroectoderm": "Spinal cord/r7/r8",
    "Spinal cord ventral progenitors": "Spinal cord ventral progenitors",
    "Spinal dI1 interneurons": "Spinal dI1 interneurons",
    "Spinal dI2/3/4 interneurons": "Spinal dI2 interneurons",
    "Spinal dI5/6 interneurons": "Spinal dI5 interneurons",
    "Sympathetic neurons": "Sympathetic neurons",
    "Telencephalon neuroectoderm": "Telencephalon",
    #"Thalamic neuronal precursors": "Thalamic neuronal precursors",
    "Ureteric bud": "Ureteric bud",
    "Ventricular cardiomyocytes": "First heart field",
    "Zona limitans intrathalamica": "Thalamic neuronal precursors"
}

In [None]:
# Map cell types to their major trajectories
cell_types = e115.obs["mapped_celltype"]

category_dict = {}

for trajectory_category, cell_type_category_list in category_dict_ref.items():
    corresponding_categories = []

    for ct_category in cell_type_category_list:
        matching_ct = [k for k, v in mapping.items() if v == ct_category]
        corresponding_categories.extend(matching_ct)

    # Store the result in the dictionary for the second object
    category_dict[trajectory_category] = corresponding_categories
    
category_dict

In [None]:
categories_ordered = []
for value_list in category_dict.values():
    categories_ordered.extend(value_list)

categories_ordered

## Find reference markers

In [None]:
# Marker genes for each cell type in the reference:
sc.pp.filter_cells(e115_ref, min_genes=200)
sc.pp.filter_genes(e115_ref, min_cells=3)
sc.pp.normalize_total(e115_ref, target_sum=1e4)
sc.pp.log1p(e115_ref)
e115_ref

In [None]:
sc.pp.highly_variable_genes(e115_ref, min_mean=0.0125, max_mean=3, min_disp=0.5)
e115_ref = e115_ref[:, e115_ref.var.highly_variable]
e115_ref

In [None]:
sc.tl.rank_genes_groups(e115_ref, 'celltypes_Spateo_match', method='t-test')
marker_genes = e115_ref.uns['rank_genes_groups']

In [None]:
n_genes = 50
groups = marker_genes['names'].dtype.names
top_genes = {group: list(marker_genes['names'][group][:n_genes]) for group in groups}
top_genes_df = pd.DataFrame(top_genes)
top_genes_df

### Filter to epithelial cells and mesenchymal cells to get markers of these cell types

In [None]:
e115_ref = anndata.read_h5ad("/home/danielyumengzhu/Single-cell-data/Shendure_reference/embryo_E11.5_ref_stereo-seq_gene_intersection.h5ad")
e115_ref_sub = e115_ref[e115_ref.obs["major_trajectory"].isin(["Epithelial_cells", "Mesoderm"])].copy()

In [None]:
# Marker genes for each cell type in the reference:
sc.pp.filter_cells(e115_ref_sub, min_genes=200)
sc.pp.filter_genes(e115_ref_sub, min_cells=3)
sc.pp.normalize_total(e115_ref_sub, target_sum=1e4)
sc.pp.log1p(e115_ref_sub)
e115_ref_sub

In [None]:
sc.pp.highly_variable_genes(e115_ref_sub, min_mean=0.0125, max_mean=3, min_disp=0.5)
e115_ref_sub = e115_ref_sub[:, e115_ref_sub.var.highly_variable]
e115_ref_sub

In [None]:
sc.tl.rank_genes_groups(e115_ref_sub, 'celltypes_Spateo_match', method='t-test')
marker_genes = e115_ref_sub.uns['rank_genes_groups']

In [None]:
n_genes = 50
groups = marker_genes['names'].dtype.names
top_genes = {group: list(marker_genes['names'][group][:n_genes]) for group in groups}
top_genes_df_sub = pd.DataFrame(top_genes)
top_genes_df_sub

In [None]:
top_genes_df.loc[:, top_genes_df_sub.columns] = top_genes_df_sub

In [None]:
top_genes_df.to_csv("/home/danielyumengzhu/E11.5_reference_marker_genes.csv")

In [None]:
top_genes_df = pd.read_csv("/home/danielyumengzhu/E11.5_reference_marker_genes.csv", index_col=0)
top_genes_df

In [None]:
#e115_ref.X = np.log1p(e115_ref.X) / np.log(10)

In [None]:
# Get all unique genes in the DataFrame
all_genes = pd.unique(top_genes_df.values.ravel())
all_genes = all_genes[pd.notna(all_genes)]  # Remove NaN values

# Initialize a DataFrame to store mean expression values
mean_expression_df = pd.DataFrame(index=all_genes)

for cell_type in tqdm(top_genes_df.columns):
    # Subset the expression data for this cell type
    ct_data = e115_ref[e115_ref.obs["celltypes_Spateo_match"] == cell_type].copy()
    if ct_data.n_obs == 0:
        continue

    expression_data = ct_data[:, all_genes].X
    # Check if the expression data is sparse and convert to dense if necessary
    if scipy.sparse.issparse(expression_data):
        expression_data = expression_data.toarray()

    # Compute mean expression across cells (rows)
    mean_expression = np.mean(expression_data, axis=0)

    # Add to the DataFrame
    mean_expression_df[cell_type] = mean_expression
mean_expression_df = (mean_expression_df - mean_expression_df.min()) / (mean_expression_df.max() - mean_expression_df.min())
mean_expression_df

In [None]:
mean_expression_df = mean_expression_df.dropna(axis=1, how='all')
mean_expression_df

In [None]:
mean_expression_df.to_csv("/home/danielyumengzhu/E11.5_reference_mean_expr.csv")

In [None]:
# Significance test:
# Initialize a DataFrame to store p-values for each gene in each column
p_values_df = pd.DataFrame(index=mean_expression_df.index, columns=mean_expression_df.columns)

for column in tqdm(mean_expression_df.columns):
    markers = list(top_genes_df[column].values)
    include = [g for g in mean_expression_df.index if g not in markers]
    for gene in mean_expression_df.index:
        # Count how many genes have a lower expression value than the current gene
        lower_count = (mean_expression_df.loc[include, column] > mean_expression_df.loc[gene, column]).sum()
        
        # Compute the proportion (p-value)
        p_value = lower_count / (len(mean_expression_df.index) - 1) # -1 to exclude the gene itself

        # Store the computed p-value
        p_values_df.loc[gene, column] = p_value
p_values_df

In [None]:
p_values_df.to_csv("/home/danielyumengzhu/E11.5_reference_p_vals.csv")

In [None]:
sig_df = (p_values_df < 0.05).astype(int)

In [None]:
sig_df.to_csv("/home/danielyumengzhu/E11.5_reference_significant.csv")

## Load processed reference information

In [None]:
ref_markers = pd.read_csv("/home/danielyumengzhu/E11.5_reference_marker_genes.csv", index_col=0)
mean_expression_df = pd.read_csv("/home/danielyumengzhu/E11.5_reference_mean_expr.csv", index_col=0)
p_values_df = pd.read_csv("/home/danielyumengzhu/E11.5_reference_p_vals.csv", index_col=0)
sig_df = pd.read_csv("/home/danielyumengzhu/E11.5_reference_significant.csv", index_col=0)

In [None]:
mean_expression_df

## Process the E11.5 data w/ respect to the reference markers

In [None]:
all_genes = pd.unique(ref_markers.values.ravel())
all_genes = all_genes[pd.notna(all_genes)]  # Remove NaN values

# Initialize a DataFrame to store mean expression values
mean_expression_df = pd.DataFrame(index=all_genes)

for cell_type in tqdm(ref_markers.columns):
    keys_for_ct = [key for key, value in mapping.items() if value == cell_type]
    
    # Repeat the process for each key
    for key in keys_for_ct:
        ct = key
        # Subset the expression data for this cell type
        ct_data = e115[e115.obs["mapped_celltype"] == ct].copy()
        expression_data = ct_data[:, all_genes].X
        
        # Check if the expression data is sparse and convert to dense if necessary
        if scipy.sparse.issparse(expression_data):
            expression_data = expression_data.toarray()
        
        if expression_data.shape[0] == 0:
            print(ct)
        
        # Compute mean expression across cells (rows)
        mean_expression = np.mean(expression_data, axis=0)
        
        # Add to the DataFrame, using the current `key` instead of `cell_type`
        mean_expression_df[key] = mean_expression

mean_expression_df = (mean_expression_df - mean_expression_df.min()) / (mean_expression_df.max() - mean_expression_df.min())
mean_expression_df

In [None]:
mean_expression_df.to_csv("/home/danielyumengzhu/E11.5_Stereo-seq_mean_expr.csv")

In [None]:
# Significance test:
# Initialize a DataFrame to store p-values for each gene in each column
p_values_df = pd.DataFrame(index=mean_expression_df.index, columns=mean_expression_df.columns)

for column in tqdm(mean_expression_df.columns):
    ct = mapping[column]
    markers = list(ref_markers[ct].values)
    include = [g for g in mean_expression_df.index if g not in markers]
    for gene in mean_expression_df.index:
        # Count how many genes have a lower expression value than the current gene
        lower_count = (mean_expression_df.loc[include, column] > mean_expression_df.loc[gene, column]).sum()
        
        # Compute the proportion (p-value)
        p_value = lower_count / (len(mean_expression_df.index) - 1) # -1 to exclude the gene itself

        # Store the computed p-value
        p_values_df.loc[gene, column] = p_value
p_values_df

In [None]:
p_values_df.to_csv("/home/danielyumengzhu/E11.5_Stereo-seq_p_vals.csv")

In [None]:
sig_df = (p_values_df < 0.05).astype(int)

In [None]:
sig_df.to_csv("/home/danielyumengzhu/E11.5_Stereo-seq_significant.csv")

## Load significance dataframes for E11.5 and the reference

In [None]:
e115_sig = pd.read_csv("/home/danielyumengzhu/E11.5_Stereo-seq_significant.csv", index_col=0)
e115_ref_sig = pd.read_csv("/home/danielyumengzhu/E11.5_reference_significant.csv", index_col=0)

In [None]:
orig_cols = e115_sig.columns

In [None]:
orig_ref_cols = list(e115_ref_sig.columns)

In [None]:
# Shorten cell type names for the display:
e115_sig.columns = [
    'AFP', 'ARP', 'AER', 'AEC', 'Astro', 'BAM', 'CCEC', 'BAE',
    'CRC', 'CMes', 'Chon', 'CPEP', 'CMN', 'DEEB', 'Derma', 
    'Dermomyo', 'Dien', 'DTelen', 'EChon', 'Endocardial', 'Endo', 
    'EN', 'EF', 'FMesen', 'CFibro', 'Fibro', 'ACM', 'OTCM', 'VCM', 'FP3D', 
    'GABAin', 'GABAn', 'GlutN', 'Gut', 'GMesen', 'HSC', 
    'HepMes', 'Hepato', 'Hind', 'Hypo', 'Hypo(Sim1+)', 
    'INP', 'Kupf', 'LMP', 'LSEC', 'LungMes', 
    'LungP', 'LVEC', 'MegaK', 'MesoP', 'MM', 'RM', 'Micro', 
    'Mid', 'MHB', 'MusP', 'SMP', 'SomMP',
    'MySch', 'Myob', 'Myofibro', 'Myot', 'NSCP', 
    'NC(PG)', 'NC(PN)', 'NeuroP', 'OEC', 'OticEC', 
    'OSN', 'PPGP', 'ED', 'PNC', 
    'PRP', 'EpiP', 'PEry', 'ProEpi', 'CMes', 
    'Sclero', 'MD', 'SCV0/1', 'SCV2/3', 'SCDP', 
    'SCMN', 'SCVP', 'SCN', 'SCdI1', 'SCdI2/3/4', 
    'SCdI5/6', 'LPM', 'SymN', 'ND', 'Telen', 
    'ZLI', 'UB'
]

label_to_abbrev_mapping = {l: a for l, a in zip(orig_cols, e115_sig.columns)}
trajectory_to_abbrev_mapping = {
    key: [label_to_abbrev_mapping[category] for category in categories] for key, categories in category_dict.items()
}

e115_ref_sig.columns = [
    'ASMC', 'Allantois', 'AmiEcto', 'AmiMeso', 'AFP', 'ARP', 'AER', 'AEC', 'Astro', 'BAM',
    'BCEC', 'BAE', 'CRC', 'CPM', 'CPC', 'Chon(Atp1a2+)', 'CP', 'CMN', 'DLN', 'DEEB(CD36-)', 'Derma',
    'Dermomyo', 'Dien', 'DRGN', 'DTelen', 'EChon', 'Endocardial', 'Endo', 'EN', 'EEMesen', 'EF',
    'FMesen', 'Fibro', 'FHF', 'FP3D', 'FGMesen', 'GABAci', 'GABAn', 'GISMC', 'GlutN', 'GPC',
    'GranKera', 'Gut', 'GMesen', 'HSC(CD34+)', 'HepMes', 'Hepato', 'Hind', 'Hypo', 'Hypo(Sim1+)',
    'INP', 'Kupf', 'LMP', 'LSEC', 'LungMes', 'LungP', 'LVEC', 'MegaK', 'Melano', 'Meninges',
    'MesoP(Tbx6+)', 'Mesothelial', 'MM', 'Micro', 'Mid', 'MHB', 'MusP', 'MusP(Prdm1+)',
    'MySch', 'MySch(Tgfb2+)', 'Myob', 'Myofibro', 'Myot', 'NSCP', 'NaiveRP', 'NC(PG)', 'NC(PN)',
    'NeuroP(ND1+)', 'Neurons(Slc17a8+)', 'OEC', 'OSN', 'OticEC', 'OSN', 'PPGP', 'PlacA', 'PRP',
    'PEK', 'PEry', 'PEpi', 'RPC(M+P)', 'RPC', 'RetPigC', 'Sclero', 'Sertoli', 'SMes',
    'SCV0', 'SCV1', 'SCV2a', 'SCV2b', 'SCV3', 'SCDP', 'SCMN', 'SCVP', 'SC(r7/r8)', 'SCdI1',
    'SCdI2', 'SCdI3', 'SCdI4', 'SCdI5', 'SCdI6', 'SpMeso', 'SymN', 'Telen', 'ThalNP', 'UB',
    'VSMC', 'VSMC(Pparg+)'
]

label_to_abbrev_ref_mapping = {l: a for l, a in zip(orig_ref_cols, e115_ref_sig.columns)}
trajectory_to_abbrev_ref_mapping = {
    key: [label_to_abbrev_ref_mapping[category] for category in categories] for key, categories in category_dict_ref.items()
}

In [None]:
def compute_cross_dataframe_proportions(df_a, df_b):
    # Ensure the DataFrames have the same number of rows
    if df_a.shape[0] != df_b.shape[0]:
        raise ValueError("DataFrames must have the same number of rows")

    # Initialize an empty DataFrame to store the results
    results_df = pd.DataFrame(index=df_a.columns, columns=df_b.columns)

    # Iterate over each column in df_a and df_b
    for col_a in tqdm(df_a.columns, desc="Processing column pairs..."):
        # Find rows where col_a is 1 in df_a- if col_a is duplicated, find the first occurrence
        first_col_a = df_a.columns.get_loc(col_a)
        if not isinstance(first_col_a, int):
            first_col_a = np.argmax(first_col_a)
        for col_b in df_b.columns:
            first_col_b = df_b.columns.get_loc(col_b)
            if not isinstance(first_col_b, int):
                first_col_b = np.argmax(first_col_b)
            rows_with_1_in_col_a = df_a[df_a.iloc[:, first_col_a] == 1]

            # Calculate the proportion of these rows where col_b is also 1 in df_b
            if not rows_with_1_in_col_a.empty:
                proportion = (df_b.loc[rows_with_1_in_col_a.index].iloc[:, first_col_b].sum()) / len(rows_with_1_in_col_a)
            else:
                proportion = 0.0  # No rows with 1 in col_a

            # Store the proportion in the results DataFrame- note if column already exists in DataFrame, create a new column
            results_df.loc[col_a, col_b] = proportion

    return results_df

In [None]:
# # First for the reference, check the proportions between the reference and itself:
# within_dataframe_proportions = compute_cross_dataframe_proportions(e115_ref_sig, e115_ref_sig)
# within_dataframe_proportions

In [None]:
cross_dataframe_proportions = compute_cross_dataframe_proportions(e115_ref_sig, e115_sig)
cross_dataframe_proportions

In [None]:
# Filter the reference cell types to only those that can be found among the mapping:
ref_cts_in_mapping = list(set(mapping.values()))
ref_cts_in_mapping_abbrev = [label_to_abbrev_ref_mapping[l] for l in ref_cts_in_mapping]
cross_dataframe_proportions = cross_dataframe_proportions.loc[ref_cts_in_mapping_abbrev, :]
cross_dataframe_proportions

In [None]:
# Set the order for rows & columns of the proportions dataframes
ref_order = ref_categories_ordered
ref_order_abbrev = [label_to_abbrev_ref_mapping[l] for l in ref_order]
ref_order_abbrev = [l for l in ref_order_abbrev if l in cross_dataframe_proportions.index]
spateo_order = categories_ordered
spateo_order_abbrev = [label_to_abbrev_mapping[l] for l in spateo_order]
spateo_order_abbrev = [l for l in spateo_order_abbrev if l in cross_dataframe_proportions.columns]

In [None]:
cross_dataframe_proportions = cross_dataframe_proportions.loc[ref_order_abbrev, spateo_order_abbrev]

In [None]:
# top_5_rows = cross_dataframe_proportions.sort_values(by='CCEC', ascending=False).head(5)
# top_5_rows

In [None]:
# To provide a bit more contrast in visualization
cross_dataframe_proportions **=2
# Normalize each column
cross_dataframe_proportions_scale = (cross_dataframe_proportions - cross_dataframe_proportions.min()) / (cross_dataframe_proportions.max() - cross_dataframe_proportions.min())

In [None]:
reverse_map = {value: key for key, values in trajectory_to_abbrev_mapping.items() for value in values}
reverse_map_ref = {value: key for key, values in trajectory_to_abbrev_ref_mapping.items() for value in values}

# Assign a unique color to each group- these are shared across both the reference and Stereo-seq
group_colors = sns.color_palette("tab20", len(trajectory_to_abbrev_mapping))
color_mapping = {group: color for group, color in zip(trajectory_to_abbrev_mapping.keys(), group_colors)}

In [None]:
# Row colors = Stereo-seq colors (after transposing)
row_colors = cross_dataframe_proportions_scale.columns.map(lambda x: color_mapping.get(reverse_map.get(x, '')))
# Column colors = scRNA-seq reference colors (after transposing)
col_colors = cross_dataframe_proportions_scale.index.map(lambda x: color_mapping.get(reverse_map_ref.get(x, '')))

In [None]:
import matplotlib.patches as patches

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 20))
#mask = cross_dataframe_proportions.astype(float) < 0.2
m = sns.heatmap(
    cross_dataframe_proportions_scale.T.astype(float),
    square=True,
    linecolor="grey",
    linewidths=0.3,
    cbar_kws={"location": "top"},
    cmap="magma",
    center=0.5,
    vmin=0,
    vmax=1,
    #mask=mask,
    ax=ax
)

# Remove the default row/column labels and ticks
ax.set_xticks([])
ax.set_yticks([])

# Adjust placement and size of rectangles for row colors (left side)
for i, color in enumerate(row_colors):
    rect = patches.Rectangle((-2, i), 2, 1, linewidth=0, edgecolor=None, facecolor=color, clip_on=False)
    ax.add_patch(rect)

# Adjust placement and size of rectangles for column colors (top side)
for i, color in enumerate(col_colors):
    rect = patches.Rectangle((i, cross_dataframe_proportions_scale.shape[1]), 1, 2, linewidth=0, edgecolor=None, facecolor=color, clip_on=False)
    ax.add_patch(rect)

for _, spine in m.spines.items():
    # spine.set_visible(True)
    # spine.set_linewidth(0.75)
    spine.set_visible(False)

# Adjust colorbar label font size
cbar = m.collections[0].colorbar
cbar.set_label("Normalized Fraction of Significantly Enriched Marker Genes Matching", fontsize=24)
# Adjust colorbar tick font size
cbar.ax.tick_params(labelsize=16)
cbar.ax.set_aspect(0.04)

plt.xlabel("Cell Types- scRNA-seq Reference", fontsize=30, labelpad=20)
plt.ylabel("Cell Types- Stereo-seq data", fontsize=30, labelpad=20)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.title("Comparison for E11.5- Stereo-seq v. scRNA-seq Reference", fontsize=28, pad=25)

# Create a custom legend for the row/column colors
legend_handles = [patches.Patch(color=color, label=key) for key, color in color_mapping.items()]
ax.legend(handles=legend_handles, title="Cell Type Groups", bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

plt.tight_layout()
plt.show()

# E9.5

## Load our data and the reference

In [None]:
e95_ref = anndata.read_h5ad("/home/danielyumengzhu/Single-cell-data/Shendure_reference/embryo_E9.5_ref_stereo-seq_gene_intersection.h5ad")
e95 = anndata.read_h5ad("/home/danielyumengzhu/Single-cell-data/cellbin_mouse_embryo_aligned_E9.5/E9.5_full.h5ad")
e95

In [None]:
e95_ref

In [None]:
# Map cell types to their major trajectories
major_trajectories = e95_ref.obs["major_trajectory"]
cell_types = e95_ref.obs["celltype_update"]

In [None]:
category_dict_ref = {}

for category in major_trajectories.unique():
    subset_obs2 = cell_types[major_trajectories == category]
    unique_obs2_categories = subset_obs2.unique().tolist()
    category_dict_ref[category] = unique_obs2_categories
category_dict_ref

In [None]:
ref_categories_ordered = []
for value_list in category_dict_ref.values():
    ref_categories_ordered.extend(value_list)

ref_categories_ordered

In [None]:
mapping = {
    "Amniotic ectoderm": "Amniotic ectoderm",
    "Anterior floor plate": "Anterior floor plate",
    "Anterior intermediate mesoderm": "Anterior intermediate mesoderm",
    "Anterior roof plate": "Anterior roof plate",
    "Arterial endothelial cells": "Arterial endothelial cells",
    "Cardiopharyngeal mesoderm": "Cardiopharyngeal mesoderm",
    "Cranial mesoderm": "Sclerotome",
    "Cranial motor neurons": "Cranial motor neurons",
    "Dermomyotome": "Dermomyotome",
    "Diencephalon neuroectoderm": "Diencephalon",
    "Dorsal telencephalon neuroectoderm": "Dorsal telencephalon",
    "Ectoderm-derived": "Placodal area",
    "Endocardial cells": "Endocardial cells",
    "Endothelium": "Endothelium",
    "Epithelial precursors": "Pre-epidermal keratinocytes",
    "Glutamatergic neurons": "Glutamatergic neurons",
    "Gut": "Gut",
    "Gut mesenchyme": "Gut mesenchyme",
    "Head and facial mesenchyme": "Facial mesenchyme",
    "Hematoendothelial progenitors": "Hematoendothelial progenitors",
    "Hepatic mesenchyme": "Hepatic mesenchyme",
    "Hepatocytes": "Hepatocytes",
    "Hindbrain neuroectoderm": "Hindbrain",
    "Hypothalamus neuroectoderm": "Hypothalamus",
    "Hypothalamus (Sim1+) neuroectoderm": "Hypothalamus (Sim1+)",
    "Intermediate mesoderm": "Posterior intermediate mesoderm",
    "Lateral plate mesoderm": "Lateral plate and intermediate mesoderm",
    "Limb progenitors and lateral plate mesoderm": "Limb mesenchyme progenitors",
    "Liver sinusoidal endothelial cell precursors": "Endothelium",
    "LV/atrioventricular canal/common atrium cardiomyocytes": "First heart field",
    "Mesodermal progenitors": "Mesodermal progenitors (Tbx6+)",
    "Midbrain neuroectoderm": "Midbrain",
    "Midbrain-hindbrain boundary": "Midbrain-hindbrain boundary",
    "Neural crest (PNS glia)": "Neural crest (PNS glia)",
    "Neural crest (PNS neurons)": "Neural crest (PNS neurons)",
    "Neural progenitor cells": "Neural progenitor cells (Neurod1+)",
    "Neuroectoderm-derived": "Neural progenitor cells (Neurod1+)",
    "NMPs and spinal cord progenitors": "NMPs and spinal cord progenitors",
    "Olfactory epithelial cells": "Olfactory epithelial cells",
    "Otic epithelial cells": "Otic epithelial cells",
    "Placode and neural crest-derived": "Placodal area",
    "Posterior roof plate": "Posterior roof plate",
    "Primitive erythroid cells": "Primitive erythroid cells",
    "Proepicardium": "Proepicardium",
    "Renal mesenchyme": "Renal pericytes and mesangial cells",
    "Sclerotome": "Sclerotome",
    "Second heart field-derived cardiomyocytes": "Second heart field",
    "Spinal cord dorsal progenitors": "Spinal cord dorsal progenitors",
    "Spinal cord motor neurons": "Spinal cord motor neurons",
    "Spinal cord neuroectoderm": "Spinal cord/r7/r8",
    "Spinal cord ventral progenitors": "Spinal cord ventral progenitors",
    "Telencephalon neuroectoderm": "Telencephalon",
    "Vessel endothelial cells": "Brain capillary endothelial cells"
}

In [None]:
# Map cell types to their major trajectories
cell_types = e95.obs["mapped_celltype"]

category_dict = {}

for trajectory_category, cell_type_category_list in category_dict_ref.items():
    corresponding_categories = []

    for ct_category in cell_type_category_list:
        matching_ct = [k for k, v in mapping.items() if v == ct_category]
        corresponding_categories.extend(matching_ct)

    # Store the result in the dictionary for the second object
    category_dict[trajectory_category] = corresponding_categories
    
category_dict

In [None]:
categories_ordered = []
for value_list in category_dict.values():
    categories_ordered.extend(value_list)

categories_ordered

## Find reference markers

In [None]:
# Marker genes for each cell type in the reference:
sc.pp.filter_cells(e95_ref, min_genes=200)
sc.pp.filter_genes(e95_ref, min_cells=3)
sc.pp.normalize_total(e95_ref, target_sum=1e4)
sc.pp.log1p(e95_ref)
e95_ref

In [None]:
sc.pp.highly_variable_genes(e95_ref, min_mean=0.0125, max_mean=3, min_disp=0.5)
e95_ref = e95_ref[:, e95_ref.var.highly_variable]
e95_ref

In [None]:
sc.tl.rank_genes_groups(e95_ref, 'celltype_Spateo_match', method='t-test')
marker_genes = e95_ref.uns['rank_genes_groups']

In [None]:
n_genes = 50
groups = marker_genes['names'].dtype.names
top_genes = {group: list(marker_genes['names'][group][:n_genes]) for group in groups}
top_genes_df = pd.DataFrame(top_genes)
top_genes_df

### Filter to epithelial cells and mesenchymal cells to get markers of these cell types

In [None]:
e95_ref = anndata.read_h5ad("/home/danielyumengzhu/Single-cell-data/Shendure_reference/embryo_E9.5_ref_stereo-seq_gene_intersection.h5ad")
e95_ref_sub = e95_ref[e95_ref.obs["major_trajectory"].isin(["Epithelial_cells", "Mesoderm"])].copy()

In [None]:
# Marker genes for each cell type in the reference:
sc.pp.filter_cells(e95_ref_sub, min_genes=200)
sc.pp.filter_genes(e95_ref_sub, min_cells=3)
sc.pp.normalize_total(e95_ref_sub, target_sum=1e4)
sc.pp.log1p(e95_ref_sub)
e95_ref_sub

In [None]:
sc.pp.highly_variable_genes(e95_ref_sub, min_mean=0.0125, max_mean=3, min_disp=0.5)
e95_ref_sub = e95_ref_sub[:, e95_ref_sub.var.highly_variable]
e95_ref_sub

In [None]:
sc.tl.rank_genes_groups(e95_ref_sub, 'celltype_Spateo_match', method='t-test')
marker_genes = e95_ref_sub.uns['rank_genes_groups']

In [None]:
n_genes = 50
groups = marker_genes['names'].dtype.names
top_genes = {group: list(marker_genes['names'][group][:n_genes]) for group in groups}
top_genes_df_sub = pd.DataFrame(top_genes)
top_genes_df_sub

In [None]:
top_genes_df.loc[:, top_genes_df_sub.columns] = top_genes_df_sub

In [None]:
top_genes_df.to_csv("/home/danielyumengzhu/E9.5_reference_marker_genes.csv")

In [None]:
top_genes_df = pd.read_csv("/home/danielyumengzhu/E9.5_reference_marker_genes.csv", index_col=0)
top_genes_df

In [None]:
# Get all unique genes in the DataFrame
all_genes = pd.unique(top_genes_df.values.ravel())
all_genes = all_genes[pd.notna(all_genes)]  # Remove NaN values

# Initialize a DataFrame to store mean expression values
mean_expression_df = pd.DataFrame(index=all_genes)

for cell_type in tqdm(top_genes_df.columns):
    # Subset the expression data for this cell type
    ct_data = e95_ref[e95_ref.obs["celltype_Spateo_match"] == cell_type].copy()
    expression_data = ct_data[:, all_genes].X
    # Check if the expression data is sparse and convert to dense if necessary
    if scipy.sparse.issparse(expression_data):
        expression_data = expression_data.toarray()

    # Compute mean expression across cells (rows)
    mean_expression = np.mean(expression_data, axis=0)

    # Add to the DataFrame
    mean_expression_df[cell_type] = mean_expression
mean_expression_df = (mean_expression_df - mean_expression_df.min()) / (mean_expression_df.max() - mean_expression_df.min())
mean_expression_df

In [None]:
mean_expression_df = mean_expression_df.dropna(axis=1, how='all')
mean_expression_df

In [None]:
mean_expression_df.to_csv("/home/danielyumengzhu/E9.5_reference_mean_expr.csv")

In [None]:
# Significance test:
# Initialize a DataFrame to store p-values for each gene in each column
p_values_df = pd.DataFrame(index=mean_expression_df.index, columns=mean_expression_df.columns)

for column in tqdm(mean_expression_df.columns):
    markers = list(top_genes_df[column].values)
    include = [g for g in mean_expression_df.index if g not in markers]
    for gene in mean_expression_df.index:
        # Count how many genes have a lower expression value than the current gene
        lower_count = (mean_expression_df.loc[include, column] > mean_expression_df.loc[gene, column]).sum()
        
        # Compute the proportion (p-value)
        p_value = lower_count / (len(mean_expression_df.index) - 1) # -1 to exclude the gene itself

        # Store the computed p-value
        p_values_df.loc[gene, column] = p_value
p_values_df

In [None]:
p_values_df.to_csv("/home/danielyumengzhu/E9.5_reference_p_vals.csv")

In [None]:
sig_df = (p_values_df < 0.05).astype(int)

In [None]:
sig_df.to_csv("/home/danielyumengzhu/E9.5_reference_significant.csv")

## Load processed reference information

In [None]:
ref_markers = pd.read_csv("/home/danielyumengzhu/E9.5_reference_marker_genes.csv", index_col=0)
mean_expression_df = pd.read_csv("/home/danielyumengzhu/E9.5_reference_mean_expr.csv", index_col=0)
p_values_df = pd.read_csv("/home/danielyumengzhu/E9.5_reference_p_vals.csv", index_col=0)
sig_df = pd.read_csv("/home/danielyumengzhu/E9.5_reference_significant.csv", index_col=0)

## Process the E9.5 data w/ respect to the reference markers

In [None]:
all_genes = pd.unique(ref_markers.values.ravel())
all_genes = all_genes[pd.notna(all_genes)]  # Remove NaN values

# Initialize a DataFrame to store mean expression values
mean_expression_df = pd.DataFrame(index=all_genes)

for cell_type in tqdm(ref_markers.columns):
    keys_for_ct = [key for key, value in mapping.items() if value == cell_type]
    
    # Repeat the process for each key
    for key in keys_for_ct:
        ct = key
        # Subset the expression data for this cell type
        ct_data = e95[e95.obs["mapped_celltype"] == ct].copy()
        expression_data = ct_data[:, all_genes].X
        
        # Check if the expression data is sparse and convert to dense if necessary
        if scipy.sparse.issparse(expression_data):
            expression_data = expression_data.toarray()
        
        if expression_data.shape[0] == 0:
            print(ct)
        
        # Compute mean expression across cells (rows)
        mean_expression = np.mean(expression_data, axis=0)
        
        # Add to the DataFrame, using the current `key` instead of `cell_type`
        mean_expression_df[key] = mean_expression

mean_expression_df = (mean_expression_df - mean_expression_df.min()) / (mean_expression_df.max() - mean_expression_df.min())
mean_expression_df

In [None]:
mean_expression_df.to_csv("/home/danielyumengzhu/E9.5_Stereo-seq_mean_expr.csv")

In [None]:
# Significance test:
# Initialize a DataFrame to store p-values for each gene in each column
p_values_df = pd.DataFrame(index=mean_expression_df.index, columns=mean_expression_df.columns)

for column in tqdm(mean_expression_df.columns):
    ct = mapping[column]
    markers = list(ref_markers[ct].values)
    include = [g for g in mean_expression_df.index if g not in markers]
    for gene in mean_expression_df.index:
        # Count how many genes have a lower expression value than the current gene
        lower_count = (mean_expression_df.loc[include, column] > mean_expression_df.loc[gene, column]).sum()
        
        # Compute the proportion (p-value)
        p_value = lower_count / (len(mean_expression_df.index) - 1) # -1 to exclude the gene itself

        # Store the computed p-value
        p_values_df.loc[gene, column] = p_value
p_values_df

In [None]:
p_values_df.to_csv("/home/danielyumengzhu/E9.5_Stereo-seq_p_vals.csv")

In [None]:
sig_df = (p_values_df < 0.05).astype(int)

In [None]:
sig_df.to_csv("/home/danielyumengzhu/E9.5_Stereo-seq_significant.csv")

## Load significance dataframes for E9.5 and the reference

In [None]:
e95_sig = pd.read_csv("/home/danielyumengzhu/E9.5_Stereo-seq_significant.csv", index_col=0)
e95_ref_sig = pd.read_csv("/home/danielyumengzhu/E9.5_reference_significant.csv", index_col=0)

In [None]:
orig_cols = e95_sig.columns

In [None]:
orig_ref_cols = list(e95_ref_sig.columns)

In [None]:
orig_ref_cols

In [None]:
# Shorten names for display:
e95_sig.columns = [
    'AmniEcto', 'AFP', 'AIM', 'ARP', 'AEC', 'VEC', 'CMN', 'Dermomyo', 'DienNE', 
    'DTelenNE', 'Endocardial', 'Endo', 'LSECP', 'HFM', 'LVACCA', 'GlutN', 'Gut', 
    'HEP', 'Hepato', 'HindNE', 'HypoNE', 'Hypo(Sim1+)NE', 'LPM', 'LimbMesen', 
    'MesoP', 'MidNE', 'MHB', 'NMPSCP', 'NCPG', 'NCPN', 'NeuroP', 'NeuroDer', 
    'OEC', 'OticEC', 'EctoDer', 'PNCD', 'IM', 'PRP', 'EpiP', 'PEC', 'CM', 
    'Sclero', 'SHFDC', 'SCDP', 'SCMN', 'SCVP', 'SCNE', 'TelenNE', 'CPM', 
    'RenMes', 'HepMes', 'GutMes', 'ProEpi'
]

label_to_abbrev_mapping = {l: a for l, a in zip(orig_cols, e95_sig.columns)}
trajectory_to_abbrev_mapping = {
    key: [label_to_abbrev_mapping[category] for category in categories] for key, categories in category_dict.items()
}

e95_ref_sig.columns = [
    'AmniEcto', 'AFP', 'AIM', 'ARP', 'AER', 'AEC', 'BCECs', 'CMN', 
    'Dermomyo', 'Dien', 'DTelen', 'Endocardial', 'Endo', 'EntericN', 
    'EyeF', 'FM', 'FHF', 'FP3D', 'GABAergN', 'GlutN', 'Gut', 
    'HEP', 'HSC(CD34+)', 'Hepato', 'Hind', 'Hypo', 'Hypo(Sim1+)', 
    'LPIM', 'LimbMesen', 'Megakaryo', 'MesoP(Tbx6+)', 'Mid', 
    'MHB', 'MuscP', 'MuscP(Prdm1+)', 'NMPSCP', 'NCPG', 'NCPN', 
    'NeuroP(Neurod1+)', 'Noto', 'OEC', 'OlfPitC', 'OticEC', 
    'OticSensN', 'Peri', 'PA', 'PIM', 'PRP', 'PreEpidK', 'PEC', 
    'Sclero', 'SHF', 'SCDP', 'SCMN', 'SCVP', 'SCR7R8', 'Telen', 
    'CPM', 'RenPeriMes', 'HepMes', 'GutMes', 'ProEpi'
]

label_to_abbrev_ref_mapping = {l: a for l, a in zip(orig_ref_cols, e95_ref_sig.columns)}
trajectory_to_abbrev_ref_mapping = {
    key: [label_to_abbrev_ref_mapping[category] for category in categories] for key, categories in category_dict_ref.items()
}

In [None]:
def compute_cross_dataframe_proportions(df_a, df_b):
    # Ensure the DataFrames have the same number of rows
    if df_a.shape[0] != df_b.shape[0]:
        raise ValueError("DataFrames must have the same number of rows")

    # Initialize an empty DataFrame to store the results
    results_df = pd.DataFrame(index=df_a.columns, columns=df_b.columns)

    # Iterate over each column in df_a and df_b
    for col_a in tqdm(df_a.columns, desc="Processing column pairs..."):
        # Find rows where col_a is 1 in df_a- if col_a is duplicated, find the first occurrence
        first_col_a = df_a.columns.get_loc(col_a)
        if not isinstance(first_col_a, int):
            first_col_a = np.argmax(first_col_a)
        for col_b in df_b.columns:
            first_col_b = df_b.columns.get_loc(col_b)
            if not isinstance(first_col_b, int):
                first_col_b = np.argmax(first_col_b)
            rows_with_1_in_col_a = df_a[df_a.iloc[:, first_col_a] == 1]

            # Calculate the proportion of these rows where col_b is also 1 in df_b
            if not rows_with_1_in_col_a.empty:
                proportion = (df_b.loc[rows_with_1_in_col_a.index].iloc[:, first_col_b].sum()) / len(rows_with_1_in_col_a)
            else:
                proportion = 0.0  # No rows with 1 in col_a

            # Store the proportion in the results DataFrame- note if column already exists in DataFrame, create a new column
            results_df.loc[col_a, col_b] = proportion

    return results_df

In [None]:
# # First for the reference, check the proportions between the reference and itself:
# within_dataframe_proportions = compute_cross_dataframe_proportions(e95_ref_sig, e95_ref_sig)
# within_dataframe_proportions **= 2
# within_dataframe_proportions

In [None]:
cross_dataframe_proportions = compute_cross_dataframe_proportions(e95_ref_sig, e95_sig)
# For visualization
cross_dataframe_proportions **= 3
cross_dataframe_proportions

In [None]:
# Filter the reference cell types to only those that can be found among the mapping:
ref_cts_in_mapping = list(set(mapping.values()))
ref_cts_in_mapping_abbrev = [label_to_abbrev_ref_mapping[l] for l in ref_cts_in_mapping]
cross_dataframe_proportions = cross_dataframe_proportions.loc[ref_cts_in_mapping_abbrev, :]
cross_dataframe_proportions

In [None]:
# Set the order for rows & columns of the proportions dataframes
ref_order = ref_categories_ordered
ref_order_abbrev = [label_to_abbrev_ref_mapping[l] for l in ref_order]
ref_order_abbrev = [l for l in ref_order_abbrev if l in cross_dataframe_proportions.index]
spateo_order = categories_ordered
spateo_order_abbrev = [label_to_abbrev_mapping[l] for l in spateo_order]
spateo_order_abbrev = [l for l in spateo_order_abbrev if l in cross_dataframe_proportions.columns]

In [None]:
cross_dataframe_proportions = cross_dataframe_proportions.loc[ref_order_abbrev, spateo_order_abbrev]

In [None]:
# Normalize each column
cross_dataframe_proportions_scale = (cross_dataframe_proportions - cross_dataframe_proportions.min()) / (cross_dataframe_proportions.max() - cross_dataframe_proportions.min())

In [None]:
reverse_map = {value: key for key, values in trajectory_to_abbrev_mapping.items() for value in values}
reverse_map_ref = {value: key for key, values in trajectory_to_abbrev_ref_mapping.items() for value in values}

# Assign a unique color to each group- these are shared across both the reference and Stereo-seq
group_colors = sns.color_palette("tab20", len(trajectory_to_abbrev_mapping))
color_mapping = {group: color for group, color in zip(trajectory_to_abbrev_mapping.keys(), group_colors)}

In [None]:
# Row colors = Stereo-seq colors (after transposing)
row_colors = cross_dataframe_proportions_scale.columns.map(lambda x: color_mapping.get(reverse_map.get(x, '')))
# Column colors = scRNA-seq reference colors (after transposing)
col_colors = cross_dataframe_proportions_scale.index.map(lambda x: color_mapping.get(reverse_map_ref.get(x, '')))

In [None]:
import matplotlib.patches as patches

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 20))
#mask = cross_dataframe_proportions.astype(float) < 0.2
m = sns.heatmap(
    cross_dataframe_proportions_scale.T.astype(float),
    square=True,
    linecolor="grey",
    linewidths=0.3,
    cbar_kws={"location": "top"},
    cmap="magma",
    center=0.5,
    vmin=0,
    vmax=1,
    #mask=mask,
    ax=ax
)

# Remove the default row/column labels and ticks
ax.set_xticks([])
ax.set_yticks([])

# Adjust placement and size of rectangles for row colors (left side)
for i, color in enumerate(row_colors):
    rect = patches.Rectangle((-2, i), 2, 1, linewidth=0, edgecolor=None, facecolor=color, clip_on=False)
    ax.add_patch(rect)

# Adjust placement and size of rectangles for column colors (top side)
for i, color in enumerate(col_colors):
    rect = patches.Rectangle((i, cross_dataframe_proportions_scale.shape[1]), 1, 2, linewidth=0, edgecolor=None, facecolor=color, clip_on=False)
    ax.add_patch(rect)

for _, spine in m.spines.items():
    # spine.set_visible(True)
    # spine.set_linewidth(0.75)
    spine.set_visible(False)

# Adjust colorbar label font size
cbar = m.collections[0].colorbar
cbar.set_label("Normalized Fraction of Significantly Enriched Marker Genes Matching", fontsize=18)
# Adjust colorbar tick font size
cbar.ax.tick_params(labelsize=16)
cbar.ax.set_aspect(0.04)

plt.xlabel("Cell Types- scRNA-seq Reference", fontsize=30, labelpad=35)
plt.ylabel("Cell Types- Stereo-seq data", fontsize=30, labelpad=35)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.title("Comparison for E9.5- Stereo-seq v. scRNA-seq Reference", fontsize=22, pad=25)

# Create a custom legend for the row/column colors
legend_handles = [patches.Patch(color=color, label=key) for key, color in color_mapping.items()]
ax.legend(handles=legend_handles, title="Cell Type Groups", bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

plt.tight_layout()
plt.show()