# Packages

In [None]:
%%capture
!pip install squidpy parmap imantics scanpy igraph leidenalg umap imagecodecs GDAL

In [None]:
%%capture
# packages
import os
import json
import glob
import random
import base64
import typing as tp
import warnings
import pickle

# Data handling and analysis
import numpy as np
import pandas as pd
import scipy
from scipy.special import softmax
from scipy.linalg import sqrtm
from scipy.sparse import csr_matrix
from scipy.io import loadmat

# Machine learning and statistics
import sklearn
from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans
import umap
import igraph as ig
import leidenalg

# Image processing and visualization
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.figure import Figure as _Figure
import seaborn as sns
import PIL
from PIL import Image, ImageSequence, ImageOps
import imagecodecs
import skimage
from skimage import io as tiff
from skimage import measure, morphology
from skimage.segmentation import relabel_sequential
from skimage.measure import label, regionprops
import imageio

# Bioinformatics and spatial data analysis
import scanpy as sc
import anndata
from anndata import AnnData
import squidpy as sq

# Geometry and masks
from shapely.geometry import Polygon
from imantics import Mask

# Progress bar
from tqdm import tqdm

# Parallel processing
import parmap
import networkx as nx
import pathlib

# image processing
from osgeo import gdal

# Data loading and handling funtions

In [None]:
# Helpers
def read_tiff_image(tiff_path):
    return tiff.imread(tiff_path)

def calculate_centroids(mask):
    props = measure.regionprops(mask)
    centroids = {prop.label: prop.centroid for prop in props}
    return centroids


Functions that are specific to the Lung Cancer tissue sample stored in .tif format. Metadata is stored within the data and needs to be extracted. Segmentation mask is stored in the same .tif file as the feature data where nucleus and cytoplasm are seperately labeled.

In [None]:
def extract_data_from_gdal_tiffs(feature_tiff_data, mask):
    """
    Extracts data from a mask array and feature TIFF data for use in REDSEA.
    Assumes that 'feature_tiff_data' is a 3D numpy array with shape (height, width, num_channels).
    """
    if feature_tiff_data.shape[:2] != mask.shape:
        raise ValueError(f"Intensity image spatial dimensions {feature_tiff_data.shape[:2]} must match mask shape {mask.shape}.")

    num_channels = feature_tiff_data.shape[2]  # Number of channels in the feature image
    cell_ids = np.unique(mask[mask != 0])
    data = np.zeros((len(cell_ids), num_channels))
    dataScaleSize = np.zeros_like(data)
    cellSizes = np.zeros((len(cell_ids), 1))

    for stat in regionprops(mask):
        if stat.area > 0:
            idx = cell_ids.tolist().index(stat.label)
            for ch in range(num_channels):
                cell_data = feature_tiff_data[:, :, ch][mask == stat.label]
                data[idx, ch] = cell_data.sum()
                dataScaleSize[idx, ch] = data[idx, ch] / stat.area
            cellSizes[idx] = stat.area

    centroids = calculate_centroids(mask)

    return data, cell_ids, centroids, dataScaleSize, cellSizes

def tif_create_data(feature_img, cell_ids, centroids, data, channels_df, channel_indices):
    """
    Create an AnnData object using specified channels and mask, utilizing precomputed data.
    """

    # Calculate global means and standard deviations only for specified channels
    global_means = np.mean(data, axis=0)
    global_stds = np.std(data, axis=0)

    # Construct DataFrame for observation metadata
    obs_df = pd.DataFrame({
        'roi': cell_ids,
        'X_centroid': [centroids[id][0] for id in cell_ids],
        'Y_centroid': [centroids[id][1] for id in cell_ids]
    })

    # Ensure that the channel names in the DataFrame are indexed correctly
    channel_names = [channels_df.iloc[ch]['channel'] for ch in channel_indices]
    var_df = pd.DataFrame({
        'mean': global_means,
        'std': global_stds
    }, index=channel_names)

    adata = AnnData(X=data, obs=obs_df, var=var_df)
    adata.obsm['spatial'] = np.array(adata.obs[['X_centroid', 'Y_centroid']])

    return adata

