In [12]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics.pairwise import cosine_similarity

class MatrixBindingDataset:
    def __init__(self, n_dim=256, n_features=100, p=0.1, n_samples=10000):
        """
        Generate synthetic data with matrix bindings
        
        Args:
            n_dim: dimension of feature space
            n_features: number of base features
            p: probability of feature activation (Bernoulli)
            n_samples: number of samples to generate
        """
        self.n_dim = n_dim
        self.n_features = n_features
        self.p = p
        
        # Generate random orthonormal base features
        features_raw = np.random.randn(n_features, n_dim)
        q, r = np.linalg.qr(features_raw.T)
        self.base_features = q.T  # Shape: [n_features, n_dim]
        
        # Generate random orthogonal binding matrix
        binding_raw = np.random.randn(n_dim, n_dim)
        q, r = np.linalg.qr(binding_raw)
        self.binding_matrix = q  # Shape: [n_dim, n_dim]
        
        # Generate dataset
        self.data = []
        self.content_vectors = []
        self.binding_vectors = []
        
        for _ in range(n_samples):
            # Generate sparse coefficients
            content_coef = (np.random.random(n_features) < p).astype(float)
            binding_coef = (np.random.random(n_features) < p).astype(float)
            
            # Create content and binding vectors
            content = content_coef @ self.base_features  # [n_features] @ [n_features, n_dim]
            binding = binding_coef @ self.base_features
            
            # Apply matrix binding
            bound = content + self.binding_matrix @ binding
            
            self.data.append(bound)
            self.content_vectors.append(content_coef)
            self.binding_vectors.append(binding_coef)
            
        self.data = np.stack(self.data)
        self.content_vectors = np.stack(self.content_vectors)
        self.binding_vectors = np.stack(self.binding_vectors)

class TorchDataset(Dataset):
    def __init__(self, data):
        self.data = torch.FloatTensor(data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.ReLU()
        )
        
        self.decoder = nn.Linear(latent_dim, input_dim)
        
    def forward(self, x):
        h = self.encoder(x)
        x_hat = self.decoder(h)
        return x_hat, h

def train_sae(dataset, latent_dim, lambda_l1=0.1, batch_size=128, n_epochs=100, device='cuda'):
    """Train Sparse Autoencoder"""
    torch_dataset = TorchDataset(dataset.data)
    dataloader = DataLoader(torch_dataset, batch_size=batch_size, shuffle=True)
    
    model = SparseAutoencoder(dataset.n_dim, latent_dim).to(device)
    optimizer = optim.Adam(model.parameters())
    
    for epoch in range(n_epochs):
        total_loss = 0
        total_l1 = 0
        
        for batch in dataloader:
            batch = batch.to(device)
            
            # Forward pass
            x_hat, h = model(batch)
            
            # Compute losses
            rec_loss = nn.MSELoss()(x_hat, batch)
            l1_loss = torch.mean(torch.abs(h))
            
            loss = rec_loss + lambda_l1 * l1_loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += rec_loss.item()
            total_l1 += l1_loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}")
            print(f"Reconstruction Loss: {total_loss/len(dataloader):.4f}")
            print(f"L1 Loss: {total_l1/len(dataloader):.4f}\n")
    
    return model

def evaluate_feature_recovery(dataset, model, device='cuda'):
    """Evaluate how well the SAE recovers ground truth features"""
    # Get decoder weights
    decoder_weights = model.decoder.weight.detach().cpu().numpy()
    
    # Compute cosine similarity between learned features and ground truth
    # Transpose decoder weights to match dimensions: [n_dim, latent_dim] -> [latent_dim, n_dim]
    similarities = cosine_similarity(decoder_weights.T, dataset.base_features)
    
    # For each ground truth feature, find best matching learned feature
    max_similarities = np.max(np.abs(similarities), axis=0)
    
    # Compute statistics
    mean_recovery = np.mean(max_similarities)
    perfect_recovery = np.mean(max_similarities > 0.95)
    
    print(f"Mean feature recovery score: {mean_recovery:.3f}")
    print(f"Fraction of perfectly recovered features: {perfect_recovery:.3f}")
    
    return max_similarities

# Example usage
if __name__ == "__main__":
    # Generate synthetic data
    dataset = MatrixBindingDataset(
        n_dim=256,
        n_features=100,
        p=0.1,
        n_samples=10000
    )
    
    # Train SAE
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = train_sae(
        dataset,
        latent_dim=1000,  # Typically 2-4x number of features
        lambda_l1=0.1,
        n_epochs=100,
        device=device
    )
    
    # Evaluate feature recovery
    similarities = evaluate_feature_recovery(dataset, model, device)

Epoch 10/100
Reconstruction Loss: 0.0023
L1 Loss: 0.0450

Epoch 20/100
Reconstruction Loss: 0.0012
L1 Loss: 0.0277

Epoch 30/100
Reconstruction Loss: 0.0002
L1 Loss: 0.0147

Epoch 40/100
Reconstruction Loss: 0.0001
L1 Loss: 0.0101

Epoch 50/100
Reconstruction Loss: 0.0002
L1 Loss: 0.0082

Epoch 60/100
Reconstruction Loss: 0.0002
L1 Loss: 0.0071

Epoch 70/100
Reconstruction Loss: 0.0002
L1 Loss: 0.0064

Epoch 80/100
Reconstruction Loss: 0.0002
L1 Loss: 0.0058

Epoch 90/100
Reconstruction Loss: 0.0002
L1 Loss: 0.0055

Epoch 100/100
Reconstruction Loss: 0.0002
L1 Loss: 0.0052

Mean feature recovery score: 0.998
Fraction of perfectly recovered features: 1.000


In [7]:
!pip install numpy torch scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)
Collecting scipy>=1.6.0 (from scikit-learn)
  Downloading scipy-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Using cached threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.5/13.5 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hUsing cached joblib-1.4.2-py3-none-any.whl (301 kB)
Downloading scipy-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (40.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.6/40.6 MB[0m [31m68.2 MB/s[0m eta [36m0:00