# Final Plots & Summary — Spatial Feature Clustering

This notebook generates publication-quality figures and summary tables for:

- Spatial Weighted Similarity results  
- Multi-View Clustering results  
- Cluster comparisons  
- Representative gene visualizations  
- Similarity heatmaps  
- ARI/NMI matrices  
- Spatial coherence metrics  

These figures will be used in the thesis and presentation.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc

import os
from pathlib import Path

# Calibrate project root
while not (Path.cwd() / 'data').exists() and Path.cwd().parent != Path.cwd():
    os.chdir('..')

from src.data.data_loader import SpatialDataset
from src.visualization.plots import SpatialPlotter
from src.evaluation.metrics import ClusteringEvaluator

# Session management
from src.utils.session import SessionManager
session = SessionManager.get_or_create_session(profile='default')
session.log("Starting notebook 05: Final plots", notebook="05_final")

def load_from_session(filename):
    """Load data from current session directory."""
    if filename.endswith('.csv'):
        path = session.get_metric_path(filename)
        return pd.read_csv(path, index_col=0)
    else:
        path = session.get_metric_path(filename)
        return np.load(path)

def save_plot_to_session(filename):
    """Save plot to current session directory."""
    path = session.get_plot_path(filename)
    plt.savefig(path, dpi=200, bbox_inches='tight')
    session.log(f"Saved {filename}", notebook="05_final")
    return path

## Load dataset and previously saved results

In [None]:
# Load dataset
dataset_path = session.config.get("dataset_path", "data/DLPFC-151673")
dataset = SpatialDataset(dataset_path)
dataset.load()
adata = dataset.adata
print(f"Loaded dataset from: {dataset_path}")

# Load saved results from notebook 04 (from current session)
top_genes = load_from_session("top_genes_multiview.npy")

# Load similarity matrices
S_expr = load_from_session("similarity_expression.npy")
S_spatial = load_from_session("similarity_spatial.npy")
S_mog = load_from_session("similarity_mog.npy")
S_weighted = load_from_session("similarity_weighted.npy")

# Load cluster labels
labels_weighted = load_from_session("cluster_labels_weighted.npy")

# Load ARI/NMI matrices
ari_matrix = load_from_session("ari_matrix.csv")
nmi_matrix = load_from_session("nmi_matrix.csv")

print(f"OK Loaded all results from session: {session.session_id}")
print(f"   Top genes: {len(top_genes)}")
print(f"   Clusters in weighted view: {len(np.unique(labels_weighted))}")

## Setup neighbors and compute spatial profiles

We need to:
1. Build neighbor graph for spatial coherence computation
2. Reconstruct spatial profiles by averaging gene clusters across spots

In [None]:
# Setup neighbors if not already done
if 'connectivities' not in adata.obsp:
    sc.pp.neighbors(adata, n_neighbors=6)
    print("OK Neighbor graph computed")
else:
    print("OK Neighbor graph already exists")

# Reconstruct spatial profiles for weighted clusters
unique_clusters = np.unique(labels_weighted)
spatial_coherence = {}

for cid in unique_clusters:
    # Find genes in this cluster
    genes_in_cluster = top_genes[labels_weighted == cid]
    
    # Average expression across spots
    cluster_avg = adata[:, genes_in_cluster].X.mean(axis=1)
    if hasattr(cluster_avg, "A1"):
        cluster_avg = cluster_avg.A1
    
    # Save to adata.obs
    col_name = f"cluster_weighted_{cid}"
    adata.obs[col_name] = cluster_avg
    
    # Compute Moran's I for spatial coherence
    score = sc.metrics.morans_i(adata, vals=adata.obs[col_name])
    spatial_coherence[f"Cluster {cid}"] = score

print(f"\nOK Spatial profiles reconstructed for {len(unique_clusters)} clusters")
print(f"   Spatial coherence: {spatial_coherence}")

## Similarity matrices (Expression, Spatial, MoG, Weighted)

