In [None]:
%pip install git+https://github.com/maximilian-heeg/UCell.git

In [None]:
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import warnings
import seaborn as sns

warnings.filterwarnings("ignore")

In [None]:
import matplotlib.colors as clr

zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]

colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
adata = sc.read_h5ad("../data/adata/tgfb.h5ad")

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
ax = sns.histplot(np.sum(adata.X > 0, axis=1)[adata.obs["Subtype"] == "Cd8_T-Cell_P14"])
ax.set_xlabel("Number of expressed genes")

In [None]:
signatures = {
    "DEG": [
        "Itgae+",
        "Cxcr6+",
        "Cd160+",
        "P2rx7+",
        "Klf2-",
        "Il18rap-",
        "S100a4-",
        "Mki67-",
    ]
}
import ucell

ucell.add_scores(adata, signatures, maxRank=100, seed=42)

In [None]:
import numpy as np
import anndata as ad


def get_expression(adata: ad.AnnData, key: str) -> np.ndarray:
    """
    Retrieves expression values for a given gene or observation annotation from an AnnData object.

    Args:
        adata: An AnnData object containing expression data.
        key: The name of the gene or observation annotation to retrieve.

    Returns:
        A NumPy array containing the expression values.

    Raises:
        ValueError: If the key is not found in either the var_names or obs columns of the AnnData object.
    """

    if key in adata.var_names:
        return np.array(adata[:, key].X.flatten())
    elif key in adata.obs.columns:
        return np.array(adata.obs[key])
    else:
        raise ValueError(f"{key} not found in object")

In [None]:
from scipy.spatial import distance


def get_closest_cell(adata: ad.AnnData, subtype_1: str, subtype_2: str) -> np.ndarray:
    """
    Finds the closest cell of a specific subtype to each cell of another subtype.

    Args:
        adata: An AnnData object containing spatial coordinates and subtype annotations.
        subtype_1: The first subtype to consider.
        subtype_2: The second subtype to consider.

    Returns:
        A NumPy array containing the minimum distance to the closest cell in the second subtype for each cell in the first subtype.

    Raises:
        ValueError: If either subtype is not found in the adata object.
    """

    if subtype_1 not in adata.obs["Subtype"].unique():
        raise ValueError(f"Subtype {subtype_1} not found in adata")
    if subtype_2 not in adata.obs["Subtype"].unique():
        raise ValueError(f"Subtype {subtype_2} not found in adata")

    locations_1 = adata[adata.obs["Subtype"] == subtype_1].obsm["X_spatial"]
    locations_2 = adata[adata.obs["Subtype"] == subtype_2].obsm["X_spatial"]

    distances_subtype = distance.cdist(locations_1, locations_2).min(axis=1)
    return distances_subtype

In [None]:
from scipy import stats
import pandas as pd


def correlation_between_distance_and_expression(
    adata: ad.AnnData, subtype: str, key: str, method: str = "spearman"
) -> pd.DataFrame:
    """
    Calculates correlation between expression of a given gene/annotation
    and distance to cells of other subtypes for a specific subtype.

    Args:
        adata: An AnnData object containing spatial coordinates, subtype annotations, and expression data.
        subtype: The subtype to focus on for expression and distance calculations.
        key: The name of the gene or observation annotation to retrieve expression values for.
        method: The correlation method to use, either "pearson" or "spearman" (default).

    Returns:
        A pandas DataFrame with columns 'subtype_1', 'subtype_2', 'pvalue', and 'correlation',
        representing the subtype pairs, p-values, and correlation coefficients.
    Raises:
        ValueError: If either subtype is not found in the adata object or if an invalid method is specified.
    """

    if subtype not in adata.obs["Subtype"].unique():
        raise ValueError(f"Subtype {subtype} not found in adata")

    allowed_methods = ["pearson", "spearman"]
    if method not in allowed_methods:
        raise ValueError(
            f"Invalid correlation method: {method}. Allowed methods are: {', '.join(allowed_methods)}"
        )

    results = []
    for subtype_2 in adata.obs["Subtype"].unique():
        distances = get_closest_cell(adata, subtype_1=subtype, subtype_2=subtype_2)
        expression = get_expression(adata[adata.obs["Subtype"] == subtype], key=key)

        if method == "pearson":
            corr, pval = stats.pearsonr(distances, expression)
        else:
            corr, pval = stats.spearmanr(distances, expression)

        results.append(
            {
                "subtype_1": subtype,
                "subtype_2": subtype_2,
                "pvalue": pval,
                "correlation": corr,
            }
        )

    return pd.DataFrame(results)

