# 02 â€” Baseline Gene Clustering (Expression-Only)

This notebook implements the baseline pipeline:

1. Load the DLPFC dataset  
2. Select top spatially variable genes  
3. Compute **expression-only** similarity (Pearson, Spearman, Cosine)  
4. Cluster genes using Louvain  
5. Evaluate clustering quality  
6. Visualize representative genes  
7. Save baseline metrics  

This baseline will be compared against:
- Spatial Weighted Similarity  
- Multi-View Clustering  

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import os
from pathlib import Path

# Calibrate project root
while not (Path.cwd() / 'data').exists() and Path.cwd().parent != Path.cwd():
    os.chdir('..')

from src.data.data_loader import SpatialDataset
from src.similarity.alternative_similarity import (
    PearsonSimilarity,
    SpearmanSimilarity,
    CosineSimilarity
)
from src.clustering.gene_clustering import GeneClustering
from src.evaluation.metrics import ClusteringEvaluator
from src.visualization.plots import SpatialPlotter

# Session management
from src.utils.session import SessionManager
session = SessionManager.get_or_create_session(profile='default')
session.log("Starting notebook 02: Baseline clustering", notebook="02_baseline")

def save_to_session(data, filename, save_func=np.save):
    """Save data to current session directory."""
    if filename.endswith('.png'):
        path = session.get_plot_path(filename)
    else:
        path = session.get_metric_path(filename)
    save_func(path, data)
    session.log(f"Saved {filename}", notebook="02_baseline")
    return path

## Load dataset

In [None]:
dataset_path = session.config.get("dataset_path", "data/DLPFC-151673")
dataset = SpatialDataset(dataset_path)
dataset.load()

adata = dataset.adata
print(f"Loaded dataset from: {dataset_path}")
adata

## Select top spatially variable genes

In [None]:
top_genes = dataset.select_top_spatially_variable_genes(
    n_top=300,
    min_gene_expression=300,
    n_top_genes=3000
)

len(top_genes), top_genes[:10]

## Compute baseline similarities

In [None]:
array_data = dataset.adata.X.toarray().T
X = array_data[top_genes]

In [None]:
pearson_sim = PearsonSimilarity().compute(X)
spearman_sim = SpearmanSimilarity().compute(X)
cosine_sim = CosineSimilarity().compute(X)

pearson_sim.shape, spearman_sim.shape, cosine_sim.shape

## Cluster each baseline similarity matrix

In [None]:
clusterer = GeneClustering(method="louvain", resolution=1.0)

labels_pearson = clusterer.cluster(pearson_sim)
labels_spearman = clusterer.cluster(spearman_sim)
labels_cosine = clusterer.cluster(cosine_sim)

np.unique(labels_pearson), np.unique(labels_spearman), np.unique(labels_cosine)

## Evaluate clustering quality

In [None]:
evaluator = ClusteringEvaluator()

metrics_pearson = evaluator.compute_basic_metrics(X, labels_pearson)
metrics_spearman = evaluator.compute_basic_metrics(X, labels_spearman)
metrics_cosine = evaluator.compute_basic_metrics(X, labels_cosine)

metrics_pearson, metrics_spearman, metrics_cosine

## Compare baseline clusterings (ARI / NMI)

In [None]:
ari_ps = evaluator.compare_clusterings(labels_pearson, labels_spearman)
ari_pc = evaluator.compare_clusterings(labels_pearson, labels_cosine)
ari_sc = evaluator.compare_clusterings(labels_spearman, labels_cosine)

ari_ps, ari_pc, ari_sc

## Representative genes per cluster (Pearson baseline)

In [None]:
plotter = SpatialPlotter(adata, dataset.filter_bank)

clusters = np.unique(labels_pearson)

for cid in clusters:
    print(f"\n=== Cluster {cid} ===")
    gene_idx = np.where(labels_pearson == cid)[0][0]
    gene_id = top_genes[gene_idx]
    print(f"Representative gene: {gene_id}")

    plot_path = session.get_plot_path(f"nb02_baseline_cluster_{cid}_gene_{gene_id}.png")
    plotter.full_gene_diagnostic_plot(gene_id, save=True, path=plot_path)

## Similarity matrices (Pearson, Spearman, Cosine)

In [None]:
matrices = {
    "Pearson": (pearson_sim, labels_pearson),
    "Spearman": (spearman_sim, labels_spearman),
    "Cosine": (cosine_sim, labels_cosine)
}

plt.figure(figsize=(12, 12))

for i, (name, (M, lbls)) in enumerate(matrices.items()):
    sort_idx = np.argsort(lbls)
    M_sorted = M[sort_idx][:, sort_idx]
    plt.subplot(2, 2, i+1)
    plt.imshow(M_sorted, cmap="viridis")
    plt.title(f"{name} Similarity (reordered by cluster)")
    plt.colorbar()

plt.tight_layout()
save_to_session(None, "nb02_baseline_similarity_matrices.png",
                save_func=lambda p, _: plt.savefig(p, dpi=150, bbox_inches='tight'))
plt.show()
plt.close()

## Save baseline results

In [None]:
save_to_session(top_genes, "baseline_top_genes.npy")
save_to_session(pearson_sim, "baseline_similarity_pearson.npy")
save_to_session(spearman_sim, "baseline_similarity_spearman.npy")
save_to_session(cosine_sim, "baseline_similarity_cosine.npy")

save_to_session(labels_pearson, "baseline_labels_pearson.npy")
save_to_session(labels_spearman, "baseline_labels_spearman.npy")
save_to_session(labels_cosine, "baseline_labels_cosine.npy")

# Save metrics DataFrame
metrics_df = pd.DataFrame(
    [metrics_pearson, metrics_spearman, metrics_cosine],
    index=["pearson", "spearman", "cosine"]
)
metrics_path = session.get_metric_path("baseline_metrics.csv")
metrics_df.to_csv(metrics_path)
session.log(f"Saved baseline_metrics.csv", notebook="02_baseline")

print(f"\nOK All baseline results saved to session: {session.session_id}")
print(f"   Location: {session.run_dir}")