In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from scipy.stats import wilcoxon
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

In [None]:
results_dir = Path("./results")
data_dir = Path("./data")
fig_dir = results_dir / "figures"

In [None]:
fig_dir.mkdir(parents=True, exist_ok=True)

In [None]:
rename_methods = {
    "leiden": "Leiden HVG",
    "spatial_leiden": "SpatialLeiden HVG",
    "leiden_svg": "Leiden SVG",
    "spatial_leiden_svg": "SpatialLeiden SVG",
}

In [None]:
groundtruth = pd.DataFrame(
    [(f.parent.name, f) for f in data_dir.glob("*/labels.tsv")],
    columns=["sample", "groundtruth"],
).set_index("sample")

results = {
    f.name: {tsv.stem: tsv for tsv in f.glob("*.tsv")} for f in results_dir.glob("Br*")
}

label_files = (
    pd.DataFrame(results)
    .transpose()
    .merge(groundtruth, left_index=True, right_index=True)
    .rename(columns=rename_methods)
)

# ARI

In [None]:
def calculate_score(fn, groundtruth, cluster, **kwargs):
    domains = pd.read_table(cluster, index_col=0)["label"].astype("category").cat.codes
    groundtruth = (
        pd.read_table(groundtruth, index_col=0)["label"].astype("category").cat.codes
    )
    common_index = domains.index.intersection(groundtruth.index)
    return fn(groundtruth.loc[common_index], domains.loc[common_index], **kwargs)


def get_ncluster(file):
    return pd.read_table(file)["label"].astype("category").cat.categories.size

In [None]:
methods = label_files.columns.to_series().loc[lambda s: s != "groundtruth"].tolist()
scores = label_files.melt(
    id_vars="groundtruth",
    value_vars=methods,
    var_name="Method",
    value_name="filepath",
    ignore_index=False,
).assign(
    ARI=lambda df: df.apply(
        lambda row: calculate_score(adjusted_rand_score, row.groundtruth, row.filepath),
        axis=1,
    ),
    NMI=lambda df: df.apply(
        lambda row: calculate_score(
            normalized_mutual_info_score, row.groundtruth, row.filepath
        ),
        axis=1,
    ),
)

In [None]:
order = [
    "Leiden HVG",
    "Leiden SVG",
    "SpatialLeiden HVG",
    "SpatialLeiden SVG",
    "SpaGCN",
    "BayesSpace",
]

In [None]:
# check number of clusters
label_files.map(get_ncluster).sort_index()[["groundtruth"] + order]

In [None]:
scores.pivot(columns="Method", values="ARI").sort_index()[order]

In [None]:
scores.pivot(columns="Method", values="ARI").idxmax(axis=1).sort_index()

In [None]:
flierprops = dict(marker=".", markerfacecolor="black")

In [None]:
fig, axs = plt.subplots(ncols=2, figsize=(4, 5))
sns.boxplot(
    data=scores,
    x="Method",
    y="ARI",
    hue="Method",
    order=order,
    flierprops=flierprops,
    ax=axs[0],
)
_ = axs[0].set(ylim=(0, 1))
axs[0].set_xticklabels(
    axs[0].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor"
)

sns.boxplot(
    data=scores,
    x="Method",
    y="NMI",
    hue="Method",
    order=order,
    flierprops=flierprops,
    ax=axs[1],
)
_ = axs[1].set(ylim=(0, 1))
axs[1].set_xticklabels(
    axs[1].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor"
)

fig.tight_layout()
fig.savefig(fig_dir / "Fig1b.pdf", dpi=600)

In [None]:
ari = scores.pivot(columns="Method", values="ARI")
nmi = scores.pivot(columns="Method", values="NMI")

In [None]:
wilcoxon(ari["Leiden HVG"], ari["SpatialLeiden HVG"], alternative="less")

In [None]:
wilcoxon(ari["Leiden SVG"], ari["SpatialLeiden SVG"], alternative="less")

