In [103]:
import scipy
import squidpy as sq
import scanpy as sc
import pandas as pd
import numpy as np
import typing as tp
import parmap # need to install parmap
import anndata
from tqdm import tqdm


def z_score(x):
    """
    Scale (divide by standard deviation) and center (subtract mean) 
    array-like objects.
    """
    return (x - x.min()) / (x.max() - x.min())

def sparse_matrix_dstack(
    matrices: tp.Sequence[scipy.sparse.csr_matrix],
) -> scipy.sparse.csr_matrix:
    """
    Diagonally stack sparse matrices.
    """

    n = sum([x.shape[0] for x in matrices])
    _res = list()
    i = 0
    for x in matrices:
        v = scipy.sparse.csr_matrix((x.shape[0], n))
        v[:, i : i + x.shape[0]] = x
        _res.append(v)
        i += x.shape[0]
    return scipy.sparse.vstack(_res)

def _parallel_message_pass(
    ad,
    radius: int,
    coord_type: str,
    set_diag: bool,
    mode: str,
):
    sq.gr.spatial_neighbors(ad, radius=radius, coord_type=coord_type, 
                            set_diag=set_diag)
    ad = custom_message_passing(ad, mode=mode)
    return ad


def custom_message_passing(adata, mode: str = "l1_norm"):
    # from scipy.linalg import sqrtm
    # import logging
    if mode == "l1_norm":
        A = adata.obsp["spatial_connectivities"]
        from sklearn.preprocessing import normalize
        affinity = normalize(A, axis=1, norm="l1")
    else:
        # Plain A_mod multiplication
        A = adata.obsp["spatial_connectivities"]
        affinity = A
    # logging.info(type(affinity))
    adata.X = affinity @ adata.X
    return adata


def low_variance_filter(adata):
    return adata[:, adata.var["std"] > adata.var["std"].median()]


def add_probabilities_to_centroid(
    adata, col: str, name_to_output: str = None
):
    from scipy.special import softmax

    if name_to_output is None:
        name_to_output = col + "_probabilities"

    mean = z_score(adata.to_df()).groupby(adata.obs[col]).mean()
    probs = softmax(adata.to_df() @ mean.T, axis=1)
    adata.obsm[name_to_output] = probs
    return adata

