In [2]:
# Testing 5-2 and 40-10 autoencoders on parameter decomposition with greedy and one-to-one matchings 

import sys
import os
import yaml
import random
import numpy as np
from dataclasses import dataclass, asdict
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from scipy.optimize import linear_sum_assignment
from tqdm import trange
import matplotlib.pyplot as plt

sys.path.append('../../')

from tqdm import tqdm
from Faithful_SAE.models import Faithful_SAE, ToySuperpositionAE
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from Faithful_SAE.train_ae import sample_sparse_batch, TrainingConfig, Config

In [3]:
def train_ae(cfg: Config, device: str) -> torch.nn.Module:
    ae = ToySuperpositionAE(cfg.input_dim, cfg.latent_dim).to(device)
    opt = torch.optim.Adam(ae.parameters(), lr=cfg.ae_lr)
    loss_fn = nn.MSELoss()

    for _ in tqdm(range(cfg.sae_steps), desc="train dense AE"):
        x = sample_sparse_batch(cfg.batch_size, cfg.input_dim,
                                p_extra=cfg.sparsity_p, device=device)
        loss = loss_fn(ae(x), x)
        opt.zero_grad(); loss.backward(); opt.step()

    print(f"final dense-AE loss: {loss.item():.4e}")
    ae.eval();
    return ae

def train_sae(cfg, ae): 
    sae = Faithful_SAE(cfg.input_dim, cfg.concept_dim, cfg.latent_dim, k=cfg.k).to(cfg.device)
    opt = torch.optim.Adam(sae.parameters(), lr=cfg.sae_lr)
    best_loss = float('inf')
    pbar = trange(cfg.sae_steps, desc="train Faithful_SAE")
    for _ in pbar:
        
        x = sample_sparse_batch(cfg.batch_size, cfg.input_dim, p_extra=cfg.sparsity_p, device=cfg.device)
        sae_out, sparse_latent = sae(x)
        
        with torch.no_grad():
            target_out = ae.encode(x)
            
        recon_loss = F.mse_loss(sae_out, target_out)
        faithful_loss = F.mse_loss(sae.effective_encoder(), ae.encoder_weights)
        loss = cfg.recon_lam * recon_loss + cfg.faithful_lam * faithful_loss 
        opt.zero_grad(); loss.backward(); opt.step()
        if loss.item() < best_loss:
            best_loss = loss.item()
        pbar.set_postfix(recon=recon_loss.item(), faithful=faithful_loss.item(), best=best_loss)
    return sae 
    
def greedy_matching_ae_sae(ae_encoder, sae_components):
    ae_encoder_np = ae_encoder.detach().cpu().numpy()
    sae_components_np = sae_components.detach().cpu().numpy()
    
    num_components, num_features, num_hidden = sae_components_np.shape
    
    all_sae_rows = []
    sae_indices = [] 
    
    for c_idx in range(num_components):
        for f_idx in range(num_features):
            all_sae_rows.append(sae_components_np[c_idx, f_idx])
            sae_indices.append((c_idx, f_idx))
    
    all_sae_rows = np.array(all_sae_rows)
    
    matches = []
    max_cosine_sims = []
    
    for ae_row_idx in range(ae_encoder_np.shape[0]):
        ae_row = ae_encoder_np[ae_row_idx].reshape(1, -1)
        
        similarities = cosine_similarity(ae_row, all_sae_rows)[0]
        
        best_match_idx = np.argmax(similarities)
        best_similarity = similarities[best_match_idx]
        
        component_idx, feature_idx = sae_indices[best_match_idx]
        
        matches.append({
            'ae_row_idx': ae_row_idx,
            'sae_component_idx': component_idx,
            'sae_feature_idx': feature_idx,
            'similarity': best_similarity,
            'ae_vector': ae_encoder_np[ae_row_idx],
            'sae_vector': all_sae_rows[best_match_idx]
        })
        
        max_cosine_sims.append(best_similarity)
    
    mean_max_cosine_sim = np.mean(max_cosine_sims)
    
    return mean_max_cosine_sim, matches