# Functions for extracting the metadata and simplifying the mask so that comparison is possible
def read_tiff_all_bands(tiff_path):
    """
    Reads a multi-band TIFF file and stacks all bands into a single 3D numpy array.
    """
    gdal_data = gdal.Open(tiff_path, gdal.GA_ReadOnly)
    if gdal_data is None:
        raise FileNotFoundError("The specified TIFF file could not be opened.")

    num_bands = gdal_data.RasterCount
    # Read the first band to initialize the numpy array
    band = gdal_data.GetRasterBand(1)
    array = band.ReadAsArray()

    if num_bands == 1:
        return array.reshape(array.shape[0], array.shape[1], 1), gdal_data.GetMetadata()  # Ensure 3D even for one band

    # Initialize an empty array with the desired output shape
    img_data = np.zeros((array.shape[0], array.shape[1], num_bands), dtype=array.dtype)
    img_data[:, :, 0] = array  # Place the first band

    for b in range(1, num_bands):
        band = gdal_data.GetRasterBand(b + 1)
        img_data[:, :, b] = band.ReadAsArray()  # Fill subsequent bands

    metadata = gdal_data.GetMetadata()
    return img_data, metadata

def read_single_band(tiff_path, band_number=1):
    """
    Reads a single band from a multi-band TIFF file.
    """
    gdal_data = gdal.Open(tiff_path, gdal.GA_ReadOnly)
    if gdal_data is None:
        raise FileNotFoundError("The specified TIFF file could not be opened.")

    band = gdal_data.GetRasterBand(band_number)
    return band.ReadAsArray()

def preprocess_mask(mask, core_label=2, cytoplasm_label=7, connectivity=2):
    """
    Process the mask to label each unique cell distinctly, combining core and cytoplasm into one label per cell,
    and then iteratively numbering each cell uniquely.
    Assumes labels are used uniformly across all cells.
    The background of the resulting mask should be 0, and each cell's pixels will be labeled with a unique integer.
    """
    # Create a combined mask where both core and cytoplasm are considered as one entity per cell
    combined_cell_mask = np.logical_or(mask == core_label, mask == cytoplasm_label)

    # Label the combined cell regions
    labeled_mask, num_features = label(combined_cell_mask, connectivity=connectivity, return_num=True)

    # Relabel the mask sequentially to ensure continuous numbering starting from 1
    labeled_mask, forward_map, inverse_map = relabel_sequential(labeled_mask)

    return labeled_mask

def crop_tiff_to_match_mask(img_data, mask_shape):
    """
    Crop the multi-channel TIFF data to match the dimensions of the mask.
    """
    # Check if the spatial dimensions of the image data are at least as large as those of the mask
    if img_data.shape[0] >= mask_shape[0] and img_data.shape[1] >= mask_shape[1]:
        # Crop the image data to match the mask's dimensions
        cropped_img_data = img_data[:mask_shape[0], :mask_shape[1], :]
        return cropped_img_data
    else:
        raise ValueError("Image data is smaller than mask dimensions or has incorrect shape.")

Functions specific to the urothelial carcinoma tissue sample stored in .tiff format. Metadata is supplied as a .csv file. Segmentation mask is stored in a separate .tiff file with incremental numbering of different cell objects.

In [None]:
# for use in utag after redsea
def extract_data_from_tiffs(feature_tiff_path, mask, channel_indices):
    """
    Extracts data from mask tiff and feature tiff for use in REDSEA
    Only data from specified channel indices are processed and extracted.
    """
    feature_img = read_tiff_image(feature_tiff_path)
    cell_ids = np.unique(mask[mask != 0])
    num_channels = len(channel_indices)  # Process only the non-noisy channels
    data = np.zeros((len(cell_ids), num_channels))
    dataScaleSize = np.zeros_like(data)
    cellSizes = np.zeros((len(cell_ids), 1))

    for stat in measure.regionprops(mask):
        if stat.area > 0:
            idx = cell_ids.tolist().index(stat.label)
            for i, ch in enumerate(channel_indices):
                cell_data = feature_img[ch, :, :][mask == stat.label]
                data[idx, i] = cell_data.sum()  # Use i to index into data since it's the index in filtered channels
                dataScaleSize[idx, i] = data[idx, i] / stat.area
            cellSizes[idx] = stat.area

    centroids = calculate_centroids(mask)
    return data, cell_ids, centroids, dataScaleSize, cellSizes