In [None]:
def get_batchwise_correlation_between_distance_and_expression(
    adata: ad.AnnData, subtype: str, key: str, method: str = "spearman"
) -> pd.DataFrame:
    """
    Calculates correlation between distance and expression for a specific subtype across batches,
    combining results into a single DataFrame.

    Args:
        adata: An AnnData object containing spatial coordinates, subtype annotations, expression data, and batch information.
        subtype: The subtype to focus on for expression and distance calculations.
        key: The name of the gene or observation annotation to retrieve expression values for.

    Returns:
        A pandas DataFrame containing correlation results for all batches,
        with columns 'subtype_1', 'subtype_2', 'pvalue', 'correlation', and 'batch'.
    """

    results = []
    for b in adata.obs["batch"].cat.categories:
        adata_batch = adata[adata.obs["batch"] == b]
        df = correlation_between_distance_and_expression(
            adata_batch, subtype=subtype, key=key, method=method
        )
        df["batch"] = b
        results.append(df)

    df = pd.concat(results, ignore_index=True)
    df["batch"] = pd.Categorical(
        df["batch"], categories=adata.obs["batch"].cat.categories
    )
    return df

In [None]:
def get_ks_statistics(adata: ad.AnnData, subtype: str) -> pd.DataFrame:
    """
    Compares the distribution of distances to the closest cell of a given subtype
    across two batches in an AnnData object.

    Args:
        adata: An AnnData object containing the data.
        subtype: The subtype of interest.

    Returns:
        A pandas DataFrame containing the following columns:
            - subtype_1: The first subtype being compared.
            - subtype_2: The second subtype being compared.
            - ks: The Kolmogorov-Smirnov statistic.
            - p: The p-value for the Kolmogorov-Smirnov test.
            - batch1-batch2: The median difference in distances to the closest cell
                between the two batches.

    Raises:
        ValueError: If the specified subtype is not found in the data or there are not
            exactly two batches present.

    Examples:
        >>> results = get_ks_statistics(adata, "subtype1")
        >>> print(results)
    """
    from scipy import stats

    if subtype not in adata.obs["Subtype"].unique():
        raise ValueError(f"Subtype {subtype} not found in adata")

    batches = adata.obs["batch"].cat.categories
    if len(batches) != 2:
        raise ValueError(f"There must be exactly two batches")

    print(f"Comparing {batches[0]} and {batches[1]}")

    results = []
    for subtype_2 in adata.obs["Subtype"].unique():

        distances = {}
        for b in batches:
            adata_batch = adata[adata.obs["batch"] == b]
            distances[b] = get_closest_cell(
                adata_batch, subtype_1=subtype, subtype_2=subtype_2
            )

        stat, p = stats.ks_2samp(distances[batches[0]], distances[batches[1]])
        diff = np.median(distances[batches[0]]) - np.median(distances[batches[1]])

        results.append(
            {
                "subtype_1": subtype,
                "subtype_2": subtype_2,
                "ks": stat,
                "p": p,
                batches[0] + "-" + batches[1]: diff,
            }
        )

    return pd.DataFrame(results)

In [None]:
def make_ks_plot(df_ks, order, ax):
    ax.grid(axis="y", linestyle="dashed", dashes=(2, 5), zorder=1)

    ks_colors = {"closer": "#E07524", "further": "#92CADE", "similar": "#BCBEC0"}
    ks_cutoff = 0.075
    ax.axvline(ks_cutoff)
    for subtype in order:
        row = df_ks[df_ks["subtype_2"] == subtype].iloc[0].to_dict()

        if row["ks"] > ks_cutoff:
            if row["WT-KO"] > 0:
                color = "closer"
            else:
                color = "further"
        else:
            color = "similar"

        ax.barh(subtype, row["ks"], 0.8, color=ks_colors[color], zorder=2)

    # Create the custom legend
    handles = [
        plt.Line2D([], [], marker="o", color=ks_colors[c], label=c, ls="")
        for c in ks_colors
    ]
    labels = ks_colors.keys()
    ax.legend(handles, labels, loc="lower right", title="KS statistics")

In [None]:
def make_dot_plot(df, order, ax):
    ax.grid(axis="y", linestyle="dashed", dashes=(2, 5), zorder=1)

    for subtype in order:
        values = df[df["subtype_2"] == subtype]["correlation"].values
        ax.plot(
            values,
            [subtype, subtype],
            color="gray",
            linestyle="-",
            linewidth=1,
            zorder=2,
        )
        ax.scatter(values[0], subtype, color="black", zorder=3)
        ax.scatter(values[1], subtype, color="#63ABB9", zorder=3)

    # Add a few custom labels to the x axis
    xt = ax.get_xticks()
    xt_labels = xt.tolist()
    xt_labels = [f"{x:.1f}" for x in xt_labels]
    xt_labels[-1] = xt_labels[-1] + "\nStronger signature \nif close to cell"
    xt_labels[0] = xt_labels[0] + "\nWeaker signature \nif close to cell"
    ax.set_xticks(xt)
    ax.set_xticklabels(xt_labels)

    # Create the custom legend
    handles = [
        plt.Line2D([], [], marker="o", color="black", label="WT", ls=""),
        plt.Line2D([], [], marker="o", color="#63ABB9", label="KO", ls=""),
    ]
    labels = ["WT", "KO"]
    ax.legend(handles, labels, loc="lower right", title="Genotype")

