# Mean Euc Distance to all same type cells

In [1]:
folder_date = "2025-07-03"

In [2]:
import sys
import os
import scib
import math
import warnings
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
plt.style.use("../nature.mplstyle")
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist, jensenshannon
from scipy.special import kl_div
from scipy.stats import wasserstein_distance
from sklearn.neighbors import NearestNeighbors

In [3]:
warnings.filterwarnings('ignore')

In [4]:
results_dir = f"../results/{folder_date}/out"
os.makedirs(results_dir, exist_ok=True)
results_dir

'../results/2025-07-03/out'

In [5]:
# Add the parent directory to sys.path
sys.path.append(os.path.abspath('..'))

from graph_connectivity_per_celltype import graph_connectivity_per_celltype
from isolated_labels import isolated_labels_asw

In [6]:
scenario = "2025-04-17_scenario_01_all_celltypes_evenly_distr_across_batches.h5ad"
sc_order = "01"

sc_order, scenario

('01', '2025-04-17_scenario_01_all_celltypes_evenly_distr_across_batches.h5ad')

In [7]:
# create folder for this scenario
out_dir = f"{results_dir}/{sc_order}"
os.makedirs(out_dir, exist_ok=True)

out_dir

'../results/2025-07-03/out/01'

In [8]:
# adjust col_order based on the dataset
col_order = [
    '#batch_1',      '#batch_2',        'weight',
    'KL loc|glob',  'max KL loc|glob',  'nKL loc|glob', 'wKL loc|glob', 'wnKL loc|glob',
    'KL glob|loc',  'max KL glob|loc',  'nKL glob|loc', 'wKL glob|loc', 'wnKL glob|loc',
    'JS Dist',      'max JS Dist',      'nJS Dist',     'wJS Dist',     'wnJS Dist',
    'JS Div',       'max JS Div',       'nJS Div',      'wJS Div',      'wnJS Div',
    'TV',           'max TV',           'nTV',          'wTV',          'wnTV', 
    'H',            'max H',            'nH',           'wH',           'wnH', 
    'chiSD',        'max chiSD',        'nChiSD',       'wChiSD',       'wnChiSD',
    'WD',           'max WD',           'nWD',          'wWD',          'wnWD',
    'cLISI',
    'iLISI',        'n_iLISI',          'w_nILISI',
    'ASW',          '1-ASW',            'wASW',
    'kBET',         'wkBET',            'PCR',
    'graph_conn',   '1-graph_conn',     'wGraph_conn',
    'isolated_labels_f1',               'isolated_labels_asw'
    ]

col_order

['#batch_1',
 '#batch_2',
 'weight',
 'KL loc|glob',
 'max KL loc|glob',
 'nKL loc|glob',
 'wKL loc|glob',
 'wnKL loc|glob',
 'KL glob|loc',
 'max KL glob|loc',
 'nKL glob|loc',
 'wKL glob|loc',
 'wnKL glob|loc',
 'JS Dist',
 'max JS Dist',
 'nJS Dist',
 'wJS Dist',
 'wnJS Dist',
 'JS Div',
 'max JS Div',
 'nJS Div',
 'wJS Div',
 'wnJS Div',
 'TV',
 'max TV',
 'nTV',
 'wTV',
 'wnTV',
 'H',
 'max H',
 'nH',
 'wH',
 'wnH',
 'chiSD',
 'max chiSD',
 'nChiSD',
 'wChiSD',
 'wnChiSD',
 'WD',
 'max WD',
 'nWD',
 'wWD',
 'wnWD',
 'cLISI',
 'iLISI',
 'n_iLISI',
 'w_nILISI',
 'ASW',
 '1-ASW',
 'wASW',
 'kBET',
 'wkBET',
 'PCR',
 'graph_conn',
 '1-graph_conn',
 'wGraph_conn',
 'isolated_labels_f1',
 'isolated_labels_asw']

In [9]:
%load_ext rpy2.ipython

In [10]:
def knn(df, k=90, metric="euclidean", include_self=False):
    """
    Finds k-nearest neighbors for each cell.
    
    Parameters:
    -----------
    df : pd.DataFrame (rows=cells, cols=features)
    k : # neighbors to return per cell
    metric : 'euclidean', 'cosine' etc
    include_self : True or False; include the cell itself as its 1st neighbor

    Returns:
    --------
    indices : np.array()
        indices of k-nearest neighbors for each cell
    distances : np.array()
        distances of k-nearest neighbors for each cell
    """
    # if we are not including cell itself add extra neighbor
    # because NearestNeighbors by default include itself as neighbor
    k = k if include_self else k + 1

    X = df.to_numpy(dtype=np.float32, copy=False)

    nn = NearestNeighbors(
        n_neighbors=k,
        metric=metric,
        algorithm='auto',
        n_jobs=-1   # use all cores
    ).fit(X)
    
    distances, indices = nn.kneighbors(X, return_distance=True)

    # drop cell itself from neighbors
    if not include_self:
        distances = distances[:, 1:]
        indices = indices[:, 1:]

    return indices, distances


In [11]:
def mean_dists_to_same_type_cells(pca_coords, cell_types):
    """
    Calculate the mean distance of each cell to other cells of the same type.
    
    Parameters
    ----------
    pca_coords : np.ndarray
        PCA coordinates matrix (cells x 50).
    cell_types : pd.Series
        Series containing the cell type for each cell.
    
    Returns
    -------
    np.ndarray
        Array of mean distances for each cell to other cells of the same type.
    """

    mean_dists = np.zeros(pca_coords.shape[0])

    for cell_type in cell_types.unique():
        
        mask = cell_types == cell_type
        type_expression = pca_coords[mask]

        # calculate all pairwise distances for the current cell type
        dists = cdist(type_expression, type_expression, metric='euclidean')

        # mean excluding diagonal (self-distances)
        np.fill_diagonal(dists, np.nan)
        type_mean_dists = np.nanmean(dists, axis=1)

        # assign back to original positions
        mean_dists[mask] = type_mean_dists

    
    min_dist = np.nanmin(mean_dists)

    # normalize distances so the minimum distance is 1
    normalized_dists = mean_dists / min_dist
    
    return normalized_dists

In [12]:
# adjust it

