In [None]:
import anndata as ad
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
from matplotlib import ticker
import scanpy as sc
import math

Uses run output saved in `trained_model_umaps` directory.

# model training

In [None]:
anndata = ad.read_h5ad("trained_model_umaps/anndata_umap_with_clusters.h5ad")

# preprocessing

## niche selection and filtering

In [None]:
niche_cell_counts = anndata.obs["nichecompass_latent_cluster"].value_counts().to_dict()
retained_niches = [x for x, y in niche_cell_counts.items() if y > 100_000]
anndata_filtered = anndata[anndata.obs["nichecompass_latent_cluster"].isin(retained_niches)]
print(f"retaining {len(anndata_filtered)} of {len(anndata)} cells following filtering")

In [None]:
len(retained_niches)

# visualise results

In [None]:
niche_color_map = {
    "0": "#66C5CC",
    "1": "#F6CF71",
    "2": "#F89C74",
    "3": "#DCB0F2",
    "4": "#87C55F",
    "5": "#9EB9F3",
    "6": "#FE88B1",
    "7": "#C9DB74",
    "8": "#8BE0A4",
    "9": "#B497E7",
    "10": "#D3B484",
    "11": "#B3B3B3",
    "12": "#276A8C",
    "13": "#DAB6C4",
    "14": "#9B4DCA",
    "15": "#9D88A2",
    "16": "#FF4D4D",
}

## gene program embedding

In [None]:
anndata_filtered_subsample = sc.pp.subsample(anndata_filtered, fraction=0.01, copy=True)

In [None]:
fig = sc.pl.umap(anndata_filtered_subsample, color="dataset", title="NicheCompass GP embedding", size=1, frameon=False, return_fig=True)
plt.savefig("r3c12_gp_embedding.svg")

In [None]:
sc.pl.umap(anndata_filtered_subsample, color="nichecompass_latent_cluster", size=1, palette=niche_color_map)

## niche composition

In [None]:
freq_table = pd.crosstab(
    anndata_filtered.obs["dataset"],
    anndata_filtered.obs["nichecompass_latent_cluster"]
)
freq_table

In [None]:
fig, ax = plt.subplots()
ax = freq_table.transpose().plot(kind="bar", stacked=True, ylabel="Number of cells", xlabel="NicheCompass niche", ax=ax)

ax.grid(which='major', axis='y', linestyle='--')
ax.grid(False, axis='x')
ax.spines[['right', 'top']].set_visible(False)
ax.spines[['left', 'bottom']].set_linewidth(1)
ax.spines[['left', 'bottom']].set_color("black")

plt.xticks(rotation=0)

ax.get_yaxis().set_major_formatter(ticker.FuncFormatter(lambda x, p: format(int(x), ',')))

plt.savefig("r3c12_number_of_cells.svg")

## spatial distribution

In [None]:
merfish_section_label = "C57BL6J-1.083"
starmap_section_label = "well11"

In [None]:
fig, axs = plt.subplots(1, 2)

merfish_selected_section_anndata = anndata_filtered[anndata_filtered.obs["section"] == merfish_section_label]
sc.pl.spatial(merfish_selected_section_anndata, spot_size=20, title="MERFISH", color="nichecompass_latent_cluster", palette=niche_color_map, ax=axs[0], return_fig=False, show=False, frameon=False)
axs[0].legend().set_visible(False)

def rotate_origin_only(xy, radians):
    """Only rotate a point around the origin (0, 0)."""
    x, y = xy
    xx = x * math.cos(radians) + y * math.sin(radians)
    yy = -x * math.sin(radians) + y * math.cos(radians)

    return [xx, yy]

starmap_selected_section_anndata = anndata_filtered[anndata_filtered.obs["section"] == starmap_section_label]
spatial_coordinates = starmap_selected_section_anndata.obsm["spatial"].tolist()
rotated_spatial_coordinates = [rotate_origin_only(xy, math.pi/2) for xy in spatial_coordinates]
starmap_selected_section_anndata.obsm["spatial"] = np.array(rotated_spatial_coordinates)
sc.pl.spatial(starmap_selected_section_anndata, spot_size=0.12, title="STARmap PLUS", color="nichecompass_latent_cluster", palette=niche_color_map, ax=axs[1], return_fig=False, show=False, frameon=False)

legend_elements = [matplotlib.patches.Patch(facecolor=y, edgecolor=y, label=x) for x, y in niche_color_map.items()]

leg = axs[1].legend(handles=legend_elements, loc="right", bbox_to_anchor=(1.5, 0.5), frameon=False)

plt.savefig("r3c12_slide_overview.svg")

And now we'll run through the niches to better visualise this

In [None]:
color_map = {"True": "blue", "False": "lightgrey"}

for selected_nichecompass_latent_cluster in retained_niches:
    
    fig, axs = plt.subplots(1, 2)
    
    # plot the merfish cluster
    merfish_selected_section_anndata.obs["is_cluster"] = merfish_selected_section_anndata.obs["nichecompass_latent_cluster"] == selected_nichecompass_latent_cluster
    merfish_selected_section_anndata.obs["is_cluster"] = merfish_selected_section_anndata.obs["is_cluster"].astype("str")
    sc.pl.spatial(merfish_selected_section_anndata, spot_size=20, return_fig=False, title="MERFISH", color="is_cluster", show=False, ax=axs[0], palette=color_map, frameon=False)
    
    # plot the starmap cluster
    starmap_selected_section_anndata.obs["is_cluster"] = starmap_selected_section_anndata.obs["nichecompass_latent_cluster"] == selected_nichecompass_latent_cluster
    starmap_selected_section_anndata.obs["is_cluster"] = starmap_selected_section_anndata.obs["is_cluster"].astype("str")
    sc.pl.spatial(starmap_selected_section_anndata, spot_size=0.12, return_fig=False, title="STARmap PLUS", color="is_cluster", show=False, ax=axs[1], palette=color_map, frameon=False)
    
    axs[0].legend().set_visible(False)
    axs[1].legend().set_visible(False)
    
    fig.suptitle(f"niche {selected_nichecompass_latent_cluster}")
    
    plt.savefig(f"r3c12_slide_detail_{selected_nichecompass_latent_cluster}.svg")
    