In [None]:
# Load cluster labels for each view to reorder similarity matrices
labels_expr = load_from_session("cluster_labels_expression.npy")
labels_spatial = load_from_session("cluster_labels_spatial.npy")
labels_mog = load_from_session("cluster_labels_mog.npy")

matrices = {
    "Expression": (S_expr, labels_expr),
    "Spatial": (S_spatial, labels_spatial),
    "MoG": (S_mog, labels_mog),
    "Weighted": (S_weighted, labels_weighted)
}

fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()

for i, (name, (M, lbls)) in enumerate(matrices.items()):
    sort_idx = np.argsort(lbls)
    M_sorted = M[sort_idx][:, sort_idx]
    im = axes[i].imshow(M_sorted, cmap="viridis")
    axes[i].set_title(f"{name} Similarity (reordered)", fontsize=14, fontweight='bold')
    plt.colorbar(im, ax=axes[i])

plt.tight_layout()
save_plot_to_session("all_similarity_matrices.png")
plt.show()

## ARI and NMI matrices between views

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# ARI Matrix
im1 = axes[0].imshow(ari_matrix.values.astype(float), cmap="viridis", vmin=0, vmax=1)
axes[0].set_xticks(range(len(ari_matrix)))
axes[0].set_yticks(range(len(ari_matrix)))
axes[0].set_xticklabels(ari_matrix.columns, rotation=45, ha='right')
axes[0].set_yticklabels(ari_matrix.index)
axes[0].set_title("ARI Matrix", fontsize=14, fontweight='bold')
plt.colorbar(im1, ax=axes[0])

# NMI Matrix
im2 = axes[1].imshow(nmi_matrix.values.astype(float), cmap="viridis", vmin=0, vmax=1)
axes[1].set_xticks(range(len(nmi_matrix)))
axes[1].set_yticks(range(len(nmi_matrix)))
axes[1].set_xticklabels(nmi_matrix.columns, rotation=45, ha='right')
axes[1].set_yticklabels(nmi_matrix.index)
axes[1].set_title("NMI Matrix", fontsize=14, fontweight='bold')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
save_plot_to_session("ari_nmi_matrices.png")
plt.show()

## Spatial coherence (Moran's I) per cluster

This measures how spatially structured each gene cluster is.

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))
ax.bar(spatial_coherence.keys(), spatial_coherence.values(), color='steelblue')
ax.set_xlabel("Cluster ID", fontsize=12)
ax.set_ylabel("Average Moran's I", fontsize=12)
ax.set_title("Spatial Coherence per Cluster (Weighted Similarity)", fontsize=14, fontweight='bold')
ax.set_ylim(0, 1)
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
save_plot_to_session("spatial_coherence_bar.png")
plt.show()

## Plot spatial distribution of each cluster

Each subplot shows one gene cluster's average expression overlaid on the H&E tissue image.
Spots are colored by the cluster's mean expression intensity (across all genes in that cluster).
This reveals the spatial structure captured by each gene module.

In [None]:
import json
from matplotlib.image import imread

# Load tissue image and scale factor for overlay
tissue_img_path = Path(dataset_path) / "spatial" / "tissue_hires_image.png"
scale_json_path = Path(dataset_path) / "spatial" / "scalefactors_json.json"

with open(scale_json_path) as f:
    scale_factors = json.load(f)
scale = scale_factors["tissue_hires_scalef"]
tissue_img = imread(str(tissue_img_path))

coords = adata.obsm["spatial"]

n_clusters = len(unique_clusters)
fig, axes = plt.subplots(1, n_clusters, figsize=(7 * n_clusters, 6))
if n_clusters == 1:
    axes = [axes]

