In [None]:
import os
import random
import itertools
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple
from scipy.spatial import distance


# ---------------------------------------------------------------------
# Synthetic test data
# ---------------------------------------------------------------------

def generate_cell_coordinates(
    n_cells: int = 500,
    fov_size_px: int = 512,
    seed: int = 0
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate random x/y coordinates for cells within the FOV.

    Parameters
    ----------
    n_cells : int
        Number of cells to generate.
    fov_size_px : int
        Field of view size in pixels (square, 0..fov_size_px).
    seed : int
        Random seed.

    Returns
    -------
    xpix, ypix : np.ndarray
        Arrays of integer coordinates of shape (n_cells,).
    """
    rng = np.random.default_rng(seed)
    xpix = rng.integers(0, fov_size_px, size=n_cells)
    ypix = rng.integers(0, fov_size_px, size=n_cells)
    return xpix, ypix


# ---------------------------------------------------------------------
# Region and selection utilities
# ---------------------------------------------------------------------

def select_cells_in_region(
    xpix: np.ndarray,
    ypix: np.ndarray,
    x_bounds: Tuple[int, int],
    y_bounds: Tuple[int, int]
) -> Tuple[List[int], List[int]]:
    """
    Select cells inside rectangular region.

    Returns
    -------
    cells : list[int]
        Indices of cells inside region.
    excluded : list[int]
        Indices of cells outside region.
    """
    cells = [i for i in range(len(xpix))
             if x_bounds[0] < xpix[i] < x_bounds[1]
             and y_bounds[0] < ypix[i] < y_bounds[1]]
    excluded = [i for i in range(len(xpix)) if i not in cells]
    return cells, excluded


def compute_possible_neighbours(
    xpix: np.ndarray,
    ypix: np.ndarray,
    stim_fov_size: float,
    umperpix: float
) -> List[List[int]]:
    """
    For each cell, find neighbours within stim FOV radius.

    Returns
    -------
    possible_cells : list[list[int]]
        For each cell index, list of neighbour indices.
    """
    possible_cells = []
    radius = (stim_fov_size / umperpix) / 2
    for idx in range(len(xpix)):
        a = (xpix[idx], ypix[idx])
        neigh = [cdx for cdx in range(len(xpix))
                 if cdx != idx and
                 distance.euclidean(a, (xpix[cdx], ypix[cdx])) < radius]
        possible_cells.append(neigh)
    return possible_cells


# ---------------------------------------------------------------------
# Trial generation
# ---------------------------------------------------------------------

def find_indexes_longest(list_of_lists: List[List[int]]) -> List[int]:
    """Return indices of the longest lists."""
    if not list_of_lists:
        return []
    max_len = max(len(l) for l in list_of_lists)
    return [i for i, l in enumerate(list_of_lists) if len(l) == max_len]


def generate_stimulation_trials(
    cells: List[int],
    possible_cells: List[List[int]],
    trials_per_cell: int,
    n_cells_per_clust: int
) -> List[List[int]]:
    """
    Generate clusters of stimulation trials.

    Each cell gets ~trials_per_cell stimulations,
    while minimizing repeated co-stimulation with same partners.

    Returns
    -------
    final : list[list[int]]
        Each element is a cluster of exactly `n_cells_per_clust` cell indices.
    """
    # options: each cell replicated trials_per_cell times
    options = [[c] * trials_per_cell for c in cells]
    options_ids = np.array(cells)

    final = []
    while sum(len(i) for i in options) >= n_cells_per_clust:
        current_cells = []

        # pick cell with most remaining trials
        longest = find_indexes_longest(options)
        if not longest:
            break
        idx = random.choice(longest)
        current_cells.append(options[idx][0])
        del options[idx][0]

        available_cells = possible_cells[idx]

        try:
            while len(current_cells) < n_cells_per_clust:
                options_greater_0 = [i for i in options if len(i) > 0]
                remaining = [c for c in available_cells if c not in current_cells]
                available_options = [i for i in options_greater_0 if i[0] in remaining]
                if not available_options:
                    break
                longest_avail = find_indexes_longest(available_options)
                cand_idx = random.choice(longest_avail)
                candidate_cell = available_options[cand_idx][0]
                current_cells.append(candidate_cell)
                # remove one occurrence from options
                del_idx = np.where(options_ids == candidate_cell)[0][0]
                del options[del_idx][0]

            # Only keep clusters that reached the full target size
            if len(current_cells) == n_cells_per_clust:
                final.append(current_cells)

        except Exception:
            pass

    return final


# ---------------------------------------------------------------------
# Analysis helpers
# ---------------------------------------------------------------------

def compute_overlap(final: List[List[int]], n_cells_per_clust: int) -> float:
    """Compute percentage of cluster pairs with overlap >1 cell."""
    all_combos = list(itertools.product(final, final))
    shared = [list(set(i[0]).intersection(i[1])) for i in all_combos]
    no_overlap = [i for i in shared if len(i) <= 1]
    with_overlap = [i for i in shared if 1 < len(i) < n_cells_per_clust]
    if not (no_overlap or with_overlap):
        return 0.0
    return 100 * (len(with_overlap) / (len(with_overlap) + len(no_overlap)))


def cluster_spread(xpix: np.ndarray, ypix: np.ndarray, clusters: List[List[int]]) -> List[float]:
    """
    Compute average pairwise distance within each cluster.
    """
    spreads = []
    for cluster in clusters:
        dists = []
        for i in cluster:
            for j in cluster:
                if j > i:
                    dists.append(distance.euclidean((xpix[i], ypix[i]), (xpix[j], ypix[j])))
        spreads.append(np.mean(dists) if dists else 0.0)
    return spreads


# ---------------------------------------------------------------------
# Example driver
# ---------------------------------------------------------------------

if __name__ == "__main__":
    # Example parameters
    FOVsizeum = 600
    umperpix = FOVsizeum / 512
    trials_per_cell = 50
    n_cells_per_clust = 15
    stim_fov_size = 400

    # Generate random coordinates
    xpix, ypix = generate_cell_coordinates()

    # Select all cells (full FOV in this demo)
    cells, excluded = select_cells_in_region(xpix, ypix, (0, 512), (0, 512))
    possible_cells = compute_possible_neighbours(xpix, ypix, stim_fov_size, umperpix)

    # Filter out cells with too few neighbours
    bad_idxs = [i for i, neigh in enumerate(possible_cells) if len(neigh) < n_cells_per_clust]
    cells = [c for i, c in enumerate(cells) if i not in bad_idxs]
    possible_cells = [neigh for i, neigh in enumerate(possible_cells) if i not in bad_idxs]

    # Generate stimulation trials
    final = generate_stimulation_trials(cells, possible_cells, trials_per_cell, n_cells_per_clust)

    # Overlap
    perc_overlap = compute_overlap(final, n_cells_per_clust)
    print(f"{len(final)} stimulation trials generated")
    print(f"{perc_overlap:.2f}% of trials share >1 cell")

    # Cluster spread
    spreads = cluster_spread(xpix, ypix, final)
    print(f"Mean cluster spread: {np.mean(spreads):.1f} px")


In [None]:
def plot_cluster_density(xpix, ypix, cells, final):
    """
    Plot cluster density analysis with four panels:
      1. All cells in the FOV
      2. Histogram of mean intra-cluster distances
      3. Least dense cluster
      4. Most dense cluster

    Parameters
    ----------
    xpix, ypix : np.ndarray
        Cell coordinates (pixels).
    cells : list[int]
        Indices of included cells.
    final : list[list[int]]
        List of stimulation clusters (cell indices).
    """
    # --- core computation (your original code) ---
    distances = []
    for idx, cluster in enumerate(final):
        l = []
        for u in cluster:
            a = (xpix[u], ypix[u])
            for c in cluster:
                b = (xpix[c], ypix[c])
                dst = distance.euclidean(a, b)
                if dst > 0:
                    l.append(dst)
        distances.append(l)

    cluster_spread = [np.mean(i) for i in distances]
    most = np.argmax(cluster_spread)
    least = np.argmin(cluster_spread)

    # --- plotting ---
    fig = plt.figure(figsize=(16, 4))
    gs = fig.add_gridspec(1, 4)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[0, 2], sharex=ax1, sharey=ax1)
    ax4 = fig.add_subplot(gs[0, 3], sharex=ax1, sharey=ax1)

    ax1.set_title('All cells', fontsize=20)
    ax2.set_title('Cluster density', fontsize=20)
    ax3.set_title('Least dense', fontsize=20)
    ax4.set_title('Most dense', fontsize=20)

    ax1.set_xlim(0, 512)
    ax1.set_ylim(0, 512)

    for i in cells:
        ax1.scatter(xpix[i], ypix[i], color='C0', s=8)

    ax2.hist(cluster_spread, bins=20, color='C1', alpha=0.7)

    for i in final[least]:
        ax3.scatter(xpix[i], ypix[i], color='C0', s=12)
    for i in final[most]:
        ax4.scatter(xpix[i], ypix[i], color='C0', s=12)

    plt.tight_layout()
    plt.show()

    return cluster_spread, least, most

In [None]:
cluster_spread, least_idx, most_idx = plot_cluster_density(xpix, ypix, cells, final)