def utag(
    adata,
    channels_to_use = None,
    slide_key = "Slide",
    save_key: str = "UTAG Label",
    filter_by_variance: bool = False,
    max_dist: float = 20.0,
    normalization_mode: str = "l1_norm",
    keep_spatial_connectivity: bool = False,
    n_pcs = 10,
    apply_umap: bool = False,
    umap_kwargs: tp.Dict[str, tp.Any] = dict(),
    apply_clustering: bool = True,
    clustering_method: tp.Sequence[str] = ["leiden"],
    resolutions: tp.Sequence[float] = [0.05, 0.1, 0.3, 1.0],
    leiden_kwargs: tp.Dict[str, tp.Any] = None,
    parc_kwargs: tp.Dict[str, tp.Any] = None,
    parallel: bool = False,
    processes: int = 1,
    k=15,
    random_state=42,
):
    """
    Discover tissue architechture in single-cell imaging data
    by combining phenotypes and positional information of cells.

    Parameters
    ----------
    adata: AnnData
        AnnData object with spatial positioning of cells in obsm 'spatial' slot.
    channels_to_use: Optional[Sequence[str]]
        An optional sequence of strings used to subset variables to use.
        Default (None) is to use all variables.
    max_dist: float
        Maximum distance to cut edges within a graph.
        Should be adjusted depending on resolution of images.
        For imaging mass cytometry, where resolution is 1um, 20 often 
        gives good results. Default is 20.
    slide_key: {str, None}
        Key of adata.obs containing information on the batch structure 
        of the data. In general, for image data this will often be a variable 
        indicating the image so image-specific effects are removed from data.
        Default is "Slide".
    save_key: str
        Key to be added to adata object holding the UTAG clusters.
        Depending on the values of `clustering_method` and `resolutions`,
        the final keys will be of the form: {save_key}_{method}_{resolution}".
        Default is "UTAG Label".
    filter_by_variance: bool
        Whether to filter vairiables by variance.
        Default is False, which keeps all variables.
    max_dist: float
        Recommended values are between 20 to 50 depending on magnification.
        Default is 20.
    normalization_mode: str
        Method to normalize adjacency matrix.
        Default is "l1_norm", any other value will not use normalization.
    keep_spatial_connectivity: bool
        Whether to keep sparse matrices of spatial connectivity and distance 
        in the obsp attribute of the resulting anndata object. This could be 
        useful in downstream applications.
        Default is not to (False).
    pca_kwargs: Dict[str, Any]
        Keyword arguments to be passed to scanpy.pp.pca for dimensionality 
        reduction after message passing.
        Default is to pass n_comps=10, which uses 10 Principal Components.
    apply_umap: bool
        Whether to build a UMAP representation after message passing.
        Default is False.
    umap_kwargs: Dict[str, Any]
        Keyword arguments to be passed to scanpy.tl.umap for dimensionality 
        reduction after message passing. Default is 10.0.
    apply_clustering: bool
        Whether to cluster the message passed matrix.
        Default is True.
    clustering_method: Sequence[str]
        Which clustering method(s) to use for clustering of the 
        message passed matrix. Default is ["leiden"].
    resolutions: Sequence[float]
        What resolutions should the methods in `clustering_method` be run at.
        Default is [0.05, 0.1, 0.3, 1.0].
    leiden_kwargs: dict[str, Any]
        Keyword arguments to pass to scanpy.tl.leiden.
    parc_kwargs: dict[str, Any]
        Keyword arguments to pass to parc.PARC.
    parallel: bool
        Whether to run message passing part of algorithm in parallel.
        Will accelerate the process but consume more memory.
        Default is True.
    processes: int
        Number of processes to use in parallel.
        Default is to use all available (-1).

    Returns
    -------
    adata: AnnData
        AnnData object with UTAG domain predictions for each cell in adata.obs,
        column `save_key`.
    """
    ad = adata.copy()

    if channels_to_use:
        ad = ad[:, channels_to_use]

    if filter_by_variance:
        ad = low_variance_filter(ad)

    if isinstance(clustering_method, list):
        clustering_method = [m.upper() for m in clustering_method]
    elif isinstance(clustering_method, str):
        clustering_method = [clustering_method.upper()]
    else:
        print(
            """Invalid Clustering Method. Clustering Method Should Either be 
            a string or a list"""
        )
        return
    assert all(m in ["LEIDEN", "KMEANS"] for m in clustering_method)

    if "KMEANS" in clustering_method:
        from sklearn.cluster import KMeans

    print("Applying UTAG Algorithm...")
    if slide_key:
        ads = [
            ad[ad.obs[slide_key] == slide].copy() for slide in ad.obs[slide_key].unique()
        ]
        ad_list = parmap.map(
            _parallel_message_pass,
            ads,
            radius=max_dist,
            coord_type="generic",
            set_diag=True,
            mode=normalization_mode,
            pm_pbar=True,
            pm_parallel=parallel,
            pm_processes=processes,
        )
        ad_result = anndata.concat(ad_list)
        if keep_spatial_connectivity:
            ad_result.obsp["spatial_connectivities"] = sparse_matrix_dstack(
                [x.obsp["spatial_connectivities"] for x in ad_list]
            )
            ad_result.obsp["spatial_distances"] = sparse_matrix_dstack(
                [x.obsp["spatial_distances"] for x in ad_list]
            )
    else:
        sq.gr.spatial_neighbors(ad, radius=max_dist, coord_type="generic", 
                                set_diag=True)
        ad_result = custom_message_passing(ad, mode=normalization_mode)

    if apply_clustering:
        # if "n_comps" in pca_kwargs:
        #     if pca_kwargs["n_comps"] > ad_result.shape[1]:
        #         pca_kwargs["n_comps"] = ad_result.shape[1] - 1
        #         print(
        #             f"Overwriding provided number of PCA dimensions to match number of features: {pca_kwargs['n_comps']}"
        #         )
        if n_pcs == 0:
            print("0 components")
            sc.pp.neighbors(ad_result, n_pcs=0, n_neighbors=k, random_state=random_state)
        else:
            print("execute with principal components")
            sc.tl.pca(ad_result, n_comps=n_pcs)
            sc.pp.neighbors(ad_result, n_neighbors=k, n_pcs=n_pcs, random_state=random_state)

        if apply_umap:
            print("Running UMAP on Input Dataset...")
            sc.tl.umap(ad_result, **umap_kwargs)

        for resolution in resolutions:

            res_key1 = save_key + "_leiden_" + str(resolution)
            res_key3 = save_key + "_kmeans_" + str(resolution)
            if "LEIDEN" in clustering_method:
                print(f"Applying Leiden Clustering at Resolution: {resolution}...")
                kwargs = dict()
                kwargs.update(leiden_kwargs or {})
                sc.tl.leiden(
                    ad_result, resolution=resolution, key_added=res_key1, **kwargs
                )
                add_probabilities_to_centroid(ad_result, res_key1)

            if "KMEANS" in clustering_method:
                print(f"Applying K-means Clustering at Resolution: {resolution}...")
                k = int(np.ceil(resolution * 10))
                kmeans = KMeans(n_clusters=k, random_state=1).fit(ad_result.obsm["X_pca"])
                ad_result.obs[res_key3] = pd.Categorical(kmeans.labels_.astype(str))
                add_probabilities_to_centroid(ad_result, res_key3)

    return ad_result