def hungarian_matching_ae_sae(ae_encoder, sae_components):
    ae_encoder_np = ae_encoder.detach().cpu().numpy()
    sae_components_np = sae_components.detach().cpu().numpy()
    
    num_components, num_features, num_hidden = sae_components_np.shape
    num_ae_features = ae_encoder_np.shape[0]
    
    component_best_similarities = np.zeros((num_ae_features, num_components))
    component_best_feature_indices = np.zeros((num_ae_features, num_components), dtype=int)
    
    for ae_row_idx in range(num_ae_features):
        ae_row = ae_encoder_np[ae_row_idx].reshape(1, -1)
        
        for c_idx in range(num_components):
            component_rows = sae_components_np[c_idx]  
            
            similarities = cosine_similarity(ae_row, component_rows)[0]
            
            best_feature_idx = np.argmax(similarities)
            best_similarity = similarities[best_feature_idx]
            
            component_best_similarities[ae_row_idx, c_idx] = best_similarity
            component_best_feature_indices[ae_row_idx, c_idx] = best_feature_idx
    
    cost_matrix = -component_best_similarities
    
    if num_components > num_ae_features:
        ae_indices, component_indices = linear_sum_assignment(cost_matrix)
    else:
        ae_indices, component_indices = linear_sum_assignment(cost_matrix[:, :min(num_components, num_ae_features)])
    
    matches = []
    max_cosine_sims = []
    
    for ae_idx, comp_idx in zip(ae_indices, component_indices):
        best_similarity = component_best_similarities[ae_idx, comp_idx]
        best_feature_idx = component_best_feature_indices[ae_idx, comp_idx]
        
        matches.append({
            'ae_row_idx': ae_idx,
            'sae_component_idx': comp_idx,
            'sae_feature_idx': best_feature_idx,
            'similarity': best_similarity,
            'ae_vector': ae_encoder_np[ae_idx],
            'sae_vector': sae_components_np[comp_idx, best_feature_idx]
        })
        
        max_cosine_sims.append(best_similarity)
    
    mean_max_cosine_sim = np.mean(max_cosine_sims)
    
    return mean_max_cosine_sim, matches

def plot_matching_visualization(ae_encoder, matches, method_name, filename):
    ae_encoder_np = ae_encoder.detach().cpu().numpy()
    
    colors = plt.cm.tab10(np.linspace(0, 1, ae_encoder_np.shape[0]))
    
    plt.figure(figsize=(10, 8))
    
    legend_labels = []
    for i, ae_vector in enumerate(ae_encoder_np):
        plt.arrow(0, 0, ae_vector[0], ae_vector[1], 
                 head_width=0.03, head_length=0.03, 
                 fc=colors[i], ec=colors[i], linewidth=2, 
                 linestyle='-', alpha=0.8, length_includes_head=True)
        
        plt.text(ae_vector[0] * 1.1, ae_vector[1] * 1.1, f'AE F{i}', 
                fontsize=10, ha='center', va='center', 
                color=colors[i], weight='bold')
    
    for match in matches:
        ae_row_idx = match['ae_row_idx']
        sae_vector = match['sae_vector']
        similarity = match['similarity']
        component_idx = match['sae_component_idx']
        feature_idx = match['sae_feature_idx']
        
        plt.arrow(0, 0, sae_vector[0], sae_vector[1], 
                 head_width=0.025, head_length=0.025, 
                 fc=colors[ae_row_idx], ec=colors[ae_row_idx], 
                 linewidth=2, linestyle='--', alpha=0.6, 
                 length_includes_head=True)
        
        plt.text(sae_vector[0] * 1.1, sae_vector[1] * 1.1, 
                f'C{component_idx}F{feature_idx}', 
                fontsize=8, ha='center', va='center', 
                color=colors[ae_row_idx])
        
        if ae_row_idx < len(legend_labels):
            legend_labels[ae_row_idx] = f'AE F{ae_row_idx} ↔ C{component_idx}F{feature_idx} (sim={similarity:.3f})'
        else:
            legend_labels.append(f'AE F{ae_row_idx} ↔ C{component_idx}F{feature_idx} (sim={similarity:.3f})')
    
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    
    plt.axhline(y=0, color='k', linewidth=0.5, alpha=0.3)
    plt.axvline(x=0, color='k', linewidth=0.5, alpha=0.3)
    
    plt.xlabel('Hidden Dimension 1')
    plt.ylabel('Hidden Dimension 2')
    plt.title(f'{method_name} Vector Matching\n(Solid = AE Original, Dashed = SAE Component)')
    
    legend_elements = []
    for i, label in enumerate(legend_labels):
        legend_elements.append(plt.Line2D([0], [0], color=colors[i], lw=2, label=label))
    
    plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Saved visualization: {filename}")

