# Phase 5: Embedding Space Analysis

**Question:** What does the learned representation look like? Where does class separation emerge?

This notebook:
1. Extracts intermediate graph-level embeddings via forward hooks
2. Visualizes with t-SNE and UMAP at each layer
3. Runs linear probing to measure where separability emerges
4. Computes CKA across layers and models
5. Analyzes within-class vs between-class cosine similarity

In [None]:
import sys
sys.path.insert(0, '../src')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

from calamr_interp.utils.data_loading import load_and_split
from calamr_interp.utils.model_loading import create_model, load_model_checkpoint, find_checkpoints
from calamr_interp.utils.visualization import setup_style, heatmap, COLORS
from calamr_interp.phase5_embeddings import (
    LayerEmbeddingExtractor,
    EmbeddingVisualizer,
    ProbingClassifier,
    CKAAnalysis,
    cosine_similarity_analysis,
)

setup_style()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 1. Load Model & Data

In [None]:
# Load data
train_data, val_data, test_data = load_and_split()
test_list = list(test_data)
test_labels = np.array([d.y.item() for d in test_list])
print(f'Test: {len(test_list)} graphs ({sum(test_labels)} hallu, {sum(1-test_labels)} truth)')

# Load model (update checkpoint path as needed)
# model = load_model_checkpoint('path/to/best_model.pt', 'GraphTransformer', device=device)
model = create_model('GraphTransformer')
model = model.to(device)
model.eval()
print(f'Model: {type(model).__name__}')

## 2. Extract Layer Embeddings

In [None]:
extractor = LayerEmbeddingExtractor(model, device)
extractor.register_hooks()

# Extract embeddings for test set
layer_embeddings = extractor.extract_graph_embeddings(test_list)

print(f'Extracted embeddings from {len(layer_embeddings)} layers:')
for name, emb in layer_embeddings.items():
    print(f'  {name}: shape {emb.shape}')

extractor.clear_hooks()

## 3. t-SNE & UMAP Visualization

In [None]:
viz = EmbeddingVisualizer(seed=42)

# t-SNE at each layer
n_layers = len(layer_embeddings)
fig, axes = plt.subplots(1, min(n_layers, 4), figsize=(6*min(n_layers, 4), 5))
if n_layers == 1:
    axes = [axes]

for idx, (layer_name, emb) in enumerate(list(layer_embeddings.items())[:4]):
    coords = viz.tsne(emb)
    viz.plot_scatter(coords, test_labels, title=f't-SNE: {layer_name}', method='t-SNE', ax=axes[idx])

plt.tight_layout()
plt.show()

In [None]:
# UMAP at each layer
fig, axes = plt.subplots(1, min(n_layers, 4), figsize=(6*min(n_layers, 4), 5))
if n_layers == 1:
    axes = [axes]

for idx, (layer_name, emb) in enumerate(list(layer_embeddings.items())[:4]):
    coords = viz.umap(emb)
    viz.plot_scatter(coords, test_labels, title=f'UMAP: {layer_name}', method='UMAP', ax=axes[idx])

plt.tight_layout()
plt.show()

## 4. Linear Probing

In [None]:
# Probe accuracy at each layer
# Use full dataset for more reliable probing
all_data = list(train_data) + list(val_data) + list(test_data)
all_labels = np.array([d.y.item() for d in all_data])

extractor2 = LayerEmbeddingExtractor(model, device)
extractor2.register_hooks()
all_embeddings = extractor2.extract_graph_embeddings(all_data)
extractor2.clear_hooks()

prober = ProbingClassifier(seed=42)
probe_results = prober.probe(all_embeddings, all_labels)
probe_results

In [None]:
# Plot probing accuracy across layers
fig, ax = plt.subplots(figsize=(10, 5))
layers = probe_results['layer'].tolist()
acc = probe_results['accuracy_mean'].values
acc_std = probe_results['accuracy_std'].values
f1 = probe_results['f1_mean'].values
f1_std = probe_results['f1_std'].values

x = range(len(layers))
ax.errorbar(x, acc, yerr=acc_std, label='Accuracy', marker='o', capsize=3, color=COLORS['primary'])
ax.errorbar(x, f1, yerr=f1_std, label='F1', marker='s', capsize=3, color=COLORS['hallu'])
ax.set_xticks(x)
ax.set_xticklabels(layers, rotation=45, ha='right')
ax.set_ylabel('Score')
ax.set_title('Linear Probing: Where Does Class Separation Emerge?')
ax.legend()
ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Chance')
plt.tight_layout()
plt.show()

print("A sharp increase in probing accuracy identifies where the model 'commits' to its decision.")

## 5. CKA Analysis

In [None]:
# CKA within the model
cka_matrix, layer_names = CKAAnalysis.compute_cka_matrix(all_embeddings)

fig = heatmap(
    cka_matrix, layer_names, layer_names,
    title='CKA: Representation Similarity Across Layers',
    cmap='viridis', center=None, fmt='.2f',
)
plt.show()

## 6. Cosine Similarity Analysis

In [None]:
# Within-class vs between-class similarity at each layer
sim_results = []
for layer_name, emb in all_embeddings.items():
    sim = cosine_similarity_analysis(emb, all_labels)
    sim['layer'] = layer_name
    sim_results.append(sim)

sim_df = pd.DataFrame(sim_results)

fig, ax = plt.subplots(figsize=(10, 5))
x = range(len(sim_df))
ax.plot(x, sim_df['within_class_sim'], 'o-', label='Within-class', color=COLORS['truth'])
ax.plot(x, sim_df['between_class_sim'], 's-', label='Between-class', color=COLORS['hallu'])
ax.set_xticks(x)
ax.set_xticklabels(sim_df['layer'], rotation=45, ha='right')
ax.set_ylabel('Cosine Similarity')
ax.set_title('Within-Class vs Between-Class Similarity Across Layers')
ax.legend()
plt.tight_layout()
plt.show()

print("\nSeparation ratio (within/between) should increase across layers:")
print(sim_df[['layer', 'separation_ratio']].to_string(index=False))