def create_anndata(data, cell_ids, centroids, means, stdevs, channels_df, valid_indices):
    """
    Create adata object on REDSEA compensated data ensuring indices are consistent.
    """
    # Create observation DataFrame filtered by valid_indices
    obs_df = pd.DataFrame({
        'roi': cell_ids[valid_indices],
        'X_centroid': [centroids[id][1] for id in cell_ids[valid_indices]],
        'Y_centroid': [centroids[id][0] for id in cell_ids[valid_indices]]
    })

    # Create variable DataFrame using channel data
    var_df = pd.DataFrame({
        'mean': means,
        'std': stdevs
    }, index=channels_df['channel'])

    adata = AnnData(X=data, obs=obs_df, var=var_df)
    adata.obsm['spatial'] = np.array(obs_df[['Y_centroid', 'X_centroid']])
    return adata

# for use in utag without prior application of redsea
def utag_extract_data_from_tiffs(feature_img, mask, channel_indices):
    """
    Extract data from specified channels in a TIFF file for use in creating an AnnData object for UTAG analysis.
    """
    centroids = calculate_centroids(mask)
    cell_ids = np.unique(mask[mask != 0])  # Exclude background
    num_channels = len(channel_indices)
    data = np.zeros((len(cell_ids), num_channels))
    global_means = np.zeros(num_channels)
    global_stds = np.zeros(num_channels)

    for i, channel in enumerate(channel_indices):
        channel_data = feature_img[channel, :, :]
        global_means[i] = np.mean(channel_data[mask != 0])
        global_stds[i] = np.std(channel_data[mask != 0])

        for j, cell_id in enumerate(cell_ids):
            cell_data = channel_data[mask == cell_id]
            data[j, i] = np.mean(cell_data)

    return data, cell_ids, centroids, global_means, global_stds

def utag_create_anndata(feature_img, mask, channels_df):
    """
    Create an AnnData object using specified channels and mask.
    """
    remaining_channels = [i for i in range(len(channels_df)) if channels_df.iloc[i]['channel'] not in noisy_channels]
    data, cell_ids, centroids, global_means, global_stds = utag_extract_data_from_tiffs(feature_img, mask, remaining_channels)

    obs_df = pd.DataFrame({
        'roi': cell_ids,
        'X_centroid': [centroids[id][0] for id in cell_ids],
        'Y_centroid': [centroids[id][1] for id in cell_ids]
    })

    var_df = pd.DataFrame({
        'mean': global_means,
        'std': global_stds
    }, index=[channels_df.iloc[ch]['channel'] for ch in remaining_channels])

    adata = AnnData(X=data, obs=obs_df, var=var_df)
    adata.obsm['spatial'] = np.array(adata.obs[['Y_centroid', 'X_centroid']])

    return adata

#REDSEA functions

In [None]:
def ismember(a, b):
    """
    Determines the indices of the first occurrence of elements of list 'a' in list 'b'.
    If an element of 'a' is not found in 'b','None' is returned for that element.
    """
    index_map = {}
    for i, element in enumerate(b):
        if element not in index_map:
            index_map[element] = i
    return [index_map.get(item, None) for item in a]

def compute_cell_contact_matrix(mask):
    """
    Loads the segmentation mask from a TIFF file, prepares it by removing the background label,
    adds a border for safe access, and computes the cell contact matrix.
    """
    # Load the mask and prepare it
    #mask = read_tiff_image(mask_path)
    unique_labels = np.unique(mask)
    if 0 in unique_labels:
        unique_labels = unique_labels[1:]  # Remove background label if present

    # Create a mapping from label to index
    label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
    cellNum = len(unique_labels)  # This is the number of cells

    # Prepare the mask with a border for safe matrix access
    mask_border = np.pad(mask, pad_width=1, mode='constant', constant_values=0)
    rowNum, colNum = mask_border.shape

    # Initialize the cell-cell shared perimeter matrix container with correct size
    cellPairMap = np.zeros((cellNum, cellNum))

    # Compute the cell-cell contact matrix
    for i in range(1, rowNum - 1):  # Adjust for the border
        for j in range(1, colNum - 1):  # Adjust for the border
            if mask_border[i, j] == 0:
                tempMatrix = mask_border[i-1:i+2, j-1:j+2]  # 3x3 window around the boundary
                tempFactors = np.unique(tempMatrix)
                tempFactors = [label for label in tempFactors if label > 0]  # Filter out the background

                # Update the cellPairMap using mapped indices
                for k in range(len(tempFactors)):
                    for l in range(k + 1, len(tempFactors)):
                        idx_k = label_to_index[tempFactors[k]]
                        idx_l = label_to_index[tempFactors[l]]
                        cellPairMap[idx_k, idx_l] += 1
                        cellPairMap[idx_l, idx_k] += 1  # Ensure symmetry

    np.fill_diagonal(cellPairMap, 0)  # Optionally remove diagonal to ignore self-pairing
    return cellPairMap, cellNum, rowNum, colNum