def print_matching_results(greedy_score, greedy_matches, hungarian_score, hungarian_matches):
    print("=" * 80)
    print("AE vs SAE Matching Results")
    print("=" * 80)
    
    print(f"\n🎯 GREEDY MATCHING RESULTS:")
    print(f"   Mean Max Cosine Similarity: {greedy_score:.6f}")
    print(f"   Number of matches: {len(greedy_matches)}")
    
    print(f"\n🎯 HUNGARIAN MATCHING RESULTS (Component-Constrained):")
    print(f"   Mean Max Cosine Similarity: {hungarian_score:.6f}")
    print(f"   Number of matches: {len(hungarian_matches)}")
    
    print(f"\n📊 COMPARISON:")
    print(f"   Difference: {greedy_score - hungarian_score:.6f}")
    print(f"   Relative improvement (Greedy vs Hungarian): {((greedy_score - hungarian_score) / hungarian_score * 100):.2f}%")
    
    print(f"\n📈 COMPONENT USAGE ANALYSIS:")
    
    greedy_component_usage = {}
    for match in greedy_matches:
        comp_idx = match['sae_component_idx']
        greedy_component_usage[comp_idx] = greedy_component_usage.get(comp_idx, 0) + 1
    
    print(f"   Greedy - Unique components used: {len(greedy_component_usage)}/512")
    print(f"   Greedy - Max uses of single component: {max(greedy_component_usage.values())}")
    
    hungarian_component_usage = {}
    for match in hungarian_matches:
        comp_idx = match['sae_component_idx']
        hungarian_component_usage[comp_idx] = hungarian_component_usage.get(comp_idx, 0) + 1
    
    print(f"   Hungarian - Unique components used: {len(hungarian_component_usage)}/512")
    print(f"   Hungarian - Max uses of single component: {max(hungarian_component_usage.values())}")
    
    greedy_sims = [match['similarity'] for match in greedy_matches]
    hungarian_sims = [match['similarity'] for match in hungarian_matches]
    
    print(f"\n📊 SIMILARITY STATISTICS:")
    print(f"   Greedy - Min: {min(greedy_sims):.4f}, Max: {max(greedy_sims):.4f}, Std: {np.std(greedy_sims):.4f}")
    print(f"   Hungarian - Min: {min(hungarian_sims):.4f}, Max: {max(hungarian_sims):.4f}, Std: {np.std(hungarian_sims):.4f}")

In [4]:
cfg_40_10 = TrainingConfig(k=1) 
cfg_5_2 = TrainingConfig(concepts=64, concept_dim=64, input_dim=5, latent_dim=2, k=1) 

ae_40_10 = train_ae(cfg_40_10, cfg_40_10.device)
ae_5_2 = train_ae(cfg_5_2, cfg_5_2.device) 

real_sae_40_10 = train_sae(cfg_40_10, ae_40_10)
real_sae_5_2 = train_sae(cfg_5_2, ae_5_2)

train dense AE: 100%|██████████| 15000/15000 [00:10<00:00, 1498.04it/s]


final dense-AE loss: 3.3565e-03


train dense AE: 100%|██████████| 15000/15000 [00:10<00:00, 1456.92it/s]


final dense-AE loss: 2.8639e-03


train Faithful_SAE: 100%|██████████| 15000/15000 [00:21<00:00, 705.03it/s, best=0.0857, faithful=1.14e-6, recon=0.0248]
train Faithful_SAE: 100%|██████████| 15000/15000 [00:20<00:00, 715.79it/s, best=1.03e-5, faithful=3.01e-7, recon=0.001]   


In [11]:
cfg_5_2.l1_lam = 0.00
cfg_40_10.l1_lam = 0.00

In [5]:

print('='*100)
print("Results for 40-10 encoder...")
print('='*100)

print("Computing Greedy Matching...")
greedy_score_40_10, greedy_matches_40_10 = greedy_matching_ae_sae(ae_40_10.encoder_weights, real_sae_40_10.components())

print("Computing Hungarian Matching...")
hungarian_score_40_10, hungarian_matches_40_10 = hungarian_matching_ae_sae(ae_40_10.encoder_weights, real_sae_40_10.components())

print_matching_results(greedy_score_40_10, greedy_matches_40_10, hungarian_score_40_10, hungarian_matches_40_10)

print('='*100)
print("Results for 5-2 encoder...")
print('='*100)

print("Computing Greedy Matching...")
greedy_score_5_2, greedy_matches_5_2 = greedy_matching_ae_sae(ae_5_2.encoder_weights, real_sae_5_2.components())