def plot_emb_and_distrs(
        scenario, coords,
        cell_types, batches,
        local_dist, global_dist,
        out_file, emb="pca",
        local_dist_k60=None, local_dist_k30=None
):
    """
    Plot
    - 1st row: UMAP/PCA in square frame
    - 2nd row: global distr (left) and local k=90 by source batch (right)
    - 3rd row: local k=60 (left) and local k=30 (right), both by source batch

    Parameters
    ----------
    scenario : str
        The name of the scenario being analyzed.
    coords : np.ndarray
        UMAP/PCA coordinates (cells x 2).
    cell_types : pd.Series
        Series containing the cell type for each cell.
    batches : pd.Series
        Series containing the batch information for each cell.
    local_dist : pd.DataFrame
        Batch composition of each cell's neighborhood (k=90)
    global_dist : pd.DataFrame
        Cell IDs are stored in indices and columns are unique batch labels
    out_file : str
        Output file path for the plot.
    emb : str
        Embedding method used ('umap' or 'pca')
    local_dist_k60 : pd.DataFrame
        Batch composition of each cell's neighborhood (k=60)
    local_dist_k30 : pd.DataFrame
        Batch composition of each cell's neighborhood (k=30)
    """

    # color palette and markers
    base_palette = list(plt.cm.Set1.colors)
    markers = ['o', 's', '^', 'v', '<', '>', 'D', 'p', 'h', '*', '+', 'x', '|', '_', '.', '1', '2', '3', '4', '8', 'H', 'P']

    unique_batches = sorted(batches.unique())
    unique_cell_types = sorted(cell_types.unique())
    batch_cols = list(local_dist.columns)   # column order for stacked bars

    # unified batch_to_color mapping for all subplots
    ordered_batches = batch_cols + [b for b in unique_batches if b not in batch_cols]
    # cycle through Set1 if more batches than colors
    batch_to_color = {b: base_palette[i % len(base_palette)] for i, b in enumerate(ordered_batches)}

    # figure and layout; 3 rows 
    fig = plt.figure(figsize=(16, 18))
    outer = fig.add_gridspec(3, 1, height_ratios=[3, 1, 1], hspace=0.45)

    # 1st row: full width (pca/umap)
    ax = fig.add_subplot(outer[0])

    # 2nd row: global distr (left) and local k=90 (right), with gutters
    row2 = outer[1].subgridspec(1, 4, wspace=0.25, width_ratios=[0.6, 1.0, 1.0, 0.6])
    ax_global    = fig.add_subplot(row2[1])    # left plot
    ax_local_k90 = fig.add_subplot(row2[2]) # right plot

    # 3rd row: k=60 (left) and k=30 (right)
    row3 = outer[2].subgridspec(1, 4, wspace=0.25, width_ratios=[0.6, 1.0, 1.0, 0.6])
    ax_k60 = fig.add_subplot(row3[1])
    ax_k30 = fig.add_subplot(row3[2])

    # scatter (pca/umap)
    for i, batch in enumerate(unique_batches):
        for j, ct in enumerate(unique_cell_types):
            mask = (batches == batch) & (cell_types == ct)
            # only plot if there are cells
            if not np.any(mask):
                continue
            pts = coords[mask]

            """
            # subsample if too many points
            if pts.shape[0] > 200:
                np.random.seed(42)
                idx = np.random.choice(pts.shape[0], 200, replace=False)
                pts = pts[idx]
            """

            ax.scatter(
                pts[:, 0], pts[:, 1],
                c=batch_to_color[batch],
                marker=markers[j%len(markers)],
                s=20,                        # Slightly larger size makes shape distinguishable
                alpha=0.6,                   # Higher alpha helps shape pop, still keeps density visible
                edgecolors='black',         # Brings out marker shape
                linewidth=0.2,              # Thin border for definition without being distracting
                zorder=2,
                rasterized=True,            # Keep it if exporting to PDF
                label=f'{batch}_{ct}' if i == 0 and j == 0 else ""
            )

    # square box for scatter plot
    x_min, x_max = coords[:, 0].min(), coords[:, 0].max()
    y_min, y_max = coords[:, 1].min(), coords[:, 1].max()
    x_c = 0.5 * (x_min + x_max); y_c = 0.5 * (y_min + y_max)
    span = max(x_max - x_min, y_max - y_min); pad = 0.05 * span; half = 0.5 * span + pad
    ax.set_xlim(x_c - half, x_c + half); ax.set_ylim(y_c - half, y_c + half)
    ax.set_box_aspect(1)

    # legends
    batch_legend = [
        plt.Line2D([0], [0], marker='o', color='w',
                   markerfacecolor=batch_to_color[b],
                   markersize=8, label=b)
        for b in ordered_batches
    ]
    celltype_legend = [
        plt.Line2D([0], [0], marker=markers[j % len(markers)],
                   color='white', markeredgecolor='black',
                   markersize=8, label=ct, linestyle='None')
        for j, ct in enumerate(unique_cell_types)
    ]
    leg1 = ax.legend(handles=batch_legend, title='Batch',
                     loc='upper left', bbox_to_anchor=(1.02, 1),
                     fontsize=10, title_fontsize=11)
    leg2 = ax.legend(handles=celltype_legend, title='Cell type',
                     loc='upper left', bbox_to_anchor=(1.02, 0.6),
                     fontsize=10, title_fontsize=11)
    ax.add_artist(leg1)

    # ax.set_title(f"Sc_{scenario} (facLoc={facloc}, facScale={facscale}); {emb}: Batch (colors) & Cell Type (markers)", fontsize=10)
    ax.set_title(f"Sc_{scenario}; {emb}: Batch (colors) & Cell Type (markers)", fontsize=10)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlabel(f"{emb} 1", fontsize=10)
    ax.set_ylabel(f"{emb} 2", fontsize=10)

    # helpers
    def _local_props_by_source(local_df):
        """Average neighbor composition per (celltype, source_batch) then row-normalize"""
        # ensure same columns/order as batch_cols
        local_df = local_df.reindex(columns=batch_cols, fill_value=0)
        tmp = local_df.copy()
        tmp["celltype"] = cell_types.values
        tmp["source_batch"] = batches.values
        means = tmp.groupby(["celltype", "source_batch"])[batch_cols].mean()
        props = means.div(means.sum(axis=1), axis=0).fillna(0)

        # order bars
        idx = []
        for ct in unique_cell_types:
            for b in ordered_batches:
                idx.append((ct, b))
        if idx:
            props = props.loc[idx]
        return props
    
    def _plot_local_props(ax_, props, title):
        props.plot(kind="bar", stacked=True, ax=ax_,
                   color=[batch_to_color[b] for b in batch_cols],
                   width=0.85, legend=False)
        ax_.set_title(title, fontsize=10)
        ax_.set_ylabel("Proportion of Neighbors", fontsize=9)
        ax_.set_xlabel("group (source batch)", fontsize=9)
        ax_.set_xticklabels([f"{ct}\n({sb})" for (ct, sb) in props.index],
                            rotation=90, ha="right")
        ax_.grid(axis='y', alpha=0.3)
        ax_.tick_params(axis="x", labelsize=7)
        ax_.tick_params(axis="y", labelsize=7)

    # 2nd row: global (left)
    global_with_type = global_dist.reindex(columns=batch_cols, fill_value=0).copy()
    global_with_type["celltype"] = cell_types.values
    global_means = global_with_type.groupby("celltype")[batch_cols].mean()
    global_props = global_means.div(global_means.sum(axis=1), axis=0).fillna(0)
    global_props.plot(kind="bar", stacked=True, ax=ax_global,
                      color=[batch_to_color[b] for b in batch_cols],
                      width=0.85, legend=False)
    ax_global.set_title("Global distribution", fontsize=10)
    ax_global.set_ylabel("Proportion of Cells", fontsize=9)
    ax_global.set_xlabel("celltype", fontsize=9)
    ax_global.tick_params(axis="x", rotation=90, labelsize=7)
    ax_global.tick_params(axis="y", labelsize=7)
    ax_global.grid(axis='y', alpha=0.3)

    # 2nd row: local k=90 (right)
    props90 = _local_props_by_source(local_df=local_dist)
    _plot_local_props(ax_=ax_local_k90, props=props90, title="Local (k=90): by source batch")

    # 3rd row: local k=60 (left) and k=30 (right)
    if local_dist_k60 is not None:
        props60 = _local_props_by_source(local_dist_k60)
        _plot_local_props(ax_=ax_k60, props=props60, title="Local (k=60): by source batch")
    else:
        ax_k60.axis('off')
        ax_k60.text(0.5, 0.5, "local_dist_k60 not provided",
                    ha='center', va='center', fontsize=10)
        
    if local_dist_k30 is not None:
        props30 = _local_props_by_source(local_dist_k30)
        _plot_local_props(ax_=ax_k30, props=props30, title="Local (k=30): by source batch")
    else:
        ax_k30.axis('off')
        ax_k30.text(0.5, 0.5, "local_dist_k30 not provided",
                    ha='center', va='center', fontsize=10)
        
    plt.tight_layout()
    plt.savefig(out_file, dpi=300, bbox_inches="tight")
    # plt.show()