def process_boundary_signals(mask, countsNoNoise, cellNum, rowNum, colNum, elementShape, elementSize, clusterChannels):
    """
    Process boundary signals in a mask to update cluster channels based on element shape and size.
    """
    MIBIdataNearEdge1 = np.zeros((cellNum,len(clusterChannels)))
    # Pre-calculate shapes based on elementShape
    if elementShape == 1:  # square
        square = skimage.morphology.square(2*elementSize + 1)
        square_loc = np.where(square == 1)
    elif elementShape == 2:  # diamond
        diam = skimage.morphology.diamond(elementSize)
        diam_loc = np.where(diam == 1)
    else:
        print("Error: elementShape value not recognized.")
        return

    # Process each cell in the mask
    for i in range(cellNum):
        label = i + 1
        tempRow, tempCol = np.where(mask == label)

        # Check each point in the cell
        for j in range(len(tempRow)):
            label_in_shape = []  # Empty list for boundary condition check
            # Ensure the operation does not expand outside mask bounds
            if (elementSize - 1 < tempRow[j] < rowNum - elementSize - 2 and
                elementSize - 1 < tempCol[j] < colNum - elementSize - 2):
                ini_point = [tempRow[j] - elementSize, tempCol[j] - elementSize]  # Top-left point

                # Apply the pre-calculated shape
                if elementShape == 1:  # square
                    square_loc_ini_x = [x + ini_point[0] for x in square_loc[0]]
                    square_loc_ini_y = [y + ini_point[1] for y in square_loc[1]]
                    label_in_shape = [mask[x, y] for x, y in zip(square_loc_ini_x, square_loc_ini_y)]

                elif elementShape == 2:  # diamond
                    diam_loc_ini_x = [x + ini_point[0] for x in diam_loc[0]]
                    diam_loc_ini_y = [y + ini_point[1] for y in diam_loc[1]]
                    label_in_shape = [mask[x, y] for x, y in zip(diam_loc_ini_x, diam_loc_ini_y)]

            # Update cluster channels if the condition is met
            if 0 in label_in_shape:
                MIBIdataNearEdge1[i, :] += countsNoNoise[tempRow[j], tempCol[j], :]
    return MIBIdataNearEdge1

def assemble_df(data_array, suffix):
    """
    Assemble dataframes for each data type, used for RedSEA output
    """
    df = pd.DataFrame(data_array, columns=clusterChannels)
    return pd.concat([pd.DataFrame(), df], axis=1)

In [None]:
def redsea_compensation(data, mask, filtered_img_data, channels_df, feature_tiff_path, cell_sizes,
                        element_shape=2, element_size=2, cluster_channels=None):
    """
    Executes the REDSEA  algorithm for spillover compensation
    args:
    elementShape (int): The shape of the element used for boundary detection (default: 2, star; 1: square).
    elementSize (int): The size of the element used (default: 2).
    clusterChannels (array): Specific channels to be normalized. Defaults to all channels if None.
    """

    # Set default cluster_channels if None
    if cluster_channels is None:
        cluster_channels = channels_df['channel']
    normalized_channels = cluster_channels

    # Indexing the chosen channels
    normalized_channels_inds = np.isin(normalized_channels, cluster_channels)
    channel_norm_identity = np.zeros((len(cluster_channels), 1))
    channel_norm_identity[normalized_channels_inds, 0] = 1
    cluster_channels_inds = np.where(np.isin(cluster_channels, channels_df['channel']))[0]

    # Load and process images
    if feature_tiff_path.endswith(".tiff"):
        counts_no_noise = np.transpose(filtered_img_data, (1, 2, 0))
    elif feature_tiff_path.endswith(".tif"):
        counts_no_noise = np.transpose(filtered_img_data, (0, 1, 2))

    # Compute cell contact matrix
    cell_pair_map, cell_num, row_num, col_num = compute_cell_contact_matrix(mask)

    # Calculate boundary effects and normalize
    cell_boundary_totals = np.sum(cell_pair_map, axis=0)
    epsilon = 1e-10  # Adding a small constant to avoid division by zero
    cell_boundary_totals += epsilon
    cell_boundary_total_matrix = np.tile(cell_boundary_totals, (cell_num, 1))
    cell_pair_norm = np.identity(cell_num) - cell_pair_map / cell_boundary_total_matrix

    # Process boundary signals
    boundary_data = process_boundary_signals(mask, counts_no_noise, cell_num, row_num, col_num, element_shape, element_size, cluster_channels)

    # Boundary signal correction and reinforcement
    normalized_data = np.dot(boundary_data.T, cell_pair_norm).T + data
    normalized_data = np.clip(normalized_data, 0, None)  # Ensure non-negative values

    # Composite the normalized channels with non-normalized channels
    rev_channel_norm_identity = 1 - channel_norm_identity
    final_normalized_data = (data * rev_channel_norm_identity.T) + (normalized_data * channel_norm_identity.T)

    # Scale by size and handle cell identities
    data_compen_scale_size = final_normalized_data / cell_sizes
    label_identity = np.ones(cell_num)
    label_identity[np.sum(data_compen_scale_size[:, cluster_channels_inds], axis=1) < 0.1] = 2  # Exclude low-info cells

    # Filter and prepare for output
    valid_indices = label_identity == 1
    data_compen_scale_size_cells = data_compen_scale_size[valid_indices, :]

    # Calculate means and standard deviations for each channel
    means = np.mean(data_compen_scale_size_cells, axis=0)
    stdevs = np.std(data_compen_scale_size_cells, axis=0)

    return data_compen_scale_size_cells, means, stdevs, valid_indices


