# Spatial Weighted Similarity

This notebook runs the full pipeline:
1. Load the Visium dataset  
2. Select top spatially variable genes  
3. Compute spatial‑weighted similarity  
4. Cluster genes  
5. Visualize representative genes per cluster
6. Visualize similarity matrix
7. Save results

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import os
from pathlib import Path

# Calibrate project root
while not (Path.cwd() / 'data').exists() and Path.cwd().parent != Path.cwd():
    os.chdir('..')

from src.data.data_loader import SpatialDataset
from src.similarity.spatial_weighted_similarity import SpatialWeightedSimilarity
from src.clustering.gene_clustering import GeneClustering
from src.visualization.plots import SpatialPlotter

# Session management
from src.utils.session import SessionManager
session = SessionManager.get_or_create_session(profile='default')
session.log("Starting notebook 03: Spatial weighted similarity", notebook="03_spatial")

def save_to_session(data, filename, save_func=np.save):
    """Save data to current session directory."""
    if filename.endswith('.png'):
        path = session.get_plot_path(filename)
    else:
        path = session.get_metric_path(filename)
    save_func(path, data)
    session.log(f"Saved {filename}", notebook="03_spatial")
    return path

# 1. Load dataset

We load the DLPFC Visium dataset.

In [None]:
dataset_path = session.config.get("dataset_path", "data/DLPFC-151673")
dataset = SpatialDataset(dataset_path)
dataset.load()

adata = dataset.adata
print(f"Loaded dataset from: {dataset_path}")
adata

# 2. Select top spatially variable genes

We use the HVG-based selector implemented in `SpatialDataset`.

In [None]:
top_genes = dataset.select_top_spatially_variable_genes(
    n_top=300, 
    min_gene_expression=300,
    n_top_genes=3000
)

len(top_genes), top_genes[:10]

# 3. Compute spatial‑weighted similarity

We combine:
- expression similarity
- spatially filtered similarity
- MoG-transformed similarity

Notes: Weights can be tuned.

In [None]:
sws = SpatialWeightedSimilarity(
    dataset,
    alpha=0.5,   # expression
    beta=0.3,    # spatial
    gamma=0.2    # MoG
)

S = sws.compute_similarity_matrix(top_genes)
S.shape

# 4. Cluster genes

We use Louvain clustering on the similarity graph.


In [None]:
clusterer = GeneClustering(method="louvain", resolution=1.0)
labels = clusterer.cluster(S)

np.unique(labels, return_counts=True)

# 5. Visualize representative genes per cluster

For each cluster, we pick one gene and plot:
- raw spatial map
- MoG-transformed map

In [None]:
plotter = SpatialPlotter(adata, dataset.filter_bank)

clusters = np.unique(labels)

for c in clusters:
    print(f"\n=== Cluster {c} ===")
    gene_id = top_genes[np.where(labels == c)[0][0]]
    print(f"Representative gene index: {gene_id}")

    plot_path = session.get_plot_path(f"nb03_weighted_cluster_{c}_gene_{gene_id}.png")
    plotter.full_gene_diagnostic_plot(gene_id, save=True, path=plot_path)

# 6. Visualize similarity matrix

A heatmap helps understand the structure of the similarity.

In [None]:
sort_idx = np.argsort(labels)
S_sorted = S[sort_idx][:, sort_idx]

plt.figure(figsize=(8, 8))
plt.imshow(S_sorted, cmap="viridis")
plt.colorbar()
plt.title("Spatial Weighted Similarity Matrix (reordered by cluster)")
save_to_session(None, "nb03_weighted_similarity_heatmap.png",
                save_func=lambda p, _: plt.savefig(p, dpi=150, bbox_inches='tight'))
plt.show()
plt.close()

# 7. Save results

We save:
- similarity matrix
- cluster labels
- selected gene indices

In [None]:
save_to_session(top_genes, "top_genes.npy")
save_to_session(S, "similarity_matrix.npy")
save_to_session(labels, "cluster_labels.npy")

print(f"\nOK All results saved to session: {session.session_id}")
print(f"   Location: {session.run_dir}")