In [13]:
def plot_emb_and_distrs_pub(
        scenario, coords,
        cell_types, batches,
        local_dist, global_dist,
        out_file, emb="pca",
        local_dist_k60=None, local_dist_k30=None,
        export_png=False, png_dpi=600
):
    """
    Nature-ready layout with horizontal bars for rows 2 & 3
    Row1 (A): Embedding (square)
    Row2 (B,C): Global distribution (barh) | Local k=90 (barh)
    Row3 (D,E): Local k=60 (barh) | Local k=30 (barh)
    """

    # --- Palette: Okabe–Ito (colorblind-safe) ---
    okabe_ito = ["#0072B2", "#E69F00", "#009E73", "#D55E00",
                 "#CC79A7", "#56B4E9", "#F0E442", "#000000"]

    unique_batches = list(pd.Index(batches.unique()).sort_values())
    unique_cell_types = list(pd.Index(cell_types.unique()).sort_values())

    # Preserve incoming local_dist column order, then append any missing batches
    batch_cols = list(local_dist.columns)
    ordered_batches = batch_cols + [b for b in unique_batches if b not in batch_cols]
    batch_to_color = {b: okabe_ito[i % len(okabe_ito)] for i, b in enumerate(ordered_batches)}

    # Figure (let the mplstyle control figure size)
    fig = plt.figure(constrained_layout=True)
    outer = fig.add_gridspec(3, 1, height_ratios=[3.2, 1.6, 1.6])

    # Axes
    ax = fig.add_subplot(outer[0])
    row2 = outer[1].subgridspec(1, 2, wspace=0.3)
    ax_global    = fig.add_subplot(row2[0])
    ax_local_k90 = fig.add_subplot(row2[1])
    row3 = outer[2].subgridspec(1, 2, wspace=0.3)
    ax_k60 = fig.add_subplot(row3[0])
    ax_k30 = fig.add_subplot(row3[1])

    # ----- Embedding scatter: batch=color, cell type=marker -----
    markers = ['o', 's', '^', 'v', '<', '>', 'D', 'p', 'h', '*', '+', 'x', 'P', 'H', '1', '2', '3', '4', '.', '|', '_']
    for j, ct in enumerate(unique_cell_types):
        for i, b in enumerate(unique_batches):
            mask = (cell_types == ct) & (batches == b)
            if not np.any(mask):
                continue
            pts = coords[mask.values] if isinstance(mask, pd.Series) else coords[mask]
            ax.scatter(
                pts[:, 0], pts[:, 1],
                s=10,
                marker=markers[j % len(markers)],
                facecolors=batch_to_color[b],
                edgecolors='black', linewidths=0.15,
                alpha=0.85, zorder=2, rasterized=True
            )

    # Square frame
    x_min, x_max = np.min(coords[:, 0]), np.max(coords[:, 0])
    y_min, y_max = np.min(coords[:, 1]), np.max(coords[:, 1])
    x_c, y_c = 0.5 * (x_min + x_max), 0.5 * (y_min + y_max)
    span = max(x_max - x_min, y_max - y_min)
    pad = 0.05 * span
    half = 0.5 * span + pad
    ax.set_xlim(x_c - half, x_c + half)
    ax.set_ylim(y_c - half, y_c + half)
    ax.set_box_aspect(1)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_xlabel(f"{emb} 1"); ax.set_ylabel(f"{emb} 2")

    # Shared legends (outside)
    batch_handles = [Line2D([0], [0], marker='o', linestyle='None',
                            markerfacecolor=batch_to_color[b], markeredgecolor='black',
                            markersize=6, label=b)
                     for b in ordered_batches]
    ct_handles = [Line2D([0], [0],
                         marker=markers[j % len(markers)], linestyle='None',
                         markerfacecolor='white', markeredgecolor='black',
                         markersize=6, label=ct)
                  for j, ct in enumerate(unique_cell_types)]
    leg1 = ax.legend(handles=batch_handles, title="Batch",
                     loc='upper left', bbox_to_anchor=(1.02, 1.0), borderaxespad=0.)
    leg2 = ax.legend(handles=ct_handles, title="Cell type",
                     loc='upper left', bbox_to_anchor=(1.02, 0.62), borderaxespad=0.)
    ax.add_artist(leg1)

    # Panel label
    ax.text(-0.08, 1.02, "A", transform=ax.transAxes,
            fontsize=11, fontweight='bold', va='bottom', ha='right')

    # ----- Helpers -----
    def _local_props_by_source(local_df):
        """Average neighbor composition per (celltype, source_batch) then row-normalize."""
        local_df = local_df.reindex(columns=ordered_batches, fill_value=0)
        tmp = local_df.copy()
        tmp["celltype"] = cell_types.values
        tmp["source_batch"] = batches.values
        means = tmp.groupby(["celltype", "source_batch"])[ordered_batches].mean()
        props = means.div(means.sum(axis=1), axis=0).fillna(0)

        # Order rows as (ct, batch) if they exist
        ordered_idx = []
        for ct in unique_cell_types:
            for b in ordered_batches:
                key = (ct, b)
                if key in props.index:
                    ordered_idx.append(key)
        if ordered_idx:
            props = props.loc[ordered_idx]
        return props

    def _plot_local_props_h(ax_, props):
        """Horizontal stacked bars for local distributions."""
        props.plot(kind="barh", stacked=True, ax=ax_,
                   color=[batch_to_color[b] for b in ordered_batches],
                   width=0.85, legend=False)
        ax_.set_xlabel("Proportion")
        ax_.set_ylabel("")
        # y tick labels from MultiIndex (ct, source_batch)
        ylabels = [f"{ct} ({sb})" for (ct, sb) in props.index]
        ax_.set_yticklabels(ylabels)
        ax_.set_xlim(0.0, 1.0)
        ax_.grid(axis='x', alpha=0.25)

    # ----- Row 2: Global (B) & Local k=90 (C) — HORIZONTAL -----
    global_with_type = global_dist.reindex(columns=ordered_batches, fill_value=0).copy()
    global_with_type["celltype"] = cell_types.values
    global_means = global_with_type.groupby("celltype")[ordered_batches].mean()
    global_props = global_means.div(global_means.sum(axis=1), axis=0).fillna(0)

    global_props.plot(kind="barh", stacked=True, ax=ax_global,
                      color=[batch_to_color[b] for b in ordered_batches],
                      width=0.85, legend=False)
    ax_global.set_xlabel("Proportion")
    ax_global.set_ylabel("")
    ax_global.set_yticklabels(list(global_props.index))
    ax_global.set_xlim(0.0, 1.0)
    ax_global.grid(axis='x', alpha=0.25)
    ax_global.text(-0.12, 1.02, "B", transform=ax_global.transAxes,
                   fontsize=11, fontweight='bold', va='bottom', ha='right')

    props90 = _local_props_by_source(local_dist)
    _plot_local_props_h(ax_local_k90, props90)
    ax_local_k90.text(-0.12, 1.02, "C", transform=ax_local_k90.transAxes,
                      fontsize=11, fontweight='bold', va='bottom', ha='right')

    # ----- Row 3: Local k=60 (D) & k=30 (E) — HORIZONTAL -----
    if local_dist_k60 is not None:
        props60 = _local_props_by_source(local_dist_k60)
        _plot_local_props_h(ax_k60, props60)
    else:
        ax_k60.axis('off')
        ax_k60.text(0.5, 0.5, "k=60 not provided", ha='center', va='center')
    ax_k60.text(-0.12, 1.02, "D", transform=ax_k60.transAxes,
                fontsize=11, fontweight='bold', va='bottom', ha='right')

    if local_dist_k30 is not None:
        props30 = _local_props_by_source(local_dist_k30)
        _plot_local_props_h(ax_k30, props30)
    else:
        ax_k30.axis('off')
        ax_k30.text(0.5, 0.5, "k=30 not provided", ha='center', va='center')
    ax_k30.text(-0.12, 1.02, "E", transform=ax_k30.transAxes,
                fontsize=11, fontweight='bold', va='bottom', ha='right')

    # ----- Export: vector PDF (preferred) + optional PNG -----
    pdf_path = out_file if out_file.lower().endswith(".pdf") else out_file + ".pdf"
    fig.savefig(pdf_path, bbox_inches="tight")
    if export_png:
        png_path = out_file if out_file.lower().endswith(".png") else out_file + ".png"
        fig.savefig(png_path, dpi=png_dpi, bbox_inches="tight")
    plt.close(fig)