In [None]:
def make_expression_plot(
    adata,
    batch,
    order,
    ax,
    genes=["Tgfb1", "Tgfb2", "Tgfb3", "Itgb6", "Itgb8", "Itgav", "Ltbp1", "Ltbp3"],
):
    adata_sub = adata[adata.obs["batch"] == batch]
    adata_sub = adata_sub[adata_sub.obs["Subtype"].isin(order)]

    sc.pl.dotplot(
        adata_sub,
        var_names=genes,
        groupby="Subtype",
        categories_order=order,
        ax=ax,
        cmap=colormap,
        show=False,
    )

In [None]:
# def make_heatmap(df, order, ax):
#     df_matrix = df.pivot(index="subtype_2", columns="batch", values="correlation")

#     df_matrix = df_matrix.reindex(order)
#     df_matrix = df_matrix.iloc[::-1]
#     sns.heatmap(df_matrix, cmap=colormap, ax=ax)


# def make_heatmap_difference(df, order, ax):
#     df_matrix = df.pivot(index="subtype_2", columns="batch", values="correlation")

#     df_matrix = df_matrix.reindex(order)
#     df_matrix = df_matrix.iloc[::-1]

#     df_difference = pd.DataFrame({"difference": (-df_matrix["WT"] + df_matrix["KO"])})

#     sns.heatmap(df_difference, ax=ax, cmap="RdBu_r")

# Figure 4g

### Get Data

In [None]:
df = get_batchwise_correlation_between_distance_and_expression(
    adata, "Cd8_T-Cell_P14", "UCell_DEG"
)

# Remove the unknown cell types
df = df[~df["subtype_2"].str.startswith("Unknown")]
# remove the correlation from p14 to P14, these are NaN values
df = df[~(df["subtype_1"] == df["subtype_2"])]

df["correlation"] = df["correlation"] * -1

df = df.sort_values(by=["batch", "correlation"])
df

In [None]:
df_ks = get_ks_statistics(adata, "Cd8_T-Cell_P14")
# Remove the unknown cell types
df_ks = df_ks[~df_ks["subtype_2"].str.startswith("Unknown")]
# remove the correlation from p14 to P14, these are NaN values
df_ks = df_ks[~(df_ks["subtype_1"] == df_ks["subtype_2"])]
df_ks.head()

## Plot

In [None]:
order = df["subtype_2"].unique()

fig, (ax1, ax2, ax3) = plt.subplots(
    nrows=1, ncols=3, figsize=(18, 8), width_ratios=[5, 3, 1]
)

make_dot_plot(df, order, ax1)
make_expression_plot(adata, "WT", order[::-1], ax2)
make_ks_plot(df_ks, order, ax3)

fig.tight_layout()

# Figure 4h

In [None]:
import matplotlib


def plot_histogram(
    adata: ad.AnnData,
    subtype_1: str,
    subtype_2: str,
    ax: matplotlib.axes.Axes,
    ymax: float = 0.004,
    xmax: float = 2000.0,
) -> None:
    """
    Plots a histogram of distances for two subtypes across two batches.

    Args:
        adata: AnnData object containing expression data and metadata.
        subtype_1: Name of the first subtype to compare.
        subtype_2: Name of the second subtype to compare.
        ax: Matplotlib axes object to plot the histogram on.
        ymax: Maximum y-axis value for the plot (default: 0.004).
        xmax: Maximum x-axis value for the plot (default: 2000).

    Raises:
        ValueError: If either subtype is not found in adata or if there are not exactly two batches.

    Returns:
        None
    """
    from scipy import stats

    if subtype_1 not in adata.obs["Subtype"].unique():
        raise ValueError(f"Subtype {subtype_1} not found in adata")
    if subtype_2 not in adata.obs["Subtype"].unique():
        raise ValueError(f"Subtype {subtype_2} not found in adata")

    batches = adata.obs["batch"].cat.categories
    if len(batches) != 2:
        raise ValueError(f"There must be exactly two batches")

    distances = {}
    for b in batches:
        adata_batch = adata[adata.obs["batch"] == b]
        distances[b] = get_closest_cell(
            adata_batch, subtype_1=subtype_1, subtype_2=subtype_2
        )

    stat, p = stats.ks_2samp(distances[batches[0]], distances[batches[1]])
    diff = np.median(distances[batches[0]]) - np.median(distances[batches[1]])

    sns.kdeplot(distances[batches[0]], ax=ax, label=batches[0])
    sns.kdeplot(distances[batches[1]], ax=ax, label=batches[1])

    ax.text(
        0.98,
        0.98,
        f"KS = {stat:.3f}\np-value = {p:3.2}",
        horizontalalignment="right",
        verticalalignment="top",
        transform=ax.transAxes,
    )
    ax.legend(loc="lower right")
    ax.set_ylim(0, ymax)
    ax.set_xlim(-100, xmax)
    ax.set_title(f"Distance from {subtype_1}\nto the closest {subtype_2}")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

subtypes = [
    "Complement_Fibroblast",
    "Fibroblast_Pdgfra+",
    "Fibroblast_Ncam1",
    "Fibroblast_Apoe+",
]

for subtype, ax in zip(subtypes, axes.flatten()):
    plot_histogram(adata, "Cd8_T-Cell_P14", subtype, ax)

fig.tight_layout()