# PARC functions

In [None]:
### HNSW functions
class Space:
    def __init__(self, dim):
        self.dim = dim

    def distance(self, vec1, vec2):
        raise NotImplementedError("Distance function is not implemented.")

class L2Space(Space):
    def distance(self, vec1, vec2):
        return np.sqrt(np.sum((vec1 - vec2) ** 2))

class InnerProductSpace(Space):
    def distance(self, vec1, vec2):
        # Normalize the vectors to get cosine of the angle
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        if norm1 == 0 or norm2 == 0:
            return 1.0  # As the vectors are zero vectors, making them orthogonal (cosine similarity is 0)
        cosine_similarity = np.dot(vec1, vec2) / (norm1 * norm2)
        return 1 - cosine_similarity

class HNSW:
    def __init__(self, space='l2', dim=16, M=16, ef_construction=100):
        if space == 'l2':
            self.space = L2Space(dim)
        elif space == 'ip':
            self.space = InnerProductSpace(dim)
        else:
            raise ValueError("Unsupported space type. Use 'l2' or 'ip'.")
        self.dim = dim
        self.data = []
        self.ef_construction = ef_construction
        self.M = M
        self.ef = ef_construction  # Default ef for queries
        self.links = {}

    def _random_level(self):
        """
        At which level to insert nodes
        """
        level = 0
        while np.random.rand() < np.exp(-1) and level < self.mL:
            level += 1
        return level

    def init_index(self, max_elements, ef_construction=None, M=None):
        """
        Initializes the structure for a maximum number of elements.
        """
        if ef_construction:
            self.ef_construction = ef_construction
        if M:
            self.M = M
        # Preparing the index structure
        self.links = [{i: [] for i in range(self.mL + 1)} for _ in range(max_elements)]

    def add_items(self, data, ids=None): #adds instead of placing data
        """
        Adds items to the structure, linking them based on their distance.
        """
        if ids is None:
            ids = range(len(self.data), len(self.data) + len(data))
        for point, point_id in zip(data, ids):
            node_level = self._random_level()
            new_node = {'point': point, 'level': node_level}
            self.data.append(new_node)
            self._insert_point(new_node, point_id)

    def _insert_point(self, point_id):
        """
        Inserts a new point and connects it with existing points based on distance.
        """
        if not self.data:
            return
        point = self.data[point_id]
        distances = [self.space.distance(point, self.data[idx]) for idx in range(len(self.data) - 1)]
        nearest_indices = np.argsort(distances)[:self.ef_construction]
        self.links[point_id] = nearest_indices.tolist()
        for idx in nearest_indices:
            self.links[idx].append(point_id)
            self.links[idx] = list(set(self.links[idx]))[:self.M]

    def knn_query(self, data, k=1):
        """
        Queries the structure to find the k-nearest neighbors for given points.
        """
        labels = np.zeros((len(data), k), dtype=int)
        distances = np.zeros((len(data), k), dtype=float)
        for i, point in enumerate(data):
            if not self.data:
                continue
            dists = [self.space.distance(point, self.data[idx]) for idx in range(len(self.data))]
            nearest_indices = np.argsort(dists)[:k]
            labels[i, :] = nearest_indices
            distances[i, :] = np.array(dists)[nearest_indices]
        return labels, distances

