# Multi-View Clustering â€” Alternative Gene Clusterings

This notebook runs the full pipeline:
1. Load the DLPFC Visium dataset  
2. Select top spatially variable genes  
3. Compute multiple similarity views:
   - expression
   - spatially filtered
   - MoG-transformed
   - weighted (combined)
4. Cluster each view independently  
5. Compare clusterings (ARI, NMI)  
6. Visualize representative genes per view  

Goal: identify **alternative clusterings** of genes that reflect different biological structures.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import os
from pathlib import Path

# Calibrate project root
while not (Path.cwd() / 'data').exists() and Path.cwd().parent != Path.cwd():
    os.chdir('..')

from src.data.data_loader import SpatialDataset
from src.clustering.multiview_clustering import MultiViewClustering
from src.visualization.plots import SpatialPlotter

# Session management
from src.utils.session import SessionManager
session = SessionManager.get_or_create_session(profile='default')
session.log("Starting notebook 04: Multi-view clustering", notebook="04_multiview")

def save_to_session(data, filename, save_func=np.save):
    """Save data to current session directory."""
    if filename.endswith('.png'):
        path = session.get_plot_path(filename)
    else:
        path = session.get_metric_path(filename)
    save_func(path, data)
    session.log(f"Saved {filename}", notebook="04_multiview")
    return path

## Load dataset

In [None]:
dataset_path = session.config.get("dataset_path", "data/DLPFC-151673")
dataset = SpatialDataset(dataset_path)
dataset.load()

adata = dataset.adata
print(f"Loaded dataset from: {dataset_path}")
adata

## Select top spatially variable genes

In [None]:
top_genes = dataset.select_top_spatially_variable_genes(
    n_top=300,
    min_gene_expression=300,
    n_top_genes=3000
)

print(f"Selected {len(top_genes)} genes")
top_genes[:10]

## Run multi-view clustering

We compute:
- expression similarity
- spatial similarity
- MoG similarity
- weighted similarity

Then we cluster each view with the same clustering algorithm.

**Note**: The improved `MultiViewClustering` automatically saves cluster labels to `adata.obs` and `adata.uns` for easy access.

In [None]:
mvc = MultiViewClustering(
    dataset,
    clustering_method="louvain",
    resolution=1.0,
    random_state=0,
    save_to_adata=True,  # Automatically saves labels to adata
)

results = mvc.run(top_genes)

similarities = results["similarities"]
clusterings = results["clusterings"]
comparisons = results["comparisons"]

print(f"\nOK Multi-view clustering complete!")
print(f"   Views: {list(similarities.keys())}")
print(f"   Cluster labels saved to adata.obs and adata.uns['gene_clusters']")

## Verify cluster labels integration

In [None]:
# Check that labels were saved
cluster_cols = [col for col in dataset.adata.obs.columns if col.startswith('cluster_')]
print(f"Cluster columns in adata.obs: {len(cluster_cols)}")
print(cluster_cols[:5])

# Check gene cluster labels
if 'gene_clusters' in dataset.adata.uns:
    print(f"\nOK Gene clusters saved to adata.uns['gene_clusters']")
    for view in dataset.adata.uns['gene_clusters'].keys():
        labels = dataset.adata.uns['gene_clusters'][view]
        n_clusters = len(np.unique(labels))
        print(f"   {view:12s}: {n_clusters} clusters")

## Compare clusterings (ARI / NMI)

We build a matrix of ARI and NMI between all pairs of views.

In [None]:
views = list(clusterings.keys())

ari_matrix = pd.DataFrame(
    index=views,
    columns=views,
    dtype=float
)

nmi_matrix = pd.DataFrame(
    index=views,
    columns=views,
    dtype=float
)

for v1 in views:
    for v2 in views:
        ari_matrix.loc[v1, v2] = comparisons[v1][v2]["ARI"]
        nmi_matrix.loc[v1, v2] = comparisons[v1][v2]["NMI"]

