# Multi-Checkpoint Embedding Space Visualization

This notebook extracts **768-dimensional embeddings** (right before the contrastive probing layer) from 4 different finetuned checkpoints using the **same validation dataset** across all models.

## Architecture Reminder
```
Input (96³) → Encoder → Features (768-dim, spatial) → Global Pool → 768-dim vector
                                                                         ↓
                                                            Projection MLP (Probing Layer)
                                                                         ↓
                                                                    128-dim embedding
```

**We extract at the 768-dim stage** (after encoder + pooling, before projection MLP)

In [None]:
import sys
sys.path.append('/home/mgazda/Projects/UMBRA')

import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from typing import Dict, List, Tuple
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import pandas as pd
import re

# Import your models
from models.foundation import ContrastiveMAEPretrainer
from models.finetuning import FinetuningModule

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration: Your 4 Checkpoints

**IMPORTANT**: Update these paths to point to your actual checkpoint files!

In [None]:
# Define your 4 checkpoint paths and their labels
CHECKPOINTS = {
    'MCL': '/home/mg873uh/data/Petros/ckpts/contrastive_modality-step=200000.ckpt',
    'CL': '/home/mg873uh/data/Petros/ckpts/contrastive_regular-step=200000.ckpt',
    'MAE + MCL': '/home/mg873uh/data/Petros/ckpts/combined_modality-step=200000.ckpt',
    'MAE + CL': '/home/mg873uh/data/Petros/ckpts/combined_regular-step=200000.ckpt',
}

# Validation dataset configuration (PRETRAIN data directory)
VAL_DATA_DIR = '/home/mg873uh/Projects_kb/data/pretrain_parsed'
SEED = 42  # Same seed used during pretraining to ensure same validation split
INPUT_SIZE = 96  # Input volume size (96x96x96)

# Output directory for embeddings
OUTPUT_DIR = Path('/home/mgazda/Projects/UMBRA/embeddings_analysis')
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

print(f"Output directory: {OUTPUT_DIR}")
print(f"\nCheckpoints to process:")
for name, path in CHECKPOINTS.items():
    print(f"  {name}: {path}")

## Validation Dataset Loading (PRETRAIN)

Load the validation dataset using **ContrastiveDataModule** - the **exact same configuration** as during pretraining to ensure consistency.

**Key Points:**
- Uses flat file structure: `sub_{patient}_ses_{session}_{modality}.npy`
- 98% train / 2% validation split (same as pretraining)
- Excludes `scan_*` files by default (ignore_scan_label=True)
- Same seed = same validation patients across all checkpoints

In [None]:
from data.contrastive_datamodule import ContrastiveDataModule
from data.transforms import get_contrastive_transforms

# Create validation transforms (no augmentation)
# Use conservative mode and val_mode for validation
val_transforms = get_contrastive_transforms(
    keys=("vol1", "vol2"),  # Contrastive pairs
    input_size=INPUT_SIZE,
    conservative_mode=True,  # Minimal augmentation
    val_mode=True,  # No augmentation for validation
    recon=False,  # Not needed for embedding extraction
)

# Create PRETRAIN data module with the same configuration as during pretraining
data_module = ContrastiveDataModule(
    data_dir=VAL_DATA_DIR,
    train_transforms=val_transforms,  # Not used for validation
    val_transforms=val_transforms,
    contrastive_mode="modality_pairs",  # or "regular" - doesn't matter for validation
    input_size=INPUT_SIZE,
    batch_size=1,  # Process one sample at a time for embedding extraction
    num_workers=0,
    seed=SEED,  # Critical: same seed as pretraining!
)

# Setup the data module
data_module.setup('fit')

# Get validation dataloader
val_loader = data_module.val_dataloader()

print(f"Validation dataset size: {len(val_loader.dataset)} samples")
print(f"Validation split: 98% train / 2% validation")
print(f"Validation split seed: {SEED}")
print(f"Data directory: {VAL_DATA_DIR}")
print(f"\nData structure expected:")
print(f"  {VAL_DATA_DIR}/")
print(f"    sub_XXX/")
print(f"      ses_YYY/")
print(f"        t1.npy, flair.npy, dwi.npy, etc.")
print(f"\nNote: 'scan_*' files are automatically excluded (ignore_scan_label=True by default)")
print(f"This validation set is FIXED across all 4 checkpoints (same as used during pretraining)")

