# Crosscoders
Attempting to recreate the SAE model Anthropic wrote about recently.

Its a SAE which is layer and model agnostic.

I'm not sure about the model agnostic part, but I have an idea about the layer agnostic part.

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM

from datasets import load_dataset

In [None]:
DEVICE = "mps"

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct").to(DEVICE)
model.eval()

tokens = tokenizer.encode("print(\"Hello", return_tensors="pt").to(DEVICE)
generation = model.generate(tokens, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(generation[0], skip_special_tokens=True))

We can very simply get the hidden states by setting output_hidden_states=True flag.

In [4]:
out = model(tokens, output_hidden_states=True)

# train SAE on these
len(out.hidden_states), out.hidden_states[0].shape


(29, torch.Size([1, 3, 1536]))

and we can register hooks to modify the activations as we please.

In [None]:
layer = model.model.layers[0].mlp.up_proj

hook = layer.register_forward_hook(lambda module, input, output: output * 0)

generation = model.generate(tokens, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(generation[0], skip_special_tokens=True))

hook.remove()

generation = model.generate(tokens, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(generation[0], skip_special_tokens=True))

In [6]:
N_LAYERS = model.config.num_hidden_layers
HIDDEN_SIZE = model.config.hidden_size

In [None]:
# these models are specifically trained on specific chat formats
# the (system, user, assistant) format is just a fancy string in the end
def format_to_chat(example):
    messages = [
        {"role": "system", "content": "You are Qwen, a helpful coding assistant."},
        {"role": "user", "content": example['prompt']},
        {"role": "assistant", "content": example['response']}
    ]
    
    chat_format = tokenizer.apply_chat_template(messages, tokenize=True)
    return {"chat_ids": chat_format}


train_ds = load_dataset("nampdn-ai/tiny-codes", split="train").shuffle(seed=42)

train_ds = train_ds.select(range(100000))  # for laptop side testing
train_ds = train_ds.map(format_to_chat, num_proc=8)


val_ds = train_ds.take(len(train_ds) // 20)
train_ds = train_ds.skip(len(train_ds) // 20)

len(train_ds), len(val_ds)

In [8]:
x = torch.tensor(train_ds[0]["chat_ids"])[None].to(DEVICE) # add batch dimension
attn = torch.tril(torch.ones(1, x.shape[1], x.shape[1])).bool().to(DEVICE)

out = model(x, attention_mask=attn, output_hidden_states=True)

out.hidden_states[0].reshape(-1, HIDDEN_SIZE).shape

torch.Size([411, 1536])

To inform the SAE of which layer it is being applied to, we encode the layer position into the activations via a sinusoidal positional encoding.

In [9]:
import math

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, hidden_size, max_len):
        super().__init__()
        self.hidden_size = hidden_size
        self.max_len = max_len
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2) * (-math.log(10000.0) / hidden_size))
        
        pe = torch.zeros(max_len, hidden_size)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)

    def forward(self, x, pos):
        return x + self.pe[pos]

The sparse autoencoder itself is quite simple. It's a traditional autoencoder, but with a twist: instead of 'compressing' the input through a bottleneck, it's overcomplete - meaning it has a larger hidden size than the input size.

There are different ways of achieving sparsity in these activations. One popular formulation uses a top_k activation function to ensure only the k most important features remain active, while others methods directly optimize for sparsity through various loss terms.

The model is trained to reconstruct its input from these sparse activations using reconstruction loss. We can also evaluate its effectiveness by measuring "recovered" loss - essentially comparing the next-token prediction accuracy (via negative log-likelihood) both before and after applying the SAE intervention. This helps us understand how well the SAE preserves the important information from the original representations.

In [10]:
K = 32

class TopKSAE(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers, top_k):
        super().__init__()
        self.input_size, self.hidden_size, self.n_layers, self.top_k = input_size, hidden_size, n_layers, top_k
        
        self.pos_enc = SinusoidalPositionalEncoding(input_size, max_len=n_layers)
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, input_size)
        
    def encode(self, x, pos):
        x = self.pos_enc(x, pos)
        x = self.encoder(x)
        top_k = torch.topk(x, k=self.top_k)
        return top_k.indices, top_k.values
    
    def decode(self, indices, values):
        sparse = torch.zeros(indices.shape[0], self.hidden_size).to(indices.device)
        sparse.scatter_(1, indices, values)
        return self.decoder(sparse)

    def forward(self, x, pos):
        indices, values = self.encode(x, pos)
        return self.decode(indices, values)
    