for i, cid in enumerate(unique_clusters):
    col = f"cluster_weighted_{cid}"
    values = adata.obs[col].values
    morans = spatial_coherence[f"Cluster {cid}"]

    axes[i].imshow(tissue_img, alpha=0.4)
    sc_plot = axes[i].scatter(
        coords[:, 0] * scale,
        coords[:, 1] * scale,
        c=values,
        cmap="magma",
        s=6,
        alpha=0.85,
        edgecolors="none",
    )
    axes[i].set_title(
        f"Cluster {cid}  —  Moran's I = {morans:.3f}",
        fontsize=13, fontweight="bold",
    )
    axes[i].axis("off")
    plt.colorbar(sc_plot, ax=axes[i], shrink=0.65, label="Avg expression")

plt.tight_layout()
save_plot_to_session("nb05_cluster_spatial_profiles.png")
plt.show()
plt.close()

## Representative genes per cluster (Weighted Similarity)

We show one gene per cluster using the full diagnostic plot.

In [None]:
plotter = SpatialPlotter(adata, dataset.filter_bank)

for cid in unique_clusters:
    print(f"\n{'='*60}")
    print(f"Cluster {cid}")
    print(f"{'='*60}")
    
    # Pick first gene in this cluster
    gene_idx = np.where(labels_weighted == cid)[0][0]
    gene_id = top_genes[gene_idx]
    print(f"Representative gene: {gene_id}")

    plot_path = session.get_plot_path(f"nb05_final_cluster_{cid}_gene_{gene_id}.png")
    plotter.full_gene_diagnostic_plot(gene_id, save=True, path=plot_path)

## Inspect differences between two most different views

In [None]:
# Find the pair with lowest ARI (excluding diagonal)
min_ari = 1
pair = None

for i in ari_matrix.index:
    for j in ari_matrix.columns:
        if i != j and ari_matrix.loc[i, j] < min_ari:
            min_ari = ari_matrix.loc[i, j]
            pair = (i, j)

print(f"Most different views: {pair}")
print(f"ARI: {min_ari:.3f}")

In [None]:
# Load labels for the two most different views
view_a, view_b = pair
labels_a = load_from_session(f"cluster_labels_{view_a}.npy")
labels_b = load_from_session(f"cluster_labels_{view_b}.npy")

# Find genes that changed cluster
diff_mask = labels_a != labels_b
changed_genes = top_genes[diff_mask]

print(f"\nGenes that changed cluster: {len(changed_genes)}")
print(f"First 10: {changed_genes[:10]}")

In [None]:
# Visualize a few genes that changed cluster
for gene_id in changed_genes[:5]:
    cluster_a = labels_a[top_genes == gene_id][0]
    cluster_b = labels_b[top_genes == gene_id][0]
    print(f"\nGene {gene_id}: {view_a}=cluster_{cluster_a}, {view_b}=cluster_{cluster_b}")
    plot_path = session.get_plot_path(f"nb05_changed_gene_{gene_id}.png")
    plotter.full_gene_diagnostic_plot(gene_id, save=True, path=plot_path)

## Summary tables for the thesis

In [None]:
# Create summary table
summary = pd.DataFrame({
    "Cluster ID": unique_clusters,
    "Size (# genes)": [np.sum(labels_weighted == cid) for cid in unique_clusters],
    "Spatial Coherence (Moran's I)": [spatial_coherence[f"Cluster {cid}"] for cid in unique_clusters]
})

summary

In [None]:
# Save summary table
summary_path = session.get_metric_path("cluster_summary.csv")
summary.to_csv(summary_path, index=False)
session.log("Saved cluster_summary.csv", notebook="05_final")
print(f"OK Summary table saved to session: {session.session_id}")

## Save weighted similarity matrix as figure

In [None]:
sort_idx = np.argsort(labels_weighted)
S_sorted = S_weighted[sort_idx][:, sort_idx]

plt.figure(figsize=(8, 8))
plt.imshow(S_sorted, cmap="viridis")
plt.colorbar(label="Similarity")
plt.title("Weighted Similarity Matrix (reordered by cluster)", fontsize=14, fontweight='bold')
plt.tight_layout()
save_plot_to_session("weighted_similarity_matrix.png")
plt.show()

print(f"\nOK All publication figures saved to session: {session.session_id}")
print(f"   Location: {session.plots_dir}")