In [None]:
### PARC functions
#combining HNSW and Leiden (as https://github.com/vtraag/leidenalg)
class PARC:
    def __init__(self, data, true_label=None, dist_std_local=3, jac_std_global='median',
                 n_iter_leiden=5, random_seed=42, distance='l2', partition_type="ModularityVP",
                 resolution_parameter=1.0, neighbor_graph=None, hnsw_param_ef_construction=150,
                 labels=None):
        self.data = data
        self.true_label = true_label
        self.dist_std_local = dist_std_local
        self.jac_std_global = jac_std_global
        self.n_iter_leiden = n_iter_leiden
        self.random_seed = random_seed
        self.distance = distance
        self.partition_type = partition_type  # Must be "ModularityVP" or "RBVP"
        self.resolution_parameter = resolution_parameter
        self.neighbor_graph = neighbor_graph
        self.hnsw = HNSW(space=distance, dim=data.shape[1], M=16, ef_construction=hnsw_param_ef_construction)
        self.hnsw.init_index(max_elements=len(data), ef_construction=hnsw_param_ef_construction)
        self.labels = labels or {}

    def run_PARC(self):
        if self.neighbor_graph is not None:
            graph = self.construct_graph_using_precomputed_neighbors()
        else:
            graph = self.construct_graph_using_hnsw()

        if self.partition_type == "ModularityVP":
            partition_type = leidenalg.ModularityVertexPartition
            partition_kwargs = {}
        elif self.partition_type == "RBVP":
            partition_type = leidenalg.RBConfigurationVertexPartition
            partition_kwargs = {'resolution_parameter': self.resolution_parameter}
        else:
            raise ValueError("Unsupported partition type")

        # Run the Leiden algorithm to detect communities
        partition = leidenalg.find_partition(graph, partition_type, weights=None,
                                             n_iterations=self.n_iter_leiden,
                                             seed = self.random_seed,
                                             **partition_kwargs)
        self.labels = np.array(partition.membership)

        return self.labels

    def construct_graph_using_precomputed_neighbors(self):
        n_cells = self.neighbor_graph.shape[0]
        row_list, col_list, weight_list = [], [], []

        for i in range(n_cells):
            neighbors = self.neighbor_graph.indices[self.neighbor_graph.indptr[i]:self.neighbor_graph.indptr[i+1]]
            distances = self.neighbor_graph.data[self.neighbor_graph.indptr[i]:self.neighbor_graph.indptr[i+1]]
            to_keep = distances < (np.mean(distances) + self.dist_std_local * np.std(distances))

            for j, keep in enumerate(to_keep):
                if keep:
                    row_list.append(i)
                    col_list.append(neighbors[j])
                    weight_list.append(1 / (distances[j] + 0.1))  # Inverse distance as weight

        sparse_graph = scipy.sparse.csr_matrix((weight_list, (row_list, col_list)), shape=(n_cells, n_cells))
        return self.apply_global_pruning(sparse_graph)

    def construct_graph_using_hnsw(self):
        self.hnsw.add_items(self.data)
        neighbor_array, distance_array = self.hnsw.knn_query(self.data, k=self.knn)

        n_cells = len(self.data)
        row_list, col_list, weight_list = [], [], []

        for i in range(n_cells):
            dists = distance_array[i]
            rows = neighbor_array[i]
            to_keep = dists < (np.mean(dists) + self.dist_std_local * np.std(dists))

            for j, keep in enumerate(to_keep):
                if keep and i != rows[j]:  # Exclude self-loops
                    row_list.append(i)
                    col_list.append(rows[j])
                    weight_list.append(1 / (dists[j] + 0.1))  # Inverse distance as weight

        sparse_graph = scipy.sparse.csr_matrix((weight_list, (row_list, col_list)), shape=(n_cells, n_cells))
        return self.apply_global_pruning(sparse_graph)

    def apply_global_pruning(self, graph):
        # Convert the sparse matrix to an adjacency list for igraph
        sources, targets = graph.nonzero()
        weights = graph.data
        G = ig.Graph(edges=list(zip(sources, targets)), directed=False)
        G.es['weight'] = weights

        # Calculate Jaccard similarities and apply threshold
        jaccard_similarities = G.similarity_jaccard(pairs=G.get_edgelist())
        if self.jac_std_global == 'median':
            threshold = np.median(jaccard_similarities)
        else:
            threshold = np.mean(jaccard_similarities) - self.jac_std_global * np.std(jaccard_similarities)

        # Delete edges below the similarity threshold
        edges_to_delete = [idx for idx, sim in enumerate(jaccard_similarities) if sim <= threshold]
        G.delete_edges(edges_to_delete)

        # Optionally simplify the graph to merge multiple edges and remove loops
        G.simplify(multiple=True, loops=False)

        return G

    def convert_to_igraph(self, graph):
        sources, targets = graph.nonzero()
        G = ig.Graph(edges=list(zip(sources, targets)), directed=False)
        G.es['weight'] = graph.data
        return G