print("Computing Hungarian Matching...")
hungarian_score_5_2, hungarian_matches_5_2 = hungarian_matching_ae_sae(ae_5_2.encoder_weights, real_sae_5_2.components())

print_matching_results(greedy_score_5_2, greedy_matches_5_2, hungarian_score_5_2, hungarian_matches_5_2)

# Create visualizations for 5->2 encoder
print("\nCreating visualizations for 5->2 encoder...")
plot_matching_visualization(ae_5_2.encoder_weights, greedy_matches_5_2, 
                            "Greedy", "greedy_5_2_figure.png")
plot_matching_visualization(ae_5_2.encoder_weights, hungarian_matches_5_2, 
                            "Hungarian", "hungarian_5_2_figure.png")

Results for 40-10 encoder...
Computing Greedy Matching...
Computing Hungarian Matching...
AE vs SAE Matching Results

🎯 GREEDY MATCHING RESULTS:
   Mean Max Cosine Similarity: 0.999564
   Number of matches: 40

🎯 HUNGARIAN MATCHING RESULTS (Component-Constrained):
   Mean Max Cosine Similarity: 0.999564
   Number of matches: 40

📊 COMPARISON:
   Difference: 0.000000
   Relative improvement (Greedy vs Hungarian): 0.00%

📈 COMPONENT USAGE ANALYSIS:
   Greedy - Unique components used: 40/512
   Greedy - Max uses of single component: 1
   Hungarian - Unique components used: 40/512
   Hungarian - Max uses of single component: 1

📊 SIMILARITY STATISTICS:
   Greedy - Min: 0.9991, Max: 0.9999, Std: 0.0002
   Hungarian - Min: 0.9991, Max: 0.9999, Std: 0.0002
Results for 5-2 encoder...
Computing Greedy Matching...
Computing Hungarian Matching...
AE vs SAE Matching Results

🎯 GREEDY MATCHING RESULTS:
   Mean Max Cosine Similarity: 0.999966
   Number of matches: 5