In [14]:
def plot_emb_and_global_pub(
        scenario, coords,
        cell_types, batches,
        local_dist, global_dist,
        out_file, emb="pca",
        local_dist_k60=None, local_dist_k30=None,
        export_png=False, png_dpi=600,
        sort_by_total=False, ascending=False,
        bar_width=0.4,           # slimmer grouped bars (default 0.6)
        legend_col_width=0.32    # width fraction for right legend column
):
    """
    Nature-ready: ONLY embedding + grouped COUNT bars.
    A and B are EXACTLY aligned in width by:
      - placing both in the same left GridSpec column,
      - disabling constrained_layout,
      - and hard-locking their left/right positions to the left-column cell.
    """

    # ---- Palette: Okabe–Ito (colorblind-safe) ----
    okabe_ito = ["#0072B2", "#E69F00", "#009E73", "#D55E00",
                 "#CC79A7", "#56B4E9", "#F0E442", "#000000"]

    # Entities
    unique_batches = list(pd.Index(batches.unique()).sort_values())
    unique_cell_types = list(pd.Index(cell_types.unique()).sort_values())

    # Series order for grouped bars
    if isinstance(global_dist, pd.DataFrame) and len(global_dist.columns) > 0:
        ordered_batches = list(global_dist.columns) + [b for b in unique_batches if b not in global_dist.columns]
    else:
        ordered_batches = unique_batches

    batch_to_color = {b: okabe_ito[i % len(okabe_ito)] for i, b in enumerate(ordered_batches)}

    # ===== Figure / Grid =====
    fig = plt.figure(constrained_layout=False)  # we'll manage layout manually
    gs = fig.add_gridspec(
        nrows=2, ncols=2,
        width_ratios=[1.0, legend_col_width],   # left = plots, right = legends
        height_ratios=[1.0, 1.0]
    )
    # Fixed margins; legends won’t compress plot widths
    fig.subplots_adjust(left=0.10, right=0.90, top=0.98, bottom=0.12, hspace=0.28, wspace=0.15)

    ax_emb = fig.add_subplot(gs[0, 0])   # A (embedding)
    ax_bar = fig.add_subplot(gs[1, 0])   # B (grouped bars)
    ax_leg = fig.add_subplot(gs[:, 1])   # legends column (spans both rows)
    ax_leg.axis("off")

    # ===== Panel A: Embedding (square) =====
    markers = ['o', 's', '^', 'v', '<', '>', 'D', 'p', 'h', '*', '+', 'x', 'P', 'H', '1', '2', '3', '4', '.', '|', '_']
    for j, ct in enumerate(unique_cell_types):
        for b in unique_batches:
            mask = (cell_types == ct) & (batches == b)
            if not np.any(mask):
                continue
            pts = coords[mask.values] if isinstance(mask, pd.Series) else coords[mask]
            ax_emb.scatter(
                pts[:, 0], pts[:, 1],
                s=10,
                marker=markers[j % len(markers)],
                facecolors=batch_to_color[b],
                edgecolors='black', linewidths=0.15,
                alpha=0.85, zorder=2, rasterized=True
            )

    # Square frame (uses full left-column width)
    x_min, x_max = np.min(coords[:, 0]), np.max(coords[:, 0])
    y_min, y_max = np.min(coords[:, 1]), np.max(coords[:, 1])
    x_c, y_c = 0.5 * (x_min + x_max), 0.5 * (y_min + y_max)
    span = max(x_max - x_min, y_max - y_min); pad = 0.05 * span; half = 0.5 * span + pad
    ax_emb.set_xlim(x_c - half, x_c + half)
    ax_emb.set_ylim(y_c - half, y_c + half)
    ax_emb.set_box_aspect(1)
    ax_emb.set_xticks([]); ax_emb.set_yticks([])
    # ax_emb.set_xlabel(f"{emb} 1"); ax_emb.set_ylabel(f"{emb} 2")
    ax_emb.set_xlabel("emb 1"); ax_emb.set_ylabel("emb 2")
    ax_emb.text(-0.08, 1.02, "A", transform=ax_emb.transAxes,
                fontsize=11, fontweight='bold', va='bottom', ha='right')

    # ===== Panel B: Grouped vertical COUNT bars (x = cell types, y = counts) =====
    counts = pd.crosstab(cell_types, batches).reindex(
        index=unique_cell_types, columns=ordered_batches, fill_value=0
    )
    if sort_by_total:
        counts = counts.loc[counts.sum(axis=1).sort_values(ascending=ascending).index]

    # Grouped bars (thinner width)
    counts.plot(kind="bar", stacked=False, ax=ax_bar,
                color=[batch_to_color[b] for b in ordered_batches],
                width=bar_width, legend=False)

    ax_bar.set_ylabel("Cells (count)")
    ax_bar.set_xlabel("")
    ax_bar.set_xticklabels(list(counts.index), rotation=0, ha="center")
    ax_bar.grid(axis='y', alpha=0.25)
    ax_bar.text(-0.12, 1.02, "B", transform=ax_bar.transAxes,
                fontsize=11, fontweight='bold', va='bottom', ha='right')
    ax_bar.margins(x=0.02)

    # ===== Legends (right column) =====
    batch_handles = [Line2D([0], [0], marker='o', linestyle='None',
                            markerfacecolor=batch_to_color[b], markeredgecolor='black',
                            markersize=6, label=b)
                     for b in ordered_batches]
    leg_batches = ax_leg.legend(handles=batch_handles, title="Batch",
                                loc='upper left', bbox_to_anchor=(0.0, 1.0),
                                ncol=1, borderaxespad=0., frameon=False)
    ax_leg.add_artist(leg_batches)

    markerset = [Line2D([0], [0], marker=markers[j % len(markers)], linestyle='None',
                        markerfacecolor='white', markeredgecolor='black',
                        markersize=6, label=ct)
                 for j, ct in enumerate(unique_cell_types)]
    ax_leg.legend(handles=markerset, title="Cell type",
                  loc='upper left', bbox_to_anchor=(0.0, 0.60),
                  ncol=1, borderaxespad=0., frameon=False)

    # ===== Hard lock A & B to the SAME left/right =====
    # Use the left-column grid cell as the authority for x0/x1 so both axes match exactly.
    fig.canvas.draw()  # finalize initial layout
    left_cell_bbox = gs[0, 0].get_position(fig)  # left column cell bbox
    left_x0, left_x1 = left_cell_bbox.x0, left_cell_bbox.x1
    # Preserve each axis' own y/height, but force identical left/right
    for ax_ in (ax_emb, ax_bar):
        p = ax_.get_position()
        ax_.set_position([left_x0, p.y0, left_x1 - left_x0, p.height])

    # ===== Export =====
    pdf_path = out_file if out_file.lower().endswith(".pdf") else out_file + ".pdf"
    fig.savefig(pdf_path, bbox_inches="tight")
    if export_png:
        png_path = out_file if out_file.lower().endswith(".png") else out_file + ".png"
        fig.savefig(png_path, dpi=png_dpi, bbox_inches="tight")
    plt.close(fig)