sae = TopKSAE(HIDDEN_SIZE, HIDDEN_SIZE * 4, N_LAYERS, K).to(DEVICE)
activations = out.hidden_states[0].reshape(-1, HIDDEN_SIZE)
positions = torch.tensor([0] * len(activations))
sae(activations, positions).shape


torch.Size([411, 1536])

I am borrowing an idea from diffusion models and will encode at which layer in the transformer model, the SAE is being applied to (in diffusion it would be the level of noise being on the image that it needs to denoise).

Here I implement a sinusoidal positional encoding for the layers. One could also use learnable embeddings, but I fear that would "divide" the expressiveness of the SAE between these embeddings and the SAE itself. Using static Sinusoidal PE sounds like a better option.

Now to train

Firstly, I dislike the habit people have of doing an entire training epoch before validating (especially since datasets can be huge). So I more often than not do the following

In [11]:
EPOCHS = 0
BATCH_SIZE = 32
MAX_SEQ_LEN = 1024


def cycle(dl):
    while True: yield from dl

def collate_fn(batch):
    input_ids = [torch.tensor(item["chat_ids"][:MAX_SEQ_LEN]) for item in batch]
    input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True)
    B, L = input_ids.shape
    
    return {
        "input_ids": input_ids,
        "attention_mask": torch.tril(torch.ones(B, 1, L, L)).float()
    }

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=collate_fn)


N_STEPS_PER_TRAIN_EPOCH = len(train_loader)
N_STEPS_PER_VAL_EPOCH = len(val_loader)

INTERLEAVE_EVERY = 100
N_VALIDATION_STEPS = N_STEPS_PER_VAL_EPOCH // (N_STEPS_PER_TRAIN_EPOCH // INTERLEAVE_EVERY)

SAVE_EVERY = N_STEPS_PER_TRAIN_EPOCH // 100 # 20 times per epoch


train_iter = cycle(train_loader)
val_iter = cycle(val_loader)

Secondly, it feels almost wasteful to only use one random layer per item.

But perhaps using multiple layers per position would actually hurt training.

Adjacent layers process the same token through successive transformations (and has skip connections), which creates highly correlated features. This correlation is perhaps stronger than the relationship between different positions in the sequence. Using multiple layers per position would likely reduce our batch diversity, as we'd be training on variations of the same underlying features rather than truly independent samples.

In [None]:
import wandb
from tqdm import tqdm

wandb.init(project="sae-training")

optim = Adam(sae.parameters(), lr=3e-4, weight_decay=1e-4)

train_loss, val_loss = float("inf"), float("inf")
train_losses, val_losses = [], []

def sae_step(batch):
    B, L = batch["input_ids"].shape
    with torch.no_grad():
        out = model(batch["input_ids"].to(DEVICE), attention_mask=batch["attention_mask"].to(DEVICE), output_hidden_states=True)

    # Generate random layer for each position in the batch/sequence
    random_layers = torch.randint(0, N_LAYERS, (B * L,))
    
    # Gather hidden states from random layers for each position
    all_hidden = torch.stack(out.hidden_states)  # [n_layers, batch_size, seq_len, hidden_size]
    all_hidden = all_hidden.permute(1, 2, 0, 3)  # [batch_size, seq_len, n_layers, hidden_size]
    
    # Create indices for gathering
    batch_indices = torch.arange(B).repeat_interleave(L)
    seq_indices = torch.arange(L).repeat(B)
    
    # Gather the hidden states using the random layers
    inputs = all_hidden[batch_indices, seq_indices, random_layers]  # [batch_size*seq_len, hidden_size]
    
    reconstructed = sae(inputs, random_layers)
    
    return F.mse_loss(reconstructed, inputs)

step_iter = tqdm(range(EPOCHS * N_STEPS_PER_TRAIN_EPOCH), desc="Training")
for step in step_iter:
    batch = next(train_iter)
    
    loss = sae_step(batch)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    train_losses.append(loss.item())
    train_loss = loss.item()
    
    wandb.log({"train_loss": train_loss}, step=step)
    
    step_iter.set_postfix({"Train Loss": f"{train_loss:.4f}", "Val Loss": f"{val_loss:.4f}"})
    
    if step % INTERLEAVE_EVERY == 0:
        interleaved_losses = []
        for _ in range(N_VALIDATION_STEPS):
            batch = next(val_iter)
            with torch.no_grad():
                loss = sae_step(batch)
                interleaved_losses.append(loss.item())
                
                val_loss = loss.item()
                
                step_iter.set_postfix({"Train Loss": f"{train_loss:.4f}", "Val Loss": f"{val_loss:.4f}"})

        val_loss = sum(interleaved_losses) / len(interleaved_losses)
        val_losses.append(val_loss)
        wandb.log({"val_loss": val_loss}, step=step)
        
    if step % SAVE_EVERY == 0 and step > 0:
        torch.save(sae.state_dict(), f"sae_{step}.pth")


