In [None]:
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import warnings
import seaborn as sns
import anndata as ad

warnings.filterwarnings("ignore")

In [None]:
adata = sc.read("../data/adata/human.h5ad")

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
def get_expression(adata: ad.AnnData, key: str) -> np.ndarray:
    """
    Retrieves expression values for a given gene or observation annotation from an AnnData object.

    Args:
        adata: An AnnData object containing expression data.
        key: The name of the gene or observation annotation to retrieve.

    Returns:
        A NumPy array containing the expression values.

    Raises:
        ValueError: If the key is not found in either the var_names or obs columns of the AnnData object.
    """

    if key in adata.var_names:
        return np.array(adata[:, key].X.flatten())
    elif key in adata.obs.columns:
        return np.array(adata.obs[key])
    else:
        raise ValueError(f"{key} not found in object")

In [None]:
from scipy.spatial import distance


def get_closest_cell(adata: ad.AnnData, subtype_1: str, subtype_2: str) -> np.ndarray:
    """
    Finds the closest cell of a specific subtype to each cell of another subtype.

    Args:
        adata: An AnnData object containing spatial coordinates and subtype annotations.
        subtype_1: The first subtype to consider.
        subtype_2: The second subtype to consider.

    Returns:
        A NumPy array containing the minimum distance to the closest cell in the second subtype for each cell in the first subtype.

    Raises:
        ValueError: If either subtype is not found in the adata object.
    """

    if subtype_1 not in adata.obs["Immunocentric_Type"].unique():
        raise ValueError(f"Immunocentric_Type {subtype_1} not found in adata")
    if subtype_2 not in adata.obs["Immunocentric_Type"].unique():
        raise ValueError(f"Immunocentric_Type {subtype_2} not found in adata")

    locations_1 = adata[adata.obs["Immunocentric_Type"] == subtype_1].obsm["X_spatial"]
    locations_2 = adata[adata.obs["Immunocentric_Type"] == subtype_2].obsm["X_spatial"]

    distances_subtype = distance.cdist(locations_1, locations_2).min(axis=1)
    return distances_subtype

In [None]:
from scipy import stats


def correlation_between_distance_and_expression(
    adata: ad.AnnData, subtype: str, key: str, method: str = "spearman"
) -> pd.DataFrame:
    """
    Calculates correlation between expression of a given gene/annotation
    and distance to cells of other subtypes for a specific subtype.

    Args:
        adata: An AnnData object containing spatial coordinates, subtype annotations, and expression data.
        subtype: The subtype to focus on for expression and distance calculations.
        key: The name of the gene or observation annotation to retrieve expression values for.
        method: The correlation method to use, either "pearson" or "spearman" (default).

    Returns:
        A pandas DataFrame with columns 'subtype_1', 'subtype_2', 'pvalue', and 'correlation',
        representing the subtype pairs, p-values, and correlation coefficients.
    Raises:
        ValueError: If either subtype is not found in the adata object or if an invalid method is specified.
    """

    if subtype not in adata.obs["Immunocentric_Type"].unique():
        raise ValueError(f"Immunocentric_Type {subtype} not found in adata")

    allowed_methods = ["pearson", "spearman"]
    if method not in allowed_methods:
        raise ValueError(
            f"Invalid correlation method: {method}. Allowed methods are: {', '.join(allowed_methods)}"
        )

    results = []
    for subtype_2 in adata.obs["Immunocentric_Type"].unique():
        distances = get_closest_cell(adata, subtype_1=subtype, subtype_2=subtype_2)
        expression = get_expression(
            adata[adata.obs["Immunocentric_Type"] == subtype], key=key
        )

        if method == "pearson":
            corr, pval = stats.pearsonr(distances, expression)
        else:
            corr, pval = stats.spearmanr(distances, expression)

        results.append(
            {
                "subtype_1": subtype,
                "subtype_2": subtype_2,
                "pvalue": pval,
                "correlation": corr,
            }
        )

    return pd.DataFrame(results)

In [None]:
def get_batchwise_correlation_between_distance_and_expression(
    adata: ad.AnnData, subtype: str, key: str, method: str = "spearman"
) -> pd.DataFrame:
    """
    Calculates correlation between distance and expression for a specific subtype across batches,
    combining results into a single DataFrame.

    Args:
        adata: An AnnData object containing spatial coordinates, subtype annotations, expression data, and batch information.
        subtype: The subtype to focus on for expression and distance calculations.
        key: The name of the gene or observation annotation to retrieve expression values for.

    Returns:
        A pandas DataFrame containing correlation results for all batches,
        with columns 'subtype_1', 'subtype_2', 'pvalue', 'correlation', and 'batch'.
    """

    results = []
    for b in adata.obs["batch"].cat.categories:
        adata_batch = adata[adata.obs["batch"] == b]
        df = correlation_between_distance_and_expression(
            adata_batch, subtype=subtype, key=key, method=method
        )
        df["batch"] = b
        results.append(df)

    df = pd.concat(results, ignore_index=True)
    df["batch"] = pd.Categorical(
        df["batch"], categories=adata.obs["batch"].cat.categories
    )
    return df

# Figure 5g

### Get Data

In [None]:
genes = ["SLAMF6", "GZMK", "TCF7", "EOMES", "KLF2", "ITGAE", "GZMA", "NT5E"]

df = []
for gene in genes:
    df_gene = get_batchwise_correlation_between_distance_and_expression(
        adata,
        "CD8 T-Cell",
        gene,
    )
    df_gene["gene"] = gene
    df.append(df_gene)
df = pd.concat(df)

df.head()

### Make plot

In [None]:
df = df[~(df["subtype_1"] == df["subtype_2"])]
mat = df.groupby(by=["gene", "subtype_2"])["correlation"].mean().reset_index()
mat = mat.pivot(index="subtype_2", columns="gene", values="correlation")

from scipy.cluster import hierarchy


def get_order(x):
    link = hierarchy.linkage(x)
    idx = hierarchy.leaves_list(hierarchy.optimal_leaf_ordering(link, x))
    return idx


fig, ax = plt.subplots(figsize=(4, 5))
sns.heatmap(
    mat.iloc[get_order(mat), get_order(mat.T)],
    cmap="coolwarm_r",
    annot=False,
    linewidths=0.5,
    ax=ax,
)
fig.tight_layout()