## Embedding Extraction Function

Extract **768-dim embeddings** right before the contrastive probing layer.

In [None]:
def extract_embeddings_before_probing(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    device: torch.device,
) -> Tuple[np.ndarray, List[str], List[str], List[str], List[str]]:
    """
    Extract 768-dim embeddings from encoder (before contrastive probing layer).
    
    Args:
        model: Loaded checkpoint model
        dataloader: Validation dataloader (from ContrastiveDataModule)
        device: Device to run on
    
    Returns:
        embeddings: (N, 768) array of embeddings
        labels: List of modality names (from filename)
        patient_ids: List of patient IDs
        session_ids: List of session IDs
        metadata: List of sample identifiers
    """
    model = model.to(device)
    model.eval()
    
    all_embeddings = []
    all_labels = []
    all_patient_ids = []
    all_session_ids = []
    all_metadata = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc='Extracting embeddings')):
            # ContrastiveDataModule returns dict with 'vol1', 'vol2', 'patient', 'session'
            if isinstance(batch, dict):
                # Use vol1 for embedding extraction
                volume = batch['vol1'].to(device)
                
                # Extract patient and session IDs from ContrastivePatientDataset
                patient_id = batch.get('patient', f'unknown_{batch_idx}')
                session_id = batch.get('session', 'unknown')
                
                # Convert to string if tensor
                if torch.is_tensor(patient_id):
                    patient_id = str(patient_id.item())
                if torch.is_tensor(session_id):
                    session_id = str(session_id.item())
                
                all_patient_ids.append(patient_id)
                all_session_ids.append(session_id)
                
                # Extract modality from path if available
                path1 = batch.get('path1', '')
                if path1:
                    # Extract modality from filename (e.g., /path/to/sub_1/ses_1/t1.npy -> t1)
                    modality = Path(path1).stem if isinstance(path1, str) else 'unknown'
                    # Remove numeric suffixes like t1_2 -> t1
                    modality = re.sub(r'_\d+$', '', modality)
                else:
                    modality = 'unknown'
                all_labels.append(modality)
            else:
                # Fallback for unexpected format
                volume = batch[0].to(device) if isinstance(batch, (list, tuple)) else batch.to(device)
                all_labels.append('unknown')
                all_patient_ids.append(f'unknown_{batch_idx}')
                all_session_ids.append('unknown')
            
            # Extract features from encoder (last layer output)
            # For ContrastiveMAEPretrainer
            if hasattr(model, 'encoder'):
                features = model.encoder(volume)[-1]  # (B, 768, D, H, W)
            elif hasattr(model, 'model') and hasattr(model.model, 'encoder'):
                features = model.model.encoder(volume)[-1]  # For FinetuningModule
            else:
                raise AttributeError("Model does not have expected encoder structure")
            
            # Apply global average pooling (same as first step of projection head)
            features_pooled = F.adaptive_avg_pool3d(features, 1)  # (B, 768, 1, 1, 1)
            features_pooled = features_pooled.flatten(1)  # (B, 768)
            
            # Convert to numpy
            embedding = features_pooled.cpu().numpy()
            all_embeddings.append(embedding)
            all_metadata.append(f'{patient_id}_ses_{session_id}_{modality}')
    
    # Concatenate all embeddings
    embeddings = np.concatenate(all_embeddings, axis=0)
    
    print(f"Extracted {len(embeddings)} embeddings with shape {embeddings.shape}")
    print(f"Unique patients: {len(set(all_patient_ids))}")
    print(f"Modalities found: {set(all_labels)}")
    
    return embeddings, all_labels, all_patient_ids, all_session_ids, all_metadata

## Extract Embeddings from All 4 Checkpoints

Process each checkpoint and save embeddings.

In [None]:
# Dictionary to store all embeddings
all_checkpoint_embeddings = {}