In [None]:
from matplotlib import pyplot as plt
plt.figure(figsize=(10, 6))
plt.plot(range(len(train_losses)), train_losses, label='Train Loss', alpha=0.5)
# Ensure validation points align with training points by using same length arrays
val_steps = range(0, len(train_losses), INTERLEAVE_EVERY)
val_losses_padded = val_losses + [val_losses[-1]] * (len(val_steps) - len(val_losses))
plt.plot(val_steps, val_losses_padded, label='Validation Loss', linewidth=2)
plt.xlabel('Training Steps')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

Still trending downwards, only trained this for 20-30 minutes on my mac. 

Now we can try to see whether the SAE is picking up on something interesting.

In [15]:
subset = val_ds.shuffle(seed=42).select(range(10))
subset = [torch.tensor(x["chat_ids"])[None] for x in subset]

In [16]:
sae.load_state_dict(torch.load("sae_12342.pth", map_location=DEVICE), strict=True)

<All keys matched successfully>

In [17]:
LAYER = N_LAYERS // 2  # maybe the middle is interesting

with torch.no_grad():
    activations = [model(x.to(DEVICE), output_hidden_states=True).hidden_states[LAYER].reshape(-1, HIDDEN_SIZE) for x in tqdm(subset)]

    activations = torch.cat(activations, dim=0)
    positions = torch.tensor([LAYER] * len(activations))
    
    indices, values = sae.encode(activations, positions)


100%|██████████| 10/10 [01:08<00:00,  6.85s/it]


In [18]:
values.min(), values.max(), values.mean(), values.std()

(tensor(0.5195, device='mps:0'),
 tensor(377.1314, device='mps:0'),
 tensor(4.7483, device='mps:0'),
 tensor(11.3045, device='mps:0'))

In [19]:
counts = torch.bincount(indices.flatten(), minlength=HIDDEN_SIZE * 2)

sorted_indices = torch.sort(counts, descending=True).indices