In [104]:
import multiprocessing
multiprocessing.set_start_method('fork', force=True)

def run_utag_clustering(
        adata,
        features=None,
        k=15,
        resolution=1,
        max_dist=20,
        n_pcs=10,
        random_state=42,
        n_jobs=1,
        n_iterations=5,
        slide_key="Slide",
        layer=None,
        output_annotation="UTAG",
        associated_table=None,
        parallel=False,
        **kwargs
):
    """
    Run UTAG clustering on the AnnData object.

    Parameters
    ----------
    adata : anndata.AnnData
        The AnnData object.
    features : list
        List of features to use for clustering or for PCA. Default 
        (None) is to use all.
    k : int
        The number of nearest neighbor to be used in creating the graph.
        Default is 15.
    resolution : float
        Resolution parameter for the clustering, higher resolution produces 
        more clusters. Default is 1.
    max_dist : float
        Maximum distance to cut edges within a graph. Default is 20.
    n_principal_components : int
        Number of principal components to use for clustering.
    random_state : int
        Random state for reproducibility.
    n_jobs : int
        Number of jobs to run in parallel. Default is 5.
    n_iterations : int
        Number of iterations for the clustering.
    slide_key: str
        Key of adata.obs containing information on the batch structure 
        of the data.In general, for image data this will often be a variable 
        indicating the imageb so image-specific effects are removed from data.
        Default is "Slide".

    Returns
    -------
    adata : anndata.AnnData
        Updated AnnData object with clustering results.
    """
    resolutions = [resolution]
    
    
    if not isinstance(k, int) or k <= 0:
        raise ValueError("`k` must be a positive integer")

    if random_state is not None:
        np.random.seed(random_state)

    adata_utag = adata.copy()
    
    utag_results = utag(
        adata_utag,
        slide_key=slide_key,
        max_dist=max_dist,
        normalization_mode='l1_norm',
        apply_clustering=True,
        clustering_method="leiden",
        resolutions=resolutions,
        leiden_kwargs={"n_iterations": n_iterations, 
                       "random_state": random_state},
        n_pcs=n_pcs,
        parallel=parallel,
        processes=n_jobs,
        k=k,
    )

    curClusterCol = 'UTAG Label_leiden_' + str(resolution)
    cluster_list = utag_results.obs[curClusterCol].copy()
    adata.obs[output_annotation] = cluster_list.copy()
    adata.uns["utag_features"] = features

In [105]:
adata = sc.read("/Users/bombina2/github/multiplex-analysis-web-apps/input/healthy_lung_adata.h5ad")