In [15]:
def build_distributions(adata, dknn_df, celltypes_df, batches_df):
    """
    Build local and global distributions for each cell
    
    Parameters
    ----------
    adata : anndata object
    dknn_df : pd.DataFrame
        indices of k nearest neighbors for each cell    
    celltypes_df : pd.DataFrame
        Cell type label for each cell.
    batches_df : pd.DataFrame
        Batch label for each cell.

    Returns
    --------
    pd.DataFrame, pd.DataFrame
        Local and global distributions for each cell.
    """

    k = len(dknn_df.columns)    # number of neighbors

    # construct dataframes
    # batches_neighbors_df      : holds the batch label of each neighbor for each cell
    # celltypes_neighbors_df    : holds the cell type label of each neighbor for each cell
    batches_neighbors_df = pd.DataFrame(index=dknn_df.index, columns=[i for i in range(0, k)])
    celltypes_neighbors_df = pd.DataFrame(index=dknn_df.index, columns=[i for i in range(0, k)])

    batches = adata[batches_neighbors_df.index].obs["Batch"].values
    celltypes = adata[celltypes_neighbors_df.index].obs["Group"].values

    for i in range(len(batches_neighbors_df.index)):
        # get batch and celltype label of each neighbor
        batches_neighbors_df.iloc[i] = batches[dknn_df.iloc[i].values]
        celltypes_neighbors_df.iloc[i] = celltypes[dknn_df.iloc[i].values]

    # local and global distibutions
    local_dist = pd.DataFrame(0, index=batches_df.index, columns=batches.unique(), dtype=float)
    global_dist = pd.DataFrame(0, index=batches_df.index, columns=batches.unique(), dtype=float)

    for b in batches.unique():
        print(type(b), b)
        for idx, cell_id in enumerate(celltypes_neighbors_df.index):
            cell_type = celltypes_df.iloc[cell_id]['Group']  # cell type of the current cell
            # neigh_celltypes = np.array(celltypes_neighbors_df.loc[cell_id])[1:] # celltypes of neighbors excluding the first one (itself)
            neigh_celltypes = np.array(celltypes_neighbors_df.loc[cell_id])
            # neigh_batch_labels = np.array(batches_neighbors_df.loc[cell_id])[1:]    # batch labels of neighbors excluding the first one (itself)    
            neigh_batch_labels = np.array(batches_neighbors_df.loc[cell_id])
            neigh_celltypes_x = neigh_batch_labels[np.where(neigh_celltypes==cell_type)]    # batch labels of neighbors that have the same cell type as the current cell  
            local_dist[b][cell_id] = len(neigh_celltypes_x[neigh_celltypes_x==b])   # number of same-type neighbors belong to the current batch b
            global_dist[b][cell_id] = adata[(adata.obs.Batch==b) & (adata.obs.Group==cell_type)].shape[0]   # total cells of this type in batch b across the entire dataset

    return local_dist, global_dist

### Distribution-based metrics

In [16]:
def kl(p, q, epsilon=1e-10):
    """
    Calculates KL divergence between two distributions

    Parameters:
    -----------
    p and q : np.array
        Two distributions
    epsilon : scalar
        To avoid infinity (i.e., division by zero)

    Returns:
    --------
    scalar
        KL divergence score between two distributions
    """
    
    p = np.clip(p, a_min=epsilon, a_max=None)   # [0, 0] -> [eps, eps]; [0.7, 0] -> [0.7, eps]
    q = np.clip(q, a_min=epsilon, a_max=None)

    p /= p.sum()
    q /= q.sum()

    return np.sum(kl_div(p, q))

In [17]:
def js(p, q, epsilon=1e-10):
    """
    Calculates JS distance between two distributions

    Parameters:
    -----------
    p and q : np.array
        Two distributions

    Returns:
    --------
    scalar :
        JS distance between two distributions
    """
    
    p = np.clip(p, a_min=epsilon, a_max=None)   # [0, 0] -> [eps, eps]; [0.7, 0] -> [0.7, eps]
    q = np.clip(q, a_min=epsilon, a_max=None)

    p /= p.sum()
    q /= q.sum()

    return jensenshannon(p, q, base=2.0)  # by default jensenshannon() uses log base e


In [18]:
def tv(p, q, epsilon=1e-10):
    """
    Calculates total variation distance between two distributions

    Parameters:
    -----------
    p and q : np.array
        Two distributions

    Returns:
    --------
    scalar :
        TV distance between two distributions
    """
    
    p = np.clip(p, a_min=epsilon, a_max=None)   # [0, 0] -> [eps, eps]; [0.7, 0] -> [0.7, eps]
    q = np.clip(q, a_min=epsilon, a_max=None)

    p /= p.sum()
    q /= q.sum()

    return 0.5 * np.sum(np.abs(p - q))

In [19]:
def hellinger(p, q, epsilon=1e-10):
    """
    Calculates Hellinger distance between two distributions

    Parameters:
    -----------
    p and q : np.array
        Two distributions

    Returns:
    --------
    scalar :
        Hellinger distance between two distributions
    """
    
    p = np.clip(p, a_min=epsilon, a_max=None)   # [0, 0] -> [eps, eps]; [0.7, 0] -> [0.7, eps]
    q = np.clip(q, a_min=epsilon, a_max=None)

    p /= p.sum()
    q /= q.sum()

    return math.sqrt(sum([(math.sqrt(t[0])-math.sqrt(t[1]))*(math.sqrt(t[0])-math.sqrt(t[1]))\
                for t in zip(p,q)]))/math.sqrt(2.)

In [20]:
def chi_sd(p, q, epsilon=1e-10):
    """
    Calculates chi-square distance between two distributions

    Parameters:
    -----------
    p and q : np.array
        Two distributions

    Returns:
    --------
    scalar :
        Chi-square distance between two distributions
    """
    
    p = np.clip(p, a_min=epsilon, a_max=None)   # [0, 0] -> [eps, eps]; [0.7, 0] -> [0.7, eps]
    q = np.clip(q, a_min=epsilon, a_max=None)

    p /= p.sum()
    q /= q.sum()

    return 0.5 * np.sum(((p - q) ** 2) / (p + q))

In [21]:
# Wasserstein Distance

# wasserstein_distance([0.2, 0.8], [0.8, 0.2]) does not treat those arrays as probability distributions.
# It treats them as samples (locations).

# As samples, both arrays contain the same points {0.2, 0.8}. The function sorts them first, so both become [0.2, 0.8].
# Two identical sample sets ⇒ Wasserstein distance 0.0.

# That’s why calling wasserstein_distance([0.2, 0.8], [0.8, 0.2]) returns 0.0.

# To compare discrete distributions over known bins, pass the bin locations as values and the probabilities as weights:
# wasserstein_distance(
#    u_values=[0, 1], v_values=[0, 1],
#    u_weights=[0.2, 0.8], v_weights=[0.8, 0.2]
#) = 0.6

In [22]:
wasserstein_distance([0.2, 0.8], [0.8, 0.2]), wasserstein_distance([0, 1], [0, 1], [0.2, 0.8], [0.8, 0.2])