for checkpoint_name, checkpoint_path in CHECKPOINTS.items():
    print(f"\n{'='*80}")
    print(f"Processing: {checkpoint_name}")
    print(f"Checkpoint: {checkpoint_path}")
    print(f"{'='*80}")
    
    # Check if checkpoint exists
    if not Path(checkpoint_path).exists():
        print(f"WARNING: Checkpoint not found at {checkpoint_path}")
        print("Skipping...")
        continue
    
    try:
        # Try loading as FinetuningModule first (most likely)
        try:
            model = FinetuningModule.load_from_checkpoint(checkpoint_path)
            print("Loaded as FinetuningModule")
        except Exception as e1:
            # Try loading as ContrastiveMAEPretrainer
            try:
                model = ContrastiveMAEPretrainer.load_from_checkpoint(checkpoint_path)
                print("Loaded as ContrastiveMAEPretrainer")
            except Exception as e2:
                print(f"Failed to load checkpoint:")
                print(f"  As FinetuningModule: {e1}")
                print(f"  As ContrastiveMAEPretrainer: {e2}")
                continue
        
        # Extract embeddings with patient IDs
        embeddings, labels, patient_ids, session_ids, metadata = extract_embeddings_before_probing(
            model=model,
            dataloader=val_loader,
            device=device,
        )
        
        # Store results
        all_checkpoint_embeddings[checkpoint_name] = {
            'embeddings': embeddings,
            'labels': labels,
            'patient_ids': patient_ids,
            'session_ids': session_ids,
            'metadata': metadata,
        }
        
        # Save individual checkpoint embeddings
        save_path = OUTPUT_DIR / f"embeddings_{checkpoint_name.replace(' ', '_').lower()}.npz"
        np.savez(
            save_path,
            embeddings=embeddings,
            labels=np.array(labels),
            patient_ids=np.array(patient_ids),
            session_ids=np.array(session_ids),
            metadata=np.array(metadata),
        )
        print(f"\nSaved embeddings to: {save_path}")
        
        # Clean up GPU memory
        del model
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"Error processing {checkpoint_name}: {e}")
        import traceback
        traceback.print_exc()
        continue

print(f"\n{'='*80}")
print(f"Extraction complete!")
print(f"Processed {len(all_checkpoint_embeddings)} checkpoints")
print(f"Results saved to: {OUTPUT_DIR}")
print(f"{'='*80}")

## Summary Statistics

In [None]:
print("\n" + "="*80)
print("EMBEDDING EXTRACTION SUMMARY")
print("="*80)

for checkpoint_name, data in all_checkpoint_embeddings.items():
    embeddings = data['embeddings']
    labels = data['labels']
    
    print(f"\n{checkpoint_name}:")
    print(f"  Shape: {embeddings.shape}")
    print(f"  Mean norm: {np.linalg.norm(embeddings, axis=1).mean():.4f}")
    print(f"  Std norm: {np.linalg.norm(embeddings, axis=1).std():.4f}")
    
    # Label distribution
    from collections import Counter
    label_counts = Counter(labels)
    print(f"  Label distribution:")
    for label, count in label_counts.items():
        print(f"    {label}: {count}")

print("\n" + "="*80)

## Visualization: PCA Comparison Across Checkpoints

Visualize how the embedding spaces differ across the 4 checkpoints.

In [None]:
# IEEE publication style
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 10,
    'figure.dpi': 100,
    'savefig.dpi': 300,
})

# Create subplots for each checkpoint
n_checkpoints = len(all_checkpoint_embeddings)
fig, axes = plt.subplots(1, n_checkpoints, figsize=(5*n_checkpoints, 5))

if n_checkpoints == 1:
    axes = [axes]

for idx, (checkpoint_name, data) in enumerate(all_checkpoint_embeddings.items()):
    ax = axes[idx]
    
    embeddings = data['embeddings']
    labels = data['labels']
    
    # Apply PCA
    pca = PCA(n_components=2, random_state=42)
    embeddings_2d = pca.fit_transform(embeddings)
    
    # Plot
    unique_labels = sorted(list(set(labels)))
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
    
    for label_idx, label in enumerate(unique_labels):
        mask = np.array(labels) == label
        ax.scatter(
            embeddings_2d[mask, 0],
            embeddings_2d[mask, 1],
            c=[colors[label_idx]],
            label=label,
            alpha=0.6,
            s=30,
            edgecolors='black',
            linewidth=0.5,
        )
    
    ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
    ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
    ax.set_title(f'{checkpoint_name}\n(768-dim embeddings)', fontweight='bold')
    ax.legend(loc='best', fontsize=8)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
