In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
import re
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
import os

tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T")
model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T")


tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]

In [3]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, sparsity_param: float = 0.1):
        super().__init__()
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, input_size)
        self.sparsity_param = sparsity_param
        self.activation = nn.ReLU()
        
        
    def forward(self, x):
        encoded = self.activation(self.encoder(x))
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def get_sparsity_penalty(self, encoded):
        avg_activation = torch.mean(encoded, dim=0)
        kl_div = torch.sum(self.sparsity_param * torch.log(self.sparsity_param / avg_activation) + 
                          (1 - self.sparsity_param) * torch.log((1 - self.sparsity_param) / (1 - avg_activation)))
        return kl_div

In [4]:
def extract_microparameters(text: str) -> Dict:
    """Extract statistical parameters from text."""
    paragraphs = text.split('\n\n')
    sentences = re.split('[.!?]+', text)
    words = text.split()
    
    return {
        'n_paragraphs': len(paragraphs),
        'n_sentences': len(sentences),
        'n_words': len(words),
        'avg_words_per_sentence': len(words) / len(sentences),
        'avg_sentences_per_paragraph': len(sentences) / len(paragraphs),
        'vocabulary_size': len(set(words)),
        'word_frequency': Counter(words)
    }

In [5]:
def get_embeddings(texts):
    embeddings = []
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048)
        with torch.no_grad():
            outputs = model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :].squeeze()
        embeddings.append(embedding)
    return torch.stack(embeddings)

In [6]:
def analyze_active_neurons(encoded_data: torch.Tensor) -> Dict:
    """Analyze which neurons are most active across texts."""
    activation_patterns = (encoded_data > 0).float()
    neuron_activity = torch.sum(activation_patterns, dim=0)
    
    # Get top active neurons and their activation counts
    top_neurons = torch.argsort(neuron_activity, descending=True)
    
    return {
        'neuron_activity': neuron_activity,
        'top_neurons': top_neurons,
        'activation_patterns': activation_patterns
    }

In [8]:
# Load and process texts
files = ['comedy.txt', 'fantasy.txt', 'power_strategy.txt', 'romance_tragedy.txt', 'romance.txt']
texts = []
for filename in files:
    filepath = os.path.join('data', filename)
    with open(filepath, 'r', encoding='utf-8') as f:
        texts.append(f.read())

# Extract microparameters
micro_params = [extract_microparameters(text) for text in texts]

# Get Gemma embeddings
embeddings = get_embeddings(texts)

# Initialize and train SAE
input_size = embeddings.shape[1]  # Gemma's embedding dimension
hidden_size = 256
sae = SparseAutoencoder(input_size, hidden_size)
optimizer = optim.Adam(sae.parameters())
criterion = nn.MSELoss()

# Training loop
n_epochs = 100
for epoch in range(n_epochs):
    optimizer.zero_grad()
    encoded, decoded = sae(embeddings)
    reconstruction_loss = criterion(decoded, embeddings)
    sparsity_loss = sae.get_sparsity_penalty(encoded)
    loss = reconstruction_loss + 0.1 * sparsity_loss
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Analyze results
with torch.no_grad():
    encoded_data, _ = sae(embeddings)
    neuron_analysis = analyze_active_neurons(encoded_data)

# Print microparameters
for i, params in enumerate(micro_params):
    print(f"\nText {i+1} microparameters:")
    for key, value in params.items():
        if key != 'word_frequency':
            print(f"{key}: {value}")

# Print neuron activity
print("\nNeuron activation analysis:")
top_k = 10
top_neurons = neuron_analysis['top_neurons'][:top_k]
print(f"Top {top_k} most active neurons:", top_neurons.tolist())
print("Activation counts:", neuron_analysis['neuron_activity'][top_neurons].tolist())

# Save results
torch.save({
    'sae_state': sae.state_dict(),
    'embeddings': embeddings,
    'encoded_data': encoded_data,
    'micro_params': micro_params,
    'neuron_analysis': neuron_analysis
}, 'text_analysis_results.pt')

# Visualization
plt.figure(figsize=(12, 6))
plt.bar(range(top_k), neuron_analysis['neuron_activity'][top_neurons].numpy())
plt.title('Top Neuron Activations')
plt.xlabel('Neuron Index')
plt.ylabel('Activation Count')
plt.savefig('neuron_activations.png')
plt.close()

Epoch 10, Loss: nan
Epoch 20, Loss: nan
Epoch 30, Loss: nan
Epoch 40, Loss: nan
Epoch 50, Loss: nan
Epoch 60, Loss: nan
Epoch 70, Loss: nan
Epoch 80, Loss: nan
Epoch 90, Loss: nan
Epoch 100, Loss: nan

Text 1 microparameters:
n_paragraphs: 2245
n_sentences: 6444
n_words: 114126
avg_words_per_sentence: 17.710428305400374
avg_sentences_per_paragraph: 2.8703786191536746
vocabulary_size: 13919

Text 2 microparameters:
n_paragraphs: 2130
n_sentences: 9658
n_words: 163818
avg_words_per_sentence: 16.961896873058603
avg_sentences_per_paragraph: 4.534272300469484
vocabulary_size: 18826

Text 3 microparameters:
n_paragraphs: 341
n_sentences: 1618
n_words: 52978
avg_words_per_sentence: 32.74289245982695
avg_sentences_per_paragraph: 4.744868035190616
vocabulary_size: 8472

Text 4 microparameters:
n_paragraphs: 1068
n_sentences: 3441
n_words: 29000
avg_words_per_sentence: 8.427782621331009
avg_sentences_per_paragraph: 3.2219101123595504
vocabulary_size: 6954

Text 5 microparameters:
n_paragraphs: 2