run_utag_clustering(
        adata,
        features=None,
        k=15,
        resolution=0.1,
        max_dist=20,
        n_pcs=0,
        random_state=42,
        n_jobs=1,
        n_iterations=1,
        slide_key="roi",
        layer=None,
        output_annotation="UTAG",
        associated_table=None,
        parallel = False)

print(adata)

first_run_clusters = adata.obs['UTAG'].copy()
run1 = list(adata.obs["UTAG"].copy())

Applying UTAG Algorithm...


100%|██████████| 26/26 [00:00<00:00, 31.91it/s]


0 components
Applying Leiden Clustering at Resolution: 0.1...
AnnData object with n_obs × n_vars = 71946 × 33
    obs: 'sample', 'obj_id', 'X_centroid', 'Y_centroid', 'roi', 'Diseased State', 'Age', 'Patient', 'Image Location', 'cell type', 'slide', 'topological_domain', 'id', 'domain', 'UTAG Label', 'UTAG'
    var: 'mean', 'std'
    uns: 'UTAG Label_colors', 'cell type_colors', 'cell_type_colors', 'domain_colors', 'neighbors', 'pca', 'umap', 'utag_features'
    obsm: 'X_pca', 'X_umap', 'spatial'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'


In [106]:
print(first_run_clusters)

0        5
1        5
2        5
3        0
4        0
        ..
71941    0
71942    0
71943    0
71944    0
71945    0
Name: UTAG, Length: 71946, dtype: category
Categories (10, object): ['0', '1', '2', '3', ..., '6', '7', '8', '9']


In [107]:
print(run1)

['5', '5', '5', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '5', '0', '0', '0', '0', '0', '5', '0', '0', '0', '0', '0', '0', '5', '0', '0', '5', '5', '0', '0', '5', '0', '5', '0', '0', '0', '0', '0', '0', '0', '0', '0', '5', '0', '0', '0', '0', '5', '0', '5', '0', '0', '0', '5', '0', '5', '0', '0', '0', '0', '5', '0', '5', '0', '0', '0', '0', '0', '5', '0', '0', '0', '0', '5', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '5', '5', '0', '0', '0', '0', '5', '5', '0', '5', '0', '0', '0', '0', '5', '0', '0', '0', '0', '0', '0', '0', '0', '5', '0', '0', '0', '5', '5', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '5', '5', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '5', '0', '0', '0', '0', '5', '0', '0', '0', '5', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '5', '0', '0', '0', '0', '5', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '5', '5', '0', '0', '0', '0', '5', '0', '0', '0',

In [108]:
del adata

In [109]:
adata = sc.read("/Users/bombina2/github/multiplex-analysis-web-apps/input/healthy_lung_adata.h5ad")

In [110]:
run_utag_clustering(
        adata,
        features=None,
        k=15,
        resolution=0.1,
        max_dist=20,
        n_pcs=2,
        random_state=42,
        n_jobs=1,
        n_iterations=1,
        slide_key="roi",
        layer=None,
        output_annotation="UTAG",
        associated_table=None,
        parallel = False)

Applying UTAG Algorithm...


100%|██████████| 26/26 [00:00<00:00, 31.82it/s]


execute with principal components
Applying Leiden Clustering at Resolution: 0.1...


In [111]:
second_run_clusters = adata.obs['UTAG'].copy()
run2 = list(adata.obs["UTAG"].copy())

In [112]:
second_run_clusters

0        5
1        5
2        5
3        0
4        0
        ..
71941    0
71942    0
71943    0
71944    0
71945    0
Name: UTAG, Length: 71946, dtype: category
Categories (10, object): ['0', '1', '2', '3', ..., '6', '7', '8', '9']

In [113]:
run2

['5',
 '5',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '5',
 '5',
 '0',
 '0',
 '5',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '5',
 '0',
 '0',
 '0',
 '5',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '5',
 '0',
 '0',
 '0',
 '0',
 '5',
 '5',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '5',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '5',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0',
 '0'

In [114]:
run1 == run2

True