save_path = OUTPUT_DIR / 'pca_comparison_all_checkpoints.png'
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved PCA comparison to: {save_path}")
plt.show()

## Visualization: t-SNE Comparison (Optional, slower)

Run t-SNE for more detailed clustering visualization. **Warning: This can be slow for large datasets!**

In [None]:
# Set to True to run t-SNE (can be slow!)
RUN_TSNE = False

if RUN_TSNE:
    fig, axes = plt.subplots(1, n_checkpoints, figsize=(5*n_checkpoints, 5))
    
    if n_checkpoints == 1:
        axes = [axes]
    
    for idx, (checkpoint_name, data) in enumerate(all_checkpoint_embeddings.items()):
        ax = axes[idx]
        
        embeddings = data['embeddings']
        labels = data['labels']
        
        print(f"Running t-SNE for {checkpoint_name}...")
        
        # Apply t-SNE
        tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42, verbose=1)
        embeddings_2d = tsne.fit_transform(embeddings)
        
        # Plot
        unique_labels = sorted(list(set(labels)))
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
        
        for label_idx, label in enumerate(unique_labels):
            mask = np.array(labels) == label
            ax.scatter(
                embeddings_2d[mask, 0],
                embeddings_2d[mask, 1],
                c=[colors[label_idx]],
                label=label,
                alpha=0.6,
                s=30,
                edgecolors='black',
                linewidth=0.5,
            )
        
        ax.set_xlabel('t-SNE Component 1')
        ax.set_ylabel('t-SNE Component 2')
        ax.set_title(f'{checkpoint_name}\n(768-dim embeddings)', fontweight='bold')
        ax.legend(loc='best', fontsize=8)
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    save_path = OUTPUT_DIR / 'tsne_comparison_all_checkpoints.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved t-SNE comparison to: {save_path}")
    plt.show()
else:
    print("t-SNE visualization skipped. Set RUN_TSNE = True to enable.")

## Pairwise Embedding Space Comparison

Compute cosine similarity between the same samples across different checkpoints.

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