🎯 HUNGARIAN MATCHING RESULTS (Co

In [6]:
real_sae_5_2.effective_encoder()

tensor([[ 0.3303,  1.0862],
        [-0.9053, -0.6939],
        [ 1.1320,  0.0304],
        [ 0.3719, -1.0694],
        [-0.9303,  0.6538]], device='cuda:0', grad_fn=<MmBackward0>)

In [7]:
ae_5_2.encoder_weights

Parameter containing:
tensor([[ 0.3299,  1.0852],
        [-0.9046, -0.6932],
        [ 1.1317,  0.0305],
        [ 0.3721, -1.0687],
        [-0.9305,  0.6529]], device='cuda:0', requires_grad=True)

In [10]:
def greedy_matching_ae_sae_l2(ae_encoder, sae_components):
    """
    Greedy matching using L2 distance: each AE row can match to any SAE component row.
    """
    ae_encoder_np = ae_encoder.detach().cpu().numpy()
    sae_components_np = sae_components.detach().cpu().numpy()
    
    num_components, num_features, num_hidden = sae_components_np.shape
    
    all_sae_rows = []
    sae_indices = [] 
    
    for c_idx in range(num_components):
        for f_idx in range(num_features):
            all_sae_rows.append(sae_components_np[c_idx, f_idx])
            sae_indices.append((c_idx, f_idx))
    
    all_sae_rows = np.array(all_sae_rows)
    
    matches = []
    used_component_indices = set()
    min_distances = []
    
    for ae_row_idx in range(ae_encoder_np.shape[0]):
        ae_row = ae_encoder_np[ae_row_idx]
        
        best_distance = float('inf')
        best_match_idx = -1
        
        for comp_row_idx, comp_row in enumerate(all_sae_rows):
            if comp_row_idx in used_component_indices:
                continue
                
            # Calculate L2 distance
            distance = np.linalg.norm(ae_row - comp_row)
            
            if distance < best_distance:
                best_distance = distance
                best_match_idx = comp_row_idx
        
        if best_match_idx != -1:
            component_idx, feature_idx = sae_indices[best_match_idx]
            
            matches.append({
                'ae_row_idx': ae_row_idx,
                'sae_component_idx': component_idx,
                'sae_feature_idx': feature_idx,
                'distance': best_distance,
                'ae_vector': ae_encoder_np[ae_row_idx],
                'sae_vector': all_sae_rows[best_match_idx]
            })
            
            used_component_indices.add(best_match_idx)
            min_distances.append(best_distance)
    
    mean_min_distance = np.mean(min_distances)
    
    return mean_min_distance, matches

def hungarian_matching_ae_sae_l2(ae_encoder, sae_components):
    """
    Hungarian matching using L2 distance: each AE row matches to a unique SAE component.
    """
    ae_encoder_np = ae_encoder.detach().cpu().numpy()
    sae_components_np = sae_components.detach().cpu().numpy()
    
    num_components, num_features, num_hidden = sae_components_np.shape
    num_ae_features = ae_encoder_np.shape[0]
    
    component_best_distances = np.full((num_ae_features, num_components), float('inf'))
    component_best_feature_indices = np.zeros((num_ae_features, num_components), dtype=int)
    
    for ae_row_idx in range(num_ae_features):
        ae_row = ae_encoder_np[ae_row_idx]
        
        for c_idx in range(num_components):
            component_rows = sae_components_np[c_idx]  
            
            for f_idx in range(num_features):
                # Calculate L2 distance
                distance = np.linalg.norm(ae_row - component_rows[f_idx])
                
                if distance < component_best_distances[ae_row_idx, c_idx]:
                    component_best_distances[ae_row_idx, c_idx] = distance
                    component_best_feature_indices[ae_row_idx, c_idx] = f_idx
    
    # Use distances directly as cost matrix (no negation needed since we want minimum)
    cost_matrix = component_best_distances
    
    if num_components > num_ae_features:
        ae_indices, component_indices = linear_sum_assignment(cost_matrix)
    else:
        ae_indices, component_indices = linear_sum_assignment(cost_matrix[:, :min(num_components, num_ae_features)])
    
    matches = []
    min_distances = []
    
    for ae_idx, comp_idx in zip(ae_indices, component_indices):
        best_distance = component_best_distances[ae_idx, comp_idx]
        best_feature_idx = component_best_feature_indices[ae_idx, comp_idx]
        
        matches.append({
            'ae_row_idx': ae_idx,
            'sae_component_idx': comp_idx,
            'sae_feature_idx': best_feature_idx,
            'distance': best_distance,
            'ae_vector': ae_encoder_np[ae_idx],
            'sae_vector': sae_components_np[comp_idx, best_feature_idx]
        })
        
        min_distances.append(best_distance)
    
    mean_min_distance = np.mean(min_distances)
    
    return mean_min_distance, matches

In [15]:
def plot_matching_visualization_styled(ae_encoder, matches, method_name, filename):
    """
    Plot AE encoder vectors and their matched SAE vectors as arrows from origin.
    Styled to match the reference visualization.
    """
    ae_encoder_np = ae_encoder.detach().cpu().numpy()
    
    # Use the same color scheme as reference
    colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Plot matches
    for i, match in enumerate(matches):
        color = colors[i % len(colors)]
        
        # Original AE vector (solid arrow)
        orig_vec = match['ae_vector']
        ax.arrow(0, 0, orig_vec[0], orig_vec[1], 
                head_width=0.05, head_length=0.05, 
                fc=color, ec=color, linewidth=2, alpha=0.8)
        
        # Matched SAE vector (dashed arrow)
        sae_vec = match['sae_vector']
        ax.arrow(0, 0, sae_vec[0], sae_vec[1], 
                head_width=0.05, head_length=0.05, 
                fc=color, ec=color, linewidth=2, alpha=0.6, linestyle='--')
        
        # Labels
        ax.text(orig_vec[0], orig_vec[1], 
               f'AE R{match["ae_row_idx"]}', 
               fontsize=9, ha='center', va='bottom', color=color, weight='bold')
        
        ax.text(sae_vec[0], sae_vec[1], 
               f'C{match["sae_component_idx"]}R{match["sae_feature_idx"]}', 
               fontsize=9, ha='center', va='top', color=color, weight='bold')
    
    # Styling to match reference
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color='k', linewidth=0.5)
    ax.axvline(x=0, color='k', linewidth=0.5)
    
    ax.set_xlabel('Hidden Dimension 1', fontsize=12)
    ax.set_ylabel('Hidden Dimension 2', fontsize=12)
    ax.set_title(f'{method_name} Matching: AE Weights vs SAE Components\n(Solid = AE Original, Dashed = SAE Component)', fontsize=14)
    
    # Create legend
    legend_elements = []
    for i, match in enumerate(matches):
        color = colors[i % len(colors)]
        # Use distance or similarity depending on what's available
        if 'distance' in match:
            metric_str = f'dist={match["distance"]:.3f}'
        else:
            metric_str = f'sim={match["similarity"]:.3f}'
        
        legend_elements.append(plt.Line2D([0], [0], color=color, lw=2, 
                                        label=f'AE R{match["ae_row_idx"]} ↔ C{match["sae_component_idx"]}R{match["sae_feature_idx"]} ({metric_str})'))
    
    ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.3, 1))
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Saved visualization: {filename}")