# UTAG functions

In [None]:
### UTAG helpers
def message_passing(adata, mode="l1_norm"):
    """
    Applies a specified normalization to the adjacency matrix and propagates the data through it.
    """
    if mode == "l1_norm":
        connectivity_matrix = adata.obsp["spatial_connectivities"]
        modified_matrix = np.asarray(connectivity_matrix + np.eye(connectivity_matrix.shape[0]))
        affinity = normalize(modified_matrix, axis=1, norm="l1")
    else:
        affinity = adata.obsp["spatial_connectivities"]
    adata.X = affinity @ adata.X
    return adata

def add_probabilities_to_centroid(adata, column):
    """
    Calculates and adds probability scores to the anndata object based on centroids.
    """
    output_key = f"{column}_probabilities"
    # Compute the mean of each group after z-score normalization directly within the pandas workflow
    df = adata.to_df()
    normalized_data = (df - df.min()) / (df.max() - df.min())
    group_means = normalized_data.groupby(adata.obs[column]).mean()
    probabilities = softmax(df.dot(group_means.T), axis=1)
    adata.obsm[output_key] = probabilities
    return adata

def cluster_results(data, method, resolution, save_key):
    """
    Conducts clustering on the data using specified method and resolution, adding the results to data.obs.
    """
    cluster_key = f"{save_key}_{method.lower()}_{resolution}"
    if method == 'LEIDEN':
        sc.tl.leiden(data, resolution=resolution, key_added=cluster_key)
    elif method == 'PARC':
        model = PARC(data.obsm["X_pca"], neighbor_graph=data.obsp["connectivities"], resolution_parameter=resolution)
        model.run_PARC()
        data.obs[cluster_key] = pd.Categorical(model.labels)
    add_probabilities_to_centroid(data, cluster_key)

In [None]:
def utag(adata,
         slide_key="slide",
         save_key="UTAG Label",
         max_dist=20.0,
         normalization_mode="l1_norm",
         pca_kwargs={'n_comps': 10},
         clustering_methods=["leiden", "parc"],
         resolutions=[0.05, 0.1, 0.3, 1.0]):
    """
    Processes and clusters single-cell imaging data by slides (if specified), using PCA for dimensionality reduction
    followed by specified clustering methods.
    """
    ad = adata.copy()

    # Validate clustering methods
    clustering_methods = [method.lower() for method in clustering_methods]
    assert all(method in ["leiden", "parc"] for method in clustering_methods), "Unsupported clustering method provided."

    if slide_key and slide_key in ad.obs:
        slide_data = {slide: ad[ad.obs[slide_key] == slide].copy() for slide in ad.obs[slide_key].unique()}
    else:
        slide_data = {None: ad}  # Process all data together if no slide_key provided

    # Process each slide separately
    for slide, data in slide_data.items():
        print(f"Processing slide: {slide}")
        sq.gr.spatial_neighbors(data, radius=max_dist, coord_type="generic", set_diag=True)
        data = message_passing(data, mode=normalization_mode)
        sc.tl.pca(data, **pca_kwargs)
        sc.pp.neighbors(data)

        for resolution in resolutions:
            for method in clustering_methods:
                cluster_results(data, method, resolution, save_key)

    # Combine slide results if processed separately
    ad_result = anndata.concat(slide_data.values()) if slide_key in ad.obs else slide_data[None]

    return ad_result