if len(all_checkpoint_embeddings) >= 2:
    print("\n" + "="*80)
    print("PAIRWISE CHECKPOINT COMPARISON")
    print("="*80)
    
    checkpoint_names = list(all_checkpoint_embeddings.keys())
    
    # Create similarity matrix
    n = len(checkpoint_names)
    similarity_matrix = np.zeros((n, n))
    
    for i, name_i in enumerate(checkpoint_names):
        emb_i = all_checkpoint_embeddings[name_i]['embeddings']
        
        for j, name_j in enumerate(checkpoint_names):
            emb_j = all_checkpoint_embeddings[name_j]['embeddings']
            
            # Compute average cosine similarity between corresponding samples
            # (assuming same order since we use same validation set)
            cos_sim = np.mean([cosine_similarity(emb_i[k:k+1], emb_j[k:k+1])[0, 0] 
                               for k in range(len(emb_i))])
            similarity_matrix[i, j] = cos_sim
    
    # Plot heatmap
    fig, ax = plt.subplots(figsize=(8, 7))
    
    im = ax.imshow(similarity_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto')
    
    # Set ticks
    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(checkpoint_names, rotation=45, ha='right')
    ax.set_yticklabels(checkpoint_names)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Average Cosine Similarity', rotation=270, labelpad=20)
    
    # Add text annotations
    for i in range(n):
        for j in range(n):
            value = similarity_matrix[i, j]
            text_color = 'white' if value < 0.5 else 'black'
            ax.text(j, i, f'{value:.3f}',
                   ha="center", va="center", color=text_color, fontsize=10, weight='bold')
    
    ax.set_title('Pairwise Embedding Space Similarity\n(Higher = More similar representations)', 
                 fontweight='bold', pad=20)
    
    plt.tight_layout()
    save_path = OUTPUT_DIR / 'pairwise_similarity_matrix.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\nSaved similarity matrix to: {save_path}")
    plt.show()
    
    # Print numerical results
    print("\nPairwise Cosine Similarities:")
    for i in range(n):
        for j in range(i+1, n):
            print(f"  {checkpoint_names[i]:20s} vs {checkpoint_names[j]:20s}: {similarity_matrix[i, j]:.4f}")
else:
    print("Need at least 2 checkpoints for pairwise comparison")

In [None]:
if len(all_checkpoint_embeddings) >= 2 and 'df_patient' in locals():
    # Create violin plots showing distribution of distances per checkpoint pair
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot 1: Cosine Similarity Distribution
    ax = axes[0]
    checkpoint_pairs = df_patient['checkpoint_pair'].unique()
    
    violin_parts = ax.violinplot(
        [df_patient[df_patient['checkpoint_pair'] == pair]['cosine_similarity'].values 
         for pair in checkpoint_pairs],
        positions=range(len(checkpoint_pairs)),
        showmeans=True,
        showmedians=True,
    )
    
    ax.set_xticks(range(len(checkpoint_pairs)))
    ax.set_xticklabels(checkpoint_pairs, rotation=45, ha='right')
    ax.set_ylabel('Cosine Similarity', fontweight='bold')
    ax.set_title('Distribution of Per-Patient Cosine Similarity\nAcross Checkpoint Pairs', 
                 fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim([0, 1])
    
    # Plot 2: L2 Distance Distribution
    ax = axes[1]
    
    violin_parts = ax.violinplot(
        [df_patient[df_patient['checkpoint_pair'] == pair]['l2_distance'].values 
         for pair in checkpoint_pairs],
        positions=range(len(checkpoint_pairs)),
        showmeans=True,
        showmedians=True,
    )
    
    ax.set_xticks(range(len(checkpoint_pairs)))
    ax.set_xticklabels(checkpoint_pairs, rotation=45, ha='right')
    ax.set_ylabel('L2 Distance', fontweight='bold')
    ax.set_title('Distribution of Per-Patient L2 Distance\nAcross Checkpoint Pairs', 
                 fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    save_path = OUTPUT_DIR / 'per_patient_distance_distributions.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved distance distributions to: {save_path}")
    plt.show()
    
    # Print patients with highest/lowest embedding changes
    print("\n" + "="*80)
    print("PATIENTS WITH LARGEST EMBEDDING CHANGES")
    print("="*80)
    
    for pair in checkpoint_pairs:
        df_pair = df_patient[df_patient['checkpoint_pair'] == pair]
        
        print(f"\n{pair}:")
        
        # Most different (lowest cosine similarity)
        most_diff = df_pair.nsmallest(5, 'cosine_similarity')
        print("  Most different embeddings (lowest cosine similarity):")
        for _, row in most_diff.iterrows():
            print(f"    Patient {row['patient_id']}: cos_sim={row['cosine_similarity']:.4f}, L2={row['l2_distance']:.4f}")
        
        # Most similar (highest cosine similarity)
        most_sim = df_pair.nlargest(5, 'cosine_similarity')
        print("  Most similar embeddings (highest cosine similarity):")
        for _, row in most_sim.iterrows():
            print(f"    Patient {row['patient_id']}: cos_sim={row['cosine_similarity']:.4f}, L2={row['l2_distance']:.4f}")
    
    print("\n" + "="*80)
else:
    print("Per-patient distance data not available for visualization")

### Distribution of Per-Patient Embedding Distances

Show how embedding distances vary across patients for each checkpoint pair.

In [None]:
if len(all_checkpoint_embeddings) >= 2:
    checkpoint_names = list(all_checkpoint_embeddings.keys())
    ref_checkpoint = checkpoint_names[0]
    ref_patient_ids = all_checkpoint_embeddings[ref_checkpoint]['patient_ids']
    
    # Check if patient order is consistent
    all_same = all(
        all_checkpoint_embeddings[cp]['patient_ids'] == ref_patient_ids 
        for cp in checkpoint_names
    )
    
    if all_same:
        print("Visualizing patient trajectories across embedding spaces...")
        
        # Concatenate all embeddings from all checkpoints for joint PCA
        all_embs = np.vstack([
            all_checkpoint_embeddings[cp]['embeddings'] 
            for cp in checkpoint_names
        ])
        
        # Fit PCA on all embeddings together
        pca = PCA(n_components=2, random_state=42)
        all_embs_2d = pca.fit_transform(all_embs)
        
        # Split back into per-checkpoint embeddings
        n_samples = len(ref_patient_ids)
        checkpoint_embs_2d = {}
        for i, cp in enumerate(checkpoint_names):
            start_idx = i * n_samples
            end_idx = (i + 1) * n_samples
            checkpoint_embs_2d[cp] = all_embs_2d[start_idx:end_idx]
        
        # Select a few patients to visualize
        unique_patients = sorted(list(set(ref_patient_ids)))
        n_patients_to_show = min(10, len(unique_patients))
        selected_patients = unique_patients[:n_patients_to_show]
        
        # Create visualization
        fig, axes = plt.subplots(1, 2, figsize=(16, 7))
        
        # Plot 1: All patients with trajectories
        ax = axes[0]
        colors = plt.cm.tab20(np.linspace(0, 1, n_patients_to_show))
        
        for patient_idx, patient_id in enumerate(selected_patients):
            # Get first occurrence index for this patient
            idx = next(i for i, pid in enumerate(ref_patient_ids) if pid == patient_id)
            
            # Get embeddings for this patient across all checkpoints
            patient_traj = []
            for cp in checkpoint_names:
                emb_2d = checkpoint_embs_2d[cp][idx]
                patient_traj.append(emb_2d)
            patient_traj = np.array(patient_traj)
            
            # Plot trajectory
            ax.plot(
                patient_traj[:, 0], 
                patient_traj[:, 1],
                'o-',
                color=colors[patient_idx],
                label=f'Patient {patient_id}',
                linewidth=2,
                markersize=8,
                alpha=0.7,
            )
            
            # Add checkpoint labels
            for i, cp_name in enumerate(checkpoint_names):
                ax.annotate(
                    cp_name[:3],  # Short label
                    (patient_traj[i, 0], patient_traj[i, 1]),
                    fontsize=6,
                    ha='center',
                    va='bottom',
                )
        
        ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)', fontweight='bold')
        ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)', fontweight='bold')
        ax.set_title(
            f'Patient Trajectories Across {len(checkpoint_names)} Checkpoints\n'
            f'(Showing {n_patients_to_show} patients)',
            fontweight='bold'
        )
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
        ax.grid(True, alpha=0.3)
        
        # Plot 2: Checkpoint-colored view
        ax = axes[1]
        checkpoint_colors = plt.cm.Set1(np.linspace(0, 1, len(checkpoint_names)))
        
        for cp_idx, cp_name in enumerate(checkpoint_names):
            embs_2d = checkpoint_embs_2d[cp_name]
            ax.scatter(
                embs_2d[:, 0],
                embs_2d[:, 1],
                c=[checkpoint_colors[cp_idx]],
                label=cp_name,
                alpha=0.5,
                s=50,
                edgecolors='black',
                linewidth=0.5,
            )
        
        ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)', fontweight='bold')
        ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)', fontweight='bold')
        ax.set_title(
            'All Patients Colored by Checkpoint\n'
            '(Joint PCA projection)',
            fontweight='bold'
        )
        ax.legend(loc='best', fontsize=10)
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        save_path = OUTPUT_DIR / 'patient_trajectories_across_checkpoints.png'
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nSaved patient trajectories to: {save_path}")
        plt.show()
        
    else:
        print("Cannot visualize trajectories - patient order differs between checkpoints")