(np.float64(0.0), np.float64(0.6000000000000001))

In [23]:
def wd(p, q, n_batches, epsilon=1e-10):
    """
    Calculates Wasserstein distance between two distributions

    Parameters:
    -----------
    p and q : np.array
        Two distributions

    Returns:
    --------
    scalar :
        Wasserstein distance between two distributions
    """
    
    p = np.clip(p, a_min=epsilon, a_max=None)   # [0, 0] -> [eps, eps]; [0.7, 0] -> [0.7, eps]
    q = np.clip(q, a_min=epsilon, a_max=None)
    
    p /= p.sum()
    q /= q.sum()

    x = np.arange(n_batches)

    return wasserstein_distance(u_values=x, v_values=x, u_weights=p, v_weights=q)

In [24]:
def distr_based_metrics(scores, local_dist, global_dist, n_celltypes, n_batches):
    """
    Calculates distribution-based metrics (i.e., KL, JS, TV, H, chiSD and WD).

    Parameters:
    -----------
    scores : pd.DataFrame
        Dataframe which will hold metric calculations (cell IDs are stored as the indices).
    local_dist : pd.DataFrame
        Local distribution for each cell
    global_dist : pd.DataFrame
        Global distribution for each cell
    n_celltypes : scalar
        Number of (unique) celltypes in the dataset
    n_batches : scalar
        Number of (unique) batches in the dataset
    
    Returns:
    --------
    scores : pd.DataFrame
        Metric calculation results
    """
    
    # identity matrix with shape n_batches x n_batches
    I = np.eye(n_batches)
    
    scores ["KL loc|glob"] = 0.
    scores ["KL glob|loc"] = 0.
    scores ["max KL loc|glob"] = 0.
    scores ["max KL glob|loc"] = 0.

    scores ["JS Dist"] = 0.
    scores ["JS Div"] = 0.
    scores ["max JS Dist"] = 0.
    scores ["max JS Div"] = 0.

    scores ["TV"] = 0.
    scores ["max TV"] = 0.

    scores ["H"] = 0.
    scores ["max H"] = 0.

    scores ["chiSD"] = 0.
    scores ["max chiSD"] = 0.

    scores ["WD"] = 0.
    scores ["max WD"] = 0.

    for idx, cell_id in enumerate(scores.index):
        loc  = np.array(local_dist.loc[cell_id])
        glob = np.array(global_dist.loc[cell_id])

        scores["KL loc|glob"][cell_id] = kl(loc, glob)
        scores["KL glob|loc"][cell_id] = kl(glob, loc)

        js_dist = js(loc, glob)
        scores["JS Dist"][cell_id] = js_dist
        scores["JS Div"][cell_id] = js_dist ** 2

        scores["TV"][cell_id] = tv(loc, glob)

        scores["H"][cell_id] = hellinger(loc, glob)

        scores ["chiSD"][cell_id] = chi_sd(loc, glob)

        scores ["WD"][cell_id] = wd(loc, glob, n_batches)

        scores["max KL loc|glob"][cell_id]  = scores["KL loc|glob"][cell_id]
        scores["max KL glob|loc"][cell_id]  = scores["KL glob|loc"][cell_id]
        scores["max JS Dist"][cell_id]      = scores["JS Dist"][cell_id]
        scores["max TV"][cell_id]           = scores["TV"][cell_id]
        scores["max H"][cell_id]            = scores["H"][cell_id]
        scores["max chiSD"][cell_id]        = scores ["chiSD"][cell_id]
        scores["max WD"][cell_id]           = scores ["WD"][cell_id]

        # get the max scores with this global distr (for each cell)
        for i in I:
            scores["max KL loc|glob"][cell_id] = max(scores["max KL loc|glob"][cell_id], kl(i, glob))
            scores["max KL glob|loc"][cell_id] = max(scores["max KL glob|loc"][cell_id], kl(glob, i))

            scores["max JS Dist"][cell_id] = max(scores["max JS Dist"][cell_id], js(i, glob))

            scores["max TV"][cell_id] = max(scores["max TV"][cell_id], tv(i, glob))

            scores["max H"][cell_id] = max(scores["max H"][cell_id], hellinger(i, glob))
            
            scores["max chiSD"][cell_id] = max(scores["max chiSD"][cell_id], chi_sd(i, glob))

            scores["max WD"][cell_id] = max(scores["max WD"][cell_id], wd(i, glob, n_batches))

        scores["max JS Div"][cell_id] = scores["max JS Dist"][cell_id] ** 2

    return scores


### Convential metrics

In [25]:
def conventional_metrics(adata, batch_key, group_key, embed="X_pca"):
    """
    Calculates conventional metrics (i.e., kBET, ASW batch, PCR, graph connectivity and isolated labels f1/asw).

    Parameters:
    -----------
    adata : anndata object
    batch_key : str
        name of batch column in adata.obs
    group_key : str
        name of cell identity labels column in adata.obs
    embed : str
        embedding key in adata.obsm for embedding and feature input
    
    Returns:
    --------
    kBET, ASW batch, PCR, graph_conn, iso_labels_f1/asw scores
    """

    # kBET; 1: good batch mixing; 0: low batch mixing
    kBET_scores = scib.me.kBET(
            adata, batch_key=batch_key, label_key=group_key, type_="full", embed=embed, return_df=True
        )

    # ASW batch
    asw, sil_means, sil_df = scib.me.silhouette_batch(
                adata, batch_key=batch_key, group_key=group_key, embed=embed, return_all=True)

    asw_arr = np.array(sil_df["silhouette_score"])

    # PCR batch
    pcr = scib.me.pcr(adata=adata, covariate="Batch", embed="X_pca")
    
    # graph connectivity
    sc.pp.neighbors(adata, use_rep="X_pca")
    graph_conn = graph_connectivity_per_celltype(adata, label_key=group_key)

    # isolated_labels_f1
    iso_f1 = scib.me.isolated_labels_f1(adata=adata, label_key=group_key, batch_key=batch_key, embed=embed, return_all=True)

    # isoalated_labels_asw
    iso_asw = isolated_labels_asw(adata=adata, label_key=group_key, batch_key=batch_key, embed=embed, return_all=True)
    
    return kBET_scores, asw_arr, pcr, graph_conn, iso_f1, iso_asw

In [26]:
def normalize_metrics(scores):
    """
    Normalizes distribution-based metrics so they can be between [0, 1].

    Parameters:
    -----------
    scores : pd.DataFrame
        Dataframe which will hold metric calculations (cell IDs are stored as the indices).

    Returns:
    --------
    scores : pd.DataFrame
        Metric calculation results
    """

    scores["nKL loc|glob"] = scores["KL loc|glob"] / scores["max KL loc|glob"]
    scores["nKL glob|loc"] = scores["KL glob|loc"] / scores["max KL glob|loc"]
    scores["nJS Dist"] = scores["JS Dist"] / scores["max JS Dist"]
    scores["nJS Div"] = scores["JS Div"] / scores["max JS Div"]
    scores["nTV"] = scores["TV"] / scores["max TV"]
    scores["nH"] = scores["H"] / scores["max H"]
    scores["nChiSD"] = scores["chiSD"] / scores["max chiSD"]
    scores["nWD"] = scores["WD"] / scores["max WD"]
    
    return scores