In [25]:
def highlight_feature_activations(text, feature_idx, model, sae, tokenizer):
    RESET = "\033[0m"
    def get_color(value):
        normalized = max(0.0, min(1.0, math.log(value + 1e-5)))
        
        text_red = int(30 + 225 * normalized)  # increased range (30-255)
        
        bg_red = 255
        bg_green = bg_blue = int(255 - (80 * normalized))  # doubled the reduction (255-175)
        
        return f"\033[48;2;{bg_red};{bg_green};{bg_blue}m\033[38;2;{text_red};0;0m"
    
    # Tokenize and get model outputs
    tokens = tokenizer.encode(text, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(tokens, output_hidden_states=True)
        activations = outputs.hidden_states[LAYER]
        activations = activations.reshape(-1, model.config.hidden_size)
        positions = torch.tensor([LAYER] * len(activations))
        
        indices, values = sae.encode(activations, positions)
        
        # Create feature mask
        feature_mask = (indices == feature_idx).any(dim=1)
        feature_values = torch.zeros_like(feature_mask, dtype=torch.float)
        for i, (idx, val) in enumerate(zip(indices, values)):
            if feature_idx in idx:
                # Get position of our feature in top features
                feature_pos = (idx == feature_idx).nonzero()[0].item()
                
                # Check if feature is in top 32
                if feature_pos < 32:
                    feature_values[i] = val[idx == feature_idx].max()
    
    # Decode tokens and apply colors
    tokens = tokens[0].cpu().numpy()
    decoded = []
    for i, token in enumerate(tokens):
        word = tokenizer.decode([token])
        if feature_values[i] > 0:  # Only highlight if it was in top 5
            color = get_color(feature_values[i].item())
            decoded.append(f"{color}{word}{RESET}")
        else:
            decoded.append(word)
    
    return "".join(decoded)

def highlight_feature(text, feature_idx, model, sae, tokenizer):
    example_text = tokenizer.decode(text[0], skip_special_tokens=False)
    colored_text = highlight_feature_activations(example_text, feature_idx, model, sae, tokenizer)
    print(colored_text)

Now we can at last highlight the features.

The most common activations appear to "overfire" in this small example we have done. 

In [21]:
highlight_feature(subset[1], sorted_indices[0], model, sae, tokenizer)

[48;2;255;175;175m[38;2;255;0;0m<|im_start|>[0m[48;2;255;175;175m[38;2;255;0;0msystem[0m[48;2;255;175;175m[38;2;255;0;0m
[0m[48;2;255;175;175m[38;2;255;0;0mYou[0m[48;2;255;175;175m[38;2;255;0;0m are[0m[48;2;255;175;175m[38;2;255;0;0m Q[0m[48;2;255;175;175m[38;2;255;0;0mwen[0m[48;2;255;175;175m[38;2;255;0;0m,[0m[48;2;255;175;175m[38;2;255;0;0m a[0m[48;2;255;175;175m[38;2;255;0;0m helpful[0m[48;2;255;175;175m[38;2;255;0;0m coding[0m[48;2;255;175;175m[38;2;255;0;0m assistant[0m[48;2;255;175;175m[38;2;255;0;0m.[0m[48;2;255;175;175m[38;2;255;0;0m<|im_end|>[0m[48;2;255;175;175m[38;2;255;0;0m
[0m[48;2;255;175;175m[38;2;255;0;0m<|im_start|>[0m[48;2;255;175;175m[38;2;255;0;0muser[0m[48;2;255;175;175m[38;2;255;0;0m
[0m[48;2;255;175;175m[38;2;255;0;0mDesign[0m[48;2;255;175;175m[38;2;255;0;0m a[0m[48;2;255;175;175m[38;2;255;0;0m Go[0m[48;2;255;175;175m[38;2;255;0;0m function[0m[48;2;255;175;175m[38;2;255;0;0m snippet[0m[48;2;255

Some features I have naively labelled

In [26]:
highlight_feature(subset[0], sorted_indices[2], model, sae, tokenizer)

<|im_start|>system
You are Qwen, a helpful coding assistant.<|im_end|>
<|im_start|>user
Build a Julia module snippet that[48;2;255;175;175m[38;2;255;0;0m Updates[0m[48;2;255;192;192m[38;2;206;0;0m Extreme[0m Handling[48;2;255;175;175m[38;2;255;0;0m personal[0m[48;2;255;175;175m[38;2;255;0;0m items[0m[48;2;255;195;195m[38;2;196;0;0m:[0m Keeping[48;2;255;175;175m[38;2;255;0;0m Personal[0m[48;2;255;175;175m[38;2;255;0;0m Items[0m Clean[48;2;255;190;190m[38;2;210;0;0m for[0m[48;2;255;175;175m[38;2;255;0;0m Engineer[0m[48;2;255;176;176m[38;2;249;0;0m for[0m[48;2;255;180;180m[38;2;240;0;0m Beginners[0m. Incorporate[48;2;255;199;199m[38;2;187;0;0m if[0m/else or switch/case statements to[48;2;255;175;175m[38;2;255;0;0m handle[0m[48;2;255;175;175m[38;2;255;0;0m different[0m[48;2;255;175;175m[38;2;255;0;0m cases[0m[48;2;255;175;175m[38;2;255;0;0m based[0m[48;2;255;175;175m[38;2;255;0;0m on[0m[48;2;255;175;175m[38;2;255;0;0m the[0m[48;2;255;1

In [27]:
highlight_feature(subset[2], sorted_indices[65], model, sae, tokenizer)

<|im_start|>system
You are Q[48;2;255;175;175m[38;2;255;0;0mwen[0m, a helpful coding assistant.<|im_end|>
<|im_start|>user
[48;2;255;175;175m[38;2;255;0;0mDevelop[0m a Rust[48;2;255;179;179m[38;2;243;0;0m program[0m snippet to[48;2;255;190;190m[38;2;210;0;0m Transform[0m Extreme Online Shopping: Shipping for Engineer for Experts. Incorporate if/else or[48;2;255;175;175m[38;2;255;0;0m switch[0m[48;2;255;204;204m[38;2;171;0;0m/c[0mase[48;2;255;189;189m[38;2;214;0;0m statements[0m to handle various cases related to the Consent.[48;2;255;184;184m[38;2;229;0;0m Dry[0m[48;2;255;210;210m[38;2;154;0;0m-run[0m, ensure your control flow logic is clear and[48;2;255;175;175m[38;2;255;0;0m well[0m-commented.<|im_end|>
<|im_start|>assistant
Here is a possible implementation of this problem in Rust:


```rust 
fn[48;2;255;184;184m[38;2;229;0;0m transform[0m_extreme_online_shopping(order_details : OrderDetails)[48;2;255;190;190m[38;2;211;0;0m ->[0m String {
    let 

In [29]:
highlight_feature(subset[9], sorted_indices[37], model, sae, tokenizer)

[48;2;255;175;175m[38;2;255;0;0m<|im_start|>[0m[48;2;255;175;175m[38;2;255;0;0msystem[0m[48;2;255;175;175m[38;2;255;0;0m
[0mYou are Qwen, a helpful coding assistant.<|im_end|>
<|im_start|>user
Develop a[48;2;255;207;207m[38;2;164;0;0m C[0m# program snippet to Calculate High Security: Fraud Detection for Engineer for Professionals. Incorporate if/else or switch[48;2;255;211;211m[38;2;151;0;0m/c[0mase statements to handle various cases related to the Ethics. Dry-run, ensure your control[48;2;255;202;202m[38;2;178;0;0m flow[0m logic is clear and well-commented.<|im_end|>
<|im_start|>assistant
Here's some sample code written in C#:

    ```[48;2;255;223;223m[38;2;117;0;0mc[0msharp
    public static double CalculateHighSecurityFraudDetectionForEngine[48;2;255;199;199m[38;2;186;0;0mers[0m(List[48;2;255;181;181m[38;2;236;0;0m<Transaction[0m[48;2;255;184;184m[38;2;229;0;0m>[0m[48;2;255;183;183m[38;2;232;0;0m transactions[0m[48;2;255;218;218m[38;2;133;0;0m)[0

In [30]:
highlight_feature(subset[8], sorted_indices[43], model, sae, tokenizer)

<|im_start|>system
You are Qwen, a helpful[48;2;255;175;175m[38;2;255;0;0m coding[0m assistant.<|im_end|>
<|im_start|>user
[48;2;255;175;175m[38;2;255;0;0mCreate[0m[48;2;255;179;179m[38;2;242;0;0m a[0m Neo[48;2;255;175;175m[38;2;255;0;0m4[0mj database[48;2;255;224;224m[38;2;114;0;0m and[0m[48;2;255;175;175m[38;2;255;0;0m Cy[0m[48;2;255;175;175m[38;2;255;0;0mpher[0m[48;2;255;184;184m[38;2;227;0;0m script[0m snippet that Updates Extreme Nose[48;2;255;175;175m[38;2;255;0;0m care[0m[48;2;255;175;175m[38;2;255;0;0m:[0m Preventing Nosebleeds for Engineer for Experts. Use if/else or switch[48;2;255;175;175m[38;2;255;0;0m/c[0m[48;2;255;175;175m[38;2;255;0;0mase[0m statements to[48;2;255;175;175m[38;2;255;0;0m condition[0m[48;2;255;175;175m[38;2;255;0;0mally[0m perform different actions based on the Privacy. Dry-run, then[48;2;255;175;175m[38;2;255;0;0m include[0m[48;2;255;175;175m[38;2;255;0;0m comments[0m[48;2;255;209;209m[38;2;156;0;0m that[

In [31]:
highlight_feature(subset[7], sorted_indices[34], model, sae, tokenizer)

<|im_start|>system
You are Qwen, a helpful coding[48;2;255;175;175m[38;2;255;0;0m assistant[0m.<|im_end|>
<|im_start|>user[48;2;255;255;255m[38;2;30;0;0m
[0mCreate[48;2;255;224;224m[38;2;115;0;0m a[0m[48;2;255;175;175m[38;2;255;0;0m relation[0m database and[48;2;255;175;175m[38;2;255;0;0m SQL[0m script snippet that Updates Extreme Nail care: Trimming[48;2;255;175;175m[38;2;255;0;0m N[0m[48;2;255;189;189m[38;2;213;0;0mails[0m for Decision Making for Experts. Use if[48;2;255;216;216m[38;2;138;0;0m/[0melse or switch/c[48;2;255;215;215m[38;2;140;0;0mase[0m statements to conditionally perform different actions based on the Consent. Dry-run, then include comments that outline the control flow[48;2;255;206;206m[38;2;166;0;0m and[0m how you handle different scenarios.<|im_end|>[48;2;255;178;178m[38;2;244;0;0m
[0m<|im_start|>[48;2;255;183;183m[38;2;231;0;0massistant[0m[48;2;255;175;175m[38;2;255;0;0m
[0mHere is a sample[48;2;255;175;175m[38;2;255;0;0m re