print("ARI Matrix:")
print(ari_matrix.round(3))
print("\nNMI Matrix:")
print(nmi_matrix.round(3))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# ARI heatmap
im1 = axes[0].imshow(ari_matrix.values, cmap="viridis", vmin=0, vmax=1)
axes[0].set_xticks(range(len(views)))
axes[0].set_yticks(range(len(views)))
axes[0].set_xticklabels(views, rotation=45, ha='right')
axes[0].set_yticklabels(views)
axes[0].set_title("ARI between views", fontsize=14, fontweight='bold')
plt.colorbar(im1, ax=axes[0], label="ARI")

# NMI heatmap
im2 = axes[1].imshow(nmi_matrix.values, cmap="viridis", vmin=0, vmax=1)
axes[1].set_xticks(range(len(views)))
axes[1].set_yticks(range(len(views)))
axes[1].set_xticklabels(views, rotation=45, ha='right')
axes[1].set_yticklabels(views)
axes[1].set_title("NMI between views", fontsize=14, fontweight='bold')
plt.colorbar(im2, ax=axes[1], label="NMI")

plt.tight_layout()
save_to_session(None, "nb04_ari_nmi_heatmaps.png",
                save_func=lambda p, _: plt.savefig(p, dpi=150, bbox_inches='tight'))
plt.show()
plt.close()

## Visualize representative genes per view

For each view, we:
- pick a cluster
- pick a representative gene
- show its full diagnostic plot

This helps see how different views emphasize different spatial patterns.

In [None]:
plotter = SpatialPlotter(adata, dataset.filter_bank)

for view_name, labels in clusterings.items():
    print(f"\n{'='*60}")
    print(f"View: {view_name}")
    print(f"{'='*60}")
    clusters = np.unique(labels)

    # pick first non-empty cluster
    cid = clusters[0]
    gene_idx_in_top = np.where(labels == cid)[0][0]
    gene_id = top_genes[gene_idx_in_top]

    print(f"Cluster {cid}, representative gene index: {gene_id}")

    plot_path = session.get_plot_path(f"nb04_view_{view_name}_gene_{gene_id}.png")
    plotter.full_gene_diagnostic_plot(gene_id, save=True, path=plot_path)

## Inspect two most different views

We pick two views with low ARI and inspect a few genes that change cluster.

In [None]:
view_a = "expression"
view_b = "mog"

labels_a = clusterings[view_a]
labels_b = clusterings[view_b]

diff_mask = labels_a != labels_b
changed_genes = top_genes[diff_mask]

print(f"Genes that changed cluster: {len(changed_genes)}")
print(f"First 10: {changed_genes[:10]}")

In [None]:
for gene_id in changed_genes[:5]:
    cluster_a = labels_a[top_genes == gene_id][0]
    cluster_b = labels_b[top_genes == gene_id][0]
    print(f"\nGene {gene_id} | {view_a}=cluster_{cluster_a}, {view_b}=cluster_{cluster_b}")
    plot_path = session.get_plot_path(f"nb04_changed_gene_{gene_id}.png")
    plotter.full_gene_diagnostic_plot(gene_id, save=True, path=plot_path)

## Save multi-view results

In [None]:
# Save gene indices
save_to_session(top_genes, "top_genes_multiview.npy")

# Save similarity matrices
for name, S in similarities.items():
    save_to_session(S, f"similarity_{name}.npy")

# Save cluster labels
for name, labels in clusterings.items():
    save_to_session(labels, f"cluster_labels_{name}.npy")

# Save ARI/NMI matrices
ari_path = session.get_metric_path("ari_matrix.csv")
ari_matrix.to_csv(ari_path)
session.log(f"Saved ari_matrix.csv", notebook="04_multiview")

nmi_path = session.get_metric_path("nmi_matrix.csv")
nmi_matrix.to_csv(nmi_path)
session.log(f"Saved nmi_matrix.csv", notebook="04_multiview")

print(f"\nOK All results saved to session: {session.session_id}")
print(f"   Location: {session.run_dir}")