In [27]:
def weighted_metrics(scores):
    """
    Calculates weighted metrics.

    Parameters:
    -----------
    scores : pd.DataFrame
        Dataframe which will hold metric calculations (cell IDs are stored as the indices).

    Returns:
    --------
    scores : pd.DataFrame
        Metric calculation results
    """

    scores["wKL loc|glob"] = scores["weight"] * scores["KL loc|glob"]
    scores["wnKL loc|glob"] = scores["weight"] * scores["nKL loc|glob"]

    scores["wKL glob|loc"] = scores["weight"] * scores["KL glob|loc"]
    scores["wnKL glob|loc"] = scores["weight"] * scores["nKL glob|loc"]
    
    scores["wJS Dist"] = scores["weight"] * scores["JS Dist"]
    scores["wnJS Dist"] = scores["weight"] * scores["nJS Dist"]

    scores["wJS Div"] = scores["weight"] * scores["JS Div"]
    scores["wnJS Div"] = scores["weight"] * scores["nJS Div"]

    scores["wTV"] = scores["weight"] * scores["TV"]
    scores["wnTV"] = scores["weight"] * scores["nTV"]

    scores["wH"] = scores["weight"] * scores["H"]
    scores["wnH"] = scores["weight"] * scores["nH"]

    scores["w_nILISI"] = scores["weight"] * scores["n_iLISI"]
    
    scores["wChiSD"] = scores["weight"] * scores["chiSD"]
    scores["wnChiSD"] = scores["weight"] * scores["nChiSD"]

    scores["wWD"] = scores["weight"] * scores["WD"]
    scores["wnWD"] = scores["weight"] * scores["nWD"]

    scores["wASW"] = scores["weight"] * scores["1-ASW"]

    return scores

### Main

In [28]:
adata = sc.read_h5ad(f"../scenarios/{scenario}")

adata

AnnData object with n_obs × n_vars = 3000 × 2000
    obs: 'cell_id', 'batch', 'cell_type', 'umap_1', 'umap_2'
    obsm: 'X_umap'

In [29]:
# remove and rename
del adata.obs["umap_1"]
del adata.obs["umap_2"]
adata.obsm.pop("X_umap")

adata.obs.rename(columns={
    "cell_type": "Group",
    "batch": "Batch"
}, inplace=True)

adata

AnnData object with n_obs × n_vars = 3000 × 2000
    obs: 'cell_id', 'Batch', 'Group'

In [30]:
print("Reducing data...")
scib.preprocessing.reduce_data(adata, batch_key="Batch", umap=True)

adata

Reducing data...
HVG
Using 2000 HVGs from full intersect set
Using 0 HVGs from n_batch-1 set
Using 2000 HVGs
Computed 2000 highly variable genes
PCA
Nearest Neigbours
UMAP


AnnData object with n_obs × n_vars = 3000 × 2000
    obs: 'cell_id', 'Batch', 'Group'
    var: 'highly_variable'
    uns: 'pca', 'neighbors', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

In [31]:
# keep only HVGs
adata = adata[:, adata.var["highly_variable"]].copy()

adata

AnnData object with n_obs × n_vars = 3000 × 2000
    obs: 'cell_id', 'Batch', 'Group'
    var: 'highly_variable'
    uns: 'pca', 'neighbors', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

In [32]:
print("Calculating mean distances to same type cells...")
# gene_expression = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X
# gene_expression.shape

pca_coords = adata.obsm["X_pca"]

Calculating mean distances to same type cells...


In [33]:
pca_coords.shape, pca_coords

((3000, 50),
 array([[-3.0819336e-01, -1.8723118e+01, -2.2236454e-01, ...,
          2.8082314e-01, -1.6306000e-02,  7.7267259e-01],
        [-1.6995955e-02, -1.8590261e+01, -2.0694648e-01, ...,
          2.1941593e-01, -7.3635012e-02, -3.6280543e-01],
        [-4.3151140e-01, -1.8566200e+01,  3.7363585e-02, ...,
         -2.5001088e-01,  2.1918282e-01,  1.2480956e-01],
        ...,
        [-1.3094796e-01, -1.8464907e+01,  1.7160770e-02, ...,
         -5.2044958e-02, -6.3519341e-01,  3.7912294e-01],
        [-1.5569174e-01, -1.8361698e+01, -3.8564491e-01, ...,
          4.2384756e-01,  3.6009499e-01,  9.2723764e-02],
        [-4.6501976e+01,  9.2302456e+00,  2.2291602e-01, ...,
         -3.2479957e-01,  4.8601322e-02, -1.5119648e-01]], dtype=float32))

In [34]:
cell_types = adata.obs["Group"]
batches = adata.obs["Batch"]

In [35]:
adata.obs['normalized_mean_dist_to_same_type'] = mean_dists_to_same_type_cells(pca_coords, cell_types)

adata

AnnData object with n_obs × n_vars = 3000 × 2000
    obs: 'cell_id', 'Batch', 'Group', 'normalized_mean_dist_to_same_type'
    var: 'highly_variable'
    uns: 'pca', 'neighbors', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

In [36]:
out_dir

'../results/2025-07-03/out/01'

In [37]:
sc_order

'01'

In [38]:
emb = "pca"
if emb == "pca":
    emb_coords = adata.obsm["X_pca"][:, :2]
# umap
else:
    emb_coords = adata.obsm["X_umap"]

out_file = f"{out_dir}/{sc_order}_{emb}_and_distrs.png"

In [39]:
print("Building distributions...")

int_df = pd.DataFrame(pca_coords)
int_df

Building distributions...


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,-0.308193,-18.723118,-0.222365,0.173728,-0.109960,0.091032,0.022229,-0.568084,0.127034,-0.475285,...,-0.148898,0.150370,-0.316477,-0.507563,0.005914,-0.004248,0.004117,0.280823,-0.016306,0.772673
1,-0.016996,-18.590261,-0.206946,-0.345269,0.253954,0.167777,-0.192839,-0.060571,-0.100360,-0.110383,...,0.221897,0.247005,-0.196879,0.492513,0.058445,-0.066590,-0.280863,0.219416,-0.073635,-0.362805
2,-0.431511,-18.566200,0.037364,0.711880,0.360014,-0.072071,0.663141,0.076312,0.164106,0.042735,...,-0.029734,-0.156944,-0.480401,0.031721,-0.526340,0.189687,-0.423752,-0.250011,0.219183,0.124810
3,-46.970257,9.339377,-0.206623,0.178732,-0.350372,-0.184252,0.484058,-0.096217,0.707637,-0.225452,...,-0.037104,-0.808915,-0.534499,0.103032,-0.435024,-0.320718,0.338513,-0.259204,-0.112704,-0.096582
4,47.569527,9.249630,0.264705,0.005034,-0.167559,0.337533,0.211542,0.023955,-0.222447,-0.189657,...,-0.173505,-0.373178,0.670260,-0.445444,0.147383,-0.084340,0.312966,-0.037651,-0.213379,0.042307
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2995,-0.212553,-18.598145,-0.281074,-0.179805,0.075338,0.291520,-0.241758,-0.277927,-0.172462,0.079976,...,-0.118801,0.083794,-0.205832,0.464777,0.175157,0.145229,0.277795,-0.413220,-0.523655,0.057600
2996,-0.144551,-18.468899,-0.241049,-0.188320,0.752085,-0.153387,-0.239993,-0.592431,-0.044067,-0.537232,...,0.094726,-0.333842,-0.145666,0.005572,0.742716,-0.358578,-0.335165,0.835790,0.283500,-0.477106
2997,-0.130948,-18.464907,0.017161,-0.332873,0.142114,-0.090164,-0.003018,0.147479,0.022505,-0.124801,...,0.723167,-0.294975,0.023774,-0.949414,0.437239,-0.151126,0.318093,-0.052045,-0.635193,0.379123
2998,-0.155692,-18.361698,-0.385645,-0.081822,-0.278615,0.177356,0.260813,0.243684,0.800808,-0.240949,...,0.007641,-0.396460,-0.598229,0.598075,0.316376,-0.074438,0.019125,0.423848,0.360095,0.092724