In [None]:
wilcoxon(nmi["Leiden HVG"], nmi["SpatialLeiden HVG"], alternative="less")

In [None]:
wilcoxon(nmi["Leiden SVG"], nmi["SpatialLeiden SVG"], alternative="less")

In [None]:
ari.to_csv(fig_dir / "ARI.tsv", sep="\t")
nmi.to_csv(fig_dir / "NMI.tsv", sep="\t")

# Spatial Plots (Br8100-151673)

In [None]:
sample = "Br8100_151673"

px_per_um = 96.40082438014726 / 65

In [None]:
coordinate_file = data_dir / sample / "coordinates.tsv"
image_file = data_dir / sample / "H_E.tiff"

In [None]:
from scipy.optimize import linear_sum_assignment

data = pd.read_table(coordinate_file, index_col=0)

files = label_files.loc[sample, :].copy()
labels_df = pd.read_table(files.pop("groundtruth"), index_col=0).astype("category")
labels_df.columns = ["groundtruth"]
labels_df["groundtruth"] = labels_df["groundtruth"].cat.rename_categories(
    {n: i for i, n in enumerate(labels_df.groundtruth.cat.categories)}
)

data = data.merge(labels_df, left_index=True, right_index=True)

for name, file in files.items():
    labels_df = pd.read_table(file, index_col=0).astype("category")
    labels_df.columns = [name]

    # match clusters
    contingency_table = pd.crosstab(labels_df[name], data["groundtruth"])
    row_ind, col_ind = linear_sum_assignment(contingency_table, maximize=True)
    labels_df[name] = labels_df[name].cat.rename_categories(
        dict(zip(contingency_table.index[row_ind], contingency_table.columns[col_ind]))
    )

    data = data.merge(labels_df, left_index=True, right_index=True)

In [None]:
from PIL import Image

# disable DecompressionBombWarning
Image.MAX_IMAGE_PIXELS = None

image = np.array(Image.open(image_file))

x_crop = (data["x"].min(), data["x"].max())
y_crop = (data["y"].min(), data["y"].max())

image = image[slice(*y_crop), slice(*x_crop)]
data = data.assign(x=lambda df: df["x"] - x_crop[0], y=lambda df: df["y"] - y_crop[0])

In [None]:
scatter_kwargs = {
    "s": 5,
    "linewidth": 0,
    "hue_order": data["groundtruth"].cat.categories,
}


def remove_tick_and_label(ax):
    ax.set(xticklabels=[], yticklabels=[], xlabel=None, ylabel=None)
    ax.tick_params(left=False, bottom=False)


def scatter_labels(df, name, ax):
    _ = sns.scatterplot(df, x="x", y="y", hue=name, ax=ax, **scatter_kwargs)
    ax.set(title=name, aspect=1, xlim=(0, df["x"].max()), ylim=(df["y"].max(), 0))
    ax.legend([], [], frameon=False)
    remove_tick_and_label(ax)


fig, axs = plt.subplots(nrows=2, ncols=4, sharex=True, sharey=True, figsize=(8, 5))
_ = axs[0, 0].imshow(image)
axs[0, 0].set(title="H & E")
remove_tick_and_label(axs[0, 0])
scale_bar = AnchoredSizeBar(
    axs[0, 0].transData, px_per_um * 1_000, r"1 mm", loc="lower right", frameon=False
)
axs[0, 0].add_artist(scale_bar)

scatter_labels(data, "groundtruth", axs[1, 0])
scatter_labels(data, "Leiden HVG", axs[0, 1])
scatter_labels(data, "Leiden SVG", axs[1, 1])
scatter_labels(data, "SpatialLeiden HVG", axs[0, 2])
scatter_labels(data, "SpatialLeiden SVG", axs[1, 2])
scatter_labels(data, "SpaGCN", axs[0, 3])
scatter_labels(data, "BayesSpace", axs[1, 3])

fig.tight_layout(h_pad=2)
fig.savefig(fig_dir / "Fig1a.pdf", dpi=600)