else:
    print("Need at least 2 checkpoints for trajectory visualization")

### Visualization: Patient Trajectories Across Embedding Spaces

Visualize how individual patients "move" in the embedding space across different checkpoints using PCA projection.

In [None]:
if len(all_checkpoint_embeddings) >= 2:
    print("\n" + "="*80)
    print("PER-PATIENT EMBEDDING DISTANCE ANALYSIS")
    print("="*80)
    
    checkpoint_names = list(all_checkpoint_embeddings.keys())
    
    # Get first checkpoint as reference for patient list
    ref_checkpoint = checkpoint_names[0]
    ref_patient_ids = all_checkpoint_embeddings[ref_checkpoint]['patient_ids']
    
    # Verify all checkpoints have same patients in same order
    all_same = True
    for cp_name in checkpoint_names[1:]:
        cp_patient_ids = all_checkpoint_embeddings[cp_name]['patient_ids']
        if cp_patient_ids != ref_patient_ids:
            print(f"WARNING: Patient order differs between {ref_checkpoint} and {cp_name}")
            all_same = False
    
    if all_same:
        print(f"\nAll checkpoints have the same {len(ref_patient_ids)} patients in the same order.")
        print("Computing per-patient embedding distances across checkpoints...\n")
        
        # Compute per-patient cosine distances across all checkpoint pairs
        unique_patients = sorted(list(set(ref_patient_ids)))
        
        # Create DataFrame to store results
        per_patient_results = []
        
        for patient_id in unique_patients:
            # Get indices for this patient across all checkpoints
            patient_indices = [i for i, pid in enumerate(ref_patient_ids) if pid == patient_id]
            
            if len(patient_indices) == 0:
                continue
            
            # For simplicity, take first session if patient has multiple
            idx = patient_indices[0]
            
            # Compute pairwise distances for this patient across checkpoints
            for i in range(len(checkpoint_names)):
                for j in range(i+1, len(checkpoint_names)):
                    cp_i = checkpoint_names[i]
                    cp_j = checkpoint_names[j]
                    
                    emb_i = all_checkpoint_embeddings[cp_i]['embeddings'][idx]
                    emb_j = all_checkpoint_embeddings[cp_j]['embeddings'][idx]
                    
                    # Compute cosine similarity
                    cos_sim = cosine_similarity(emb_i.reshape(1, -1), emb_j.reshape(1, -1))[0, 0]
                    # Compute L2 distance
                    l2_dist = np.linalg.norm(emb_i - emb_j)
                    
                    per_patient_results.append({
                        'patient_id': patient_id,
                        'checkpoint_pair': f'{cp_i} vs {cp_j}',
                        'cosine_similarity': cos_sim,
                        'l2_distance': l2_dist,
                    })
        
        # Create DataFrame
        df_patient = pd.DataFrame(per_patient_results)
        
        # Summary statistics
        print("Average embedding distances per checkpoint pair:")
        print("-" * 80)
        for pair in df_patient['checkpoint_pair'].unique():
            df_pair = df_patient[df_patient['checkpoint_pair'] == pair]
            avg_cos = df_pair['cosine_similarity'].mean()
            std_cos = df_pair['cosine_similarity'].std()
            avg_l2 = df_pair['l2_distance'].mean()
            std_l2 = df_pair['l2_distance'].std()
            print(f"{pair:40s}")
            print(f"  Cosine Similarity: {avg_cos:.4f} ± {std_cos:.4f}")
            print(f"  L2 Distance:       {avg_l2:.4f} ± {std_l2:.4f}")
            print()
        
        # Save per-patient results
        df_patient.to_csv(OUTPUT_DIR / 'per_patient_distances.csv', index=False)
        print(f"Saved per-patient distances to: {OUTPUT_DIR / 'per_patient_distances.csv'}")
        
    else:
        print("Cannot perform per-patient analysis - patient order differs between checkpoints")
else:
    print("Need at least 2 checkpoints for per-patient comparison")

## Per-Patient Analysis Across Checkpoints

Analyze how **the same patient** is embedded differently across the 4 checkpoints.

## Next Steps

Now that you have the 768-dim embeddings extracted from all 4 checkpoints, you can:

1. **Further Analysis**:
   - Cluster analysis (k-means, hierarchical clustering)
   - Linear probing evaluation
   - Representation quality metrics (silhouette score, etc.)
   
2. **Visualization**:
   - UMAP visualization (alternative to t-SNE)
   - Per-class embedding distributions
   - Attention maps visualization
   
3. **Comparison Metrics**:
   - CKA (Centered Kernel Alignment) between checkpoints
   - Representation stability analysis
   - Downstream task performance correlation

All embeddings are saved in: `{OUTPUT_DIR}`