In [40]:
k = 90  # number of neighbors for each cell

idx, dists = knn(int_df, k=k)
dknn_df = pd.DataFrame(idx, index=int_df.index)
dists_df = pd.DataFrame(dists, index=int_df.index)

In [41]:
dknn_df_k60 = dknn_df.iloc[:, :60]
dknn_df_k30 = dknn_df.iloc[:, :30]

celltypes_df = pd.DataFrame(adata.obs.Group)
batches_df = pd.DataFrame(adata.obs.Batch)

In [42]:
local_dist, global_dist = build_distributions(adata,     dknn_df, celltypes_df, batches_df)
local_dist_k60, _       = build_distributions(adata, dknn_df_k60, celltypes_df, batches_df)
local_dist_k30, _       = build_distributions(adata, dknn_df_k30, celltypes_df, batches_df)

<class 'str'> batch_1
<class 'str'> batch_2
<class 'str'> batch_1
<class 'str'> batch_2
<class 'str'> batch_1
<class 'str'> batch_2


In [43]:
sc_order, out_dir

('01', '../results/2025-07-03/out/01')

In [44]:
out_file = f"{out_dir}/{sc_order}_n_genes_2000_emb_and_global"

plot_emb_and_global_pub(
    scenario=sc_order,
    coords=emb_coords,
    cell_types=cell_types,
    batches=batches,
    local_dist=None, global_dist=None,  # ignored
    out_file=out_file,
    emb="pca",
    export_png=True, png_dpi=600,
    sort_by_total=False, ascending=False,  # optional
)

In [45]:
# plot_emb_and_distrs(scenario=sc_order, coords=emb_coords, cell_types=cell_types, batches=batches,
#                        local_dist=local_dist.copy(), global_dist=global_dist.copy(), out_file=out_file,
#                        emb=emb, local_dist_k60=local_dist_k60.copy(), local_dist_k30=local_dist_k30.copy())

In [46]:
n_celltypes = len(celltypes_df["Group"].unique())
n_batches = len(batches_df["Batch"].unique())

In [47]:
sc_order, out_dir

('01', '../results/2025-07-03/out/01')

In [48]:
for k in [90, 60, 30]:
    print(f"Calculating metrics k={k}...")
    scores = pd.DataFrame(0., columns=[], index=global_dist.index)
    scores["Cell type"] = celltypes_df
    scores["Batch"] = batches_df
    scores["weight"] = adata.obs["normalized_mean_dist_to_same_type"]

    if k == 60:
        local_dist = local_dist_k60
    if k == 30:
        local_dist = local_dist_k30

    scores = distr_based_metrics(scores, local_dist.copy(), global_dist.copy(), n_celltypes, n_batches)

    # LISI
    %R library(lisi)
    %R -i int_df,batches_df,celltypes_df,k
    %R cLISI=lisi::compute_lisi(int_df, data.frame(celltypes_df), colnames(celltypes_df), perplexity=k/3)
    %R iLISI=lisi::compute_lisi(int_df, data.frame(batches_df), colnames(batches_df), perplexity=k/3)
    %R -o iLISI,cLISI

    scores["iLISI"] = np.array(iLISI["Batch"])
    scores["cLISI"] = np.array(cLISI["Group"])

    kBET_scores, asw_arr, pcr, graph_conn, iso_f1, iso_asw = conventional_metrics(adata, batch_key="Batch", group_key="Group")

    if len(asw_arr) == len(scores.index):
        scores["ASW"] = asw_arr
    else:
        scores["ASW"] = np.nan
        
    scores["PCR"] = pcr

    scores["iLISI"] = scores["iLISI"].clip(upper=n_batches)

    # Normalize iLISI so 0 can be interpreted as good batch mixing and 1 as bad batch mixing
    scores["n_iLISI"] = (n_batches - scores["iLISI"]) / (n_batches - 1)

    # Initially, ASW = 1 means different batches are mixing well. After inversion, ASW = 0 means batches are not mixing well.
    scores["1-ASW"] = 1 - scores["ASW"]

    # normalize distribution-based metrics
    scores = normalize_metrics(scores)

    # weight scores
    scores = weighted_metrics(scores)

    # detailed scores
    detailed_scores = pd.DataFrame(columns=scores.columns, index=sorted(adata.obs.Group.unique()))
    detailed_scores = detailed_scores.drop(columns=["Cell type", "Batch"])

    for bt in adata.obs.Batch.unique():
        detailed_scores[f"#{bt}"] = 0.

    for ct in adata.obs.Group.unique():
        print(ct)
        detailed_scores.loc[ct] = scores.loc[celltypes_df[celltypes_df.Group==ct].index].drop(columns=["Cell type", "Batch"]).mean()
            
        for bt in adata.obs.Batch.unique():
            detailed_scores[f"#{bt}"][ct] = scores[(scores["Batch"] == bt) & (scores["Cell type"] == ct)].shape[0]

    detailed_scores["kBET"] = np.nan
    for i in range(len(kBET_scores["cluster"])):
        print(kBET_scores["cluster"][i])
        detailed_scores["kBET"][kBET_scores["cluster"][i]] = kBET_scores["kBET"][i]

    # kBET = 0 means different batches are mixing well
    detailed_scores["wkBET"] = detailed_scores["kBET"] * detailed_scores["weight"]

    # add graph_conn col
    detailed_scores["graph_conn"] = detailed_scores.index.map(graph_conn)

    detailed_scores["1-graph_conn"] = 1 - detailed_scores["graph_conn"]

    detailed_scores["wGraph_conn"] = detailed_scores["1-graph_conn"] * detailed_scores["weight"]

    if iso_f1.empty:
        detailed_scores["isolated_labels_f1"] = np.nan
    else:
        detailed_scores["isolated_labels_f1"] = iso_f1

    if iso_asw.empty:
        detailed_scores["isolated_labels_asw"] = np.nan
    else:
        detailed_scores["isolated_labels_asw"] = iso_asw

    # reorder columns
    assert sorted(detailed_scores.columns) == sorted(col_order)
    detailed_scores = detailed_scores[col_order]

    round(detailed_scores,5).to_excel(f"{out_dir}/{sc_order}_n_genes_2000_k_{k}_scores_celltypes_all_batches.xlsx")
    round(scores,5).to_excel(f"{out_dir}/{sc_order}_n_genes_2000_k_{k}_scores_all_batches.xlsx")


Calculating metrics k=90...
0 labels consist of a single batch or is too small. Skip.
mean silhouette per group:              silhouette_score
group                        
cell_type_1          0.996888
cell_type_2          0.992612
cell_type_3          0.995357
isolated labels: []
isolated labels: []
cell_type_2
cell_type_1
cell_type_3
cell_type_1
cell_type_2
cell_type_3
Calculating metrics k=60...
0 labels consist of a single batch or is too small. Skip.
mean silhouette per group:              silhouette_score
group                        
cell_type_1          0.996888
cell_type_2          0.992612
cell_type_3          0.995357
isolated labels: []
isolated labels: []
cell_type_2
cell_type_1
cell_type_3
cell_type_1
cell_type_2
cell_type_3
Calculating metrics k=30...
0 labels consist of a single batch or is too small. Skip.
mean silhouette per group:              silhouette_score
group                        
cell_type_1          0.996888
cell_type_2          0.992612
cell_type_3      

In [49]:
print("End of the notebook")

End of the notebook