def plot_matching_visualization(ae_encoder, matches, method_name, filename):
    """
    Plot AE encoder vectors and their matched SAE vectors as arrows from origin.
    For 5->2 encoder, plots vectors in 2D space.
    """
    ae_encoder_np = ae_encoder.detach().cpu().numpy()
    
    # Create color map for each AE row
    colors = plt.cm.tab10(np.linspace(0, 1, ae_encoder_np.shape[0]))
    
    plt.figure(figsize=(10, 8))
    
    # Plot AE encoder vectors as solid arrows from origin
    legend_labels = []
    for i, ae_vector in enumerate(ae_encoder_np):
        plt.arrow(0, 0, ae_vector[0], ae_vector[1], 
                 head_width=0.03, head_length=0.03, 
                 fc=colors[i], ec=colors[i], linewidth=2, 
                 linestyle='-', alpha=0.8, length_includes_head=True)
        
        # Add label at the end of the arrow
        plt.text(ae_vector[0] * 1.1, ae_vector[1] * 1.1, f'AE F{i}', 
                fontsize=10, ha='center', va='center', 
                color=colors[i], weight='bold')
    
    # Plot matched SAE vectors as dashed arrows from origin
    for match in matches:
        ae_row_idx = match['ae_row_idx']
        sae_vector = match['sae_vector']
        similarity = match['distance']
        component_idx = match['sae_component_idx']
        feature_idx = match['sae_feature_idx']
        
        plt.arrow(0, 0, sae_vector[0], sae_vector[1], 
                 head_width=0.025, head_length=0.025, 
                 fc=colors[ae_row_idx], ec=colors[ae_row_idx], 
                 linewidth=2, linestyle='--', alpha=0.6, 
                 length_includes_head=True)
        
        # Add label for SAE vector
        plt.text(sae_vector[0] * 1.1, sae_vector[1] * 1.1, 
                f'C{component_idx}F{feature_idx}', 
                fontsize=8, ha='center', va='center', 
                color=colors[ae_row_idx])
        
        # Create legend entry
        if ae_row_idx < len(legend_labels):
            legend_labels[ae_row_idx] = f'AE F{ae_row_idx} ↔ C{component_idx}F{feature_idx} (sim={similarity:.3f})'
        else:
            legend_labels.append(f'AE F{ae_row_idx} ↔ C{component_idx}F{feature_idx} (sim={similarity:.3f})')
    
    # Set equal aspect ratio and add grid
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    
    # Add axes through origin
    plt.axhline(y=0, color='k', linewidth=0.5, alpha=0.3)
    plt.axvline(x=0, color='k', linewidth=0.5, alpha=0.3)
    
    plt.xlabel('Hidden Dimension 1')
    plt.ylabel('Hidden Dimension 2')
    plt.title(f'{method_name} Vector Matching\n(Solid = AE Original, Dashed = SAE Component)')
    
    # Create custom legend
    legend_elements = []
    for i, label in enumerate(legend_labels):
        legend_elements.append(plt.Line2D([0], [0], color=colors[i], lw=2, label=label))
    
    plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    
    # Save the figure
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Saved visualization: {filename}")

In [16]:
print("Computing Greedy Matching...")
greedy_score_5_2, greedy_matches_5_2 = greedy_matching_ae_sae_l2(ae_5_2.encoder_weights, real_sae_5_2.components())

print("Computing Hungarian Matching...")
hungarian_score_5_2, hungarian_matches_5_2 = hungarian_matching_ae_sae_l2(ae_5_2.encoder_weights, real_sae_5_2.components())

print("\nCreating visualizations for 5->2 encoder...")
plot_matching_visualization(ae_5_2.encoder_weights, greedy_matches_5_2, 
                            "Greedy", "greedy_5_2_figure.png")
plot_matching_visualization(ae_5_2.encoder_weights, hungarian_matches_5_2, 
                            "Hungarian", "hungarian_5_2_figure.png")


Computing Greedy Matching...
Computing Hungarian Matching...

Creating visualizations for 5->2 encoder...
Saved visualization: greedy_5_2_figure.png
Saved visualization: hungarian_5_2_figure.png
