# Sparse Autoencoders: Interpreting the Llama-3.2-1B Model

In [None]:
%pip install datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding, pipeline
from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names
from huggingface_hub import notebook_login
import tqdm
import heapq
import pickle

In [None]:
# llama 3.2-1B is a gated model, so we need to login to use it with transformers
notebook_login()

In [None]:
# loading stuff here
try:
    model_path = '../Llama-3.2-1B-Instruct'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', return_dict_in_generate=True, output_hidden_states=True)
except:
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", torch_dtype=torch.float16, device_map='auto', return_dict_in_generate=True, output_hidden_states=True)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# Testing the Llama model

inputs = tokenizer('Hello LLaMa!', return_tensors='pt').to(model.device)
# input_ids = tokenizer('Hello LLaMa!', return_tensors='pt').input_ids.to(model.device)

with torch.no_grad():
    outputs = model(**inputs)
    z = outputs.hidden_states[-1]
    generated_ids = model.generate(**inputs, max_new_tokens=50)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

In [None]:
print(generated_ids)
print(generated_text)

print(z.shape)
print(z)

In [None]:
# simple encoder and decoder modules for the SAE

class Encoder(nn.Module):
    def __init__(self, in_dim, out_dim, dtype, activation_fn=F.leaky_relu):
        super(Encoder, self).__init__()
        self.enc = nn.Linear(in_dim, out_dim, bias=True, dtype=dtype)
        self.activation_fn = activation_fn
        nn.init.kaiming_uniform_(self.enc.weight, nonlinearity='relu')

    def forward(self, z):
        # z: b, L, in_dim
        # returns h(z): b, L, out_dim
        return self.activation_fn(self.enc.forward(z))

class Decoder(nn.Module):
    def __init__(self, in_dim, out_dim, dtype):
        super(Decoder, self).__init__()
        self.dec = nn.Linear(in_dim, out_dim, bias=True, dtype=dtype)

    def forward(self, hz):
        # hz: b, L, in_dim
        # returns zhat: b, L, out_dim
        return self.dec.forward(hz)


In [None]:
# standard SAE implementation

class SAE(nn.Module):
    def __init__(self, feature_dim, sparse_dim, alpha, dtype=torch.float16):
        super(SAE, self).__init__()
        self.E = Encoder(feature_dim, sparse_dim, dtype)
        self.D = Decoder(sparse_dim, feature_dim, dtype)
        self.alpha = alpha

    def forward(self, z):
        # z: b, L, feature_dim
        # returns zhat: b, L, feature_dim
        # returns hz: b, L, sparse_dim
        hz = self.E.forward(z)
        zhat = self.D.forward(hz)
        return zhat, hz

    def loss(self, z, zhat, hz, attention_mask):
        # reconstruction_loss = torch.square(torch.norm(z - zhat, p=2))
        attention_mask = torch.unsqueeze(attention_mask, dim=-1)
        # print(zhat.shape, z.shape, attention_mask.shape)
        reconstruction_loss = F.mse_loss(zhat * attention_mask, z * attention_mask)
        sparsity_regularization = self.alpha * torch.norm(hz, p=1)
        return reconstruction_loss + sparsity_regularization


In [None]:
# Custom DataLoader
class TokenizedDataset(Dataset):
    def __init__(self, dataset):
        # dataset is a dictionary containing 'input_ids': [tensors]
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset['input_ids'])

    def __getitem__(self, idx):
        return self.dataset['input_ids'][idx]

    def collate_fn(self, data):
        input_ids = pad_sequence(data, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
        attention_mask = torch.where(input_ids != tokenizer.pad_token_id, 1, 0).to(device)
        return { 'input_ids': input_ids, 'attention_mask': attention_mask }

In [None]:
# Load Llama Nemotron dataset
try:
    dataset = load_dataset('../Llama-Nemotron-Post-Training-Dataset/SFT/chat', split='train').with_format('torch')
except:
    dataset = load_dataset('nvidia/Llama-Nemotron-Post-Training-Dataset', 'SFT', data_dir='SFT/chat').with_format('torch')['train']

In [None]:
def tokenize_raw_data(x):
    input_text = [ex[0]['content'] for ex in x['input']]
    input_ids = tokenizer(input_text)
    # input_ids['output_ids'] = tokenizer(x['output'])['input_ids'] # Uncomment if we need output ids
    return input_ids

# trim for performance
trim_dataset = dataset.train_test_split(test_size=0.9)['train']
trim_dataset = dataset.filter(lambda sample: len(sample['input'][0]['content']) <= 50)

# dataset keys: input, output, category, license, reasoning, generator, used_in_training, version, system_prompt
encoded_dataset = trim_dataset.map(tokenize_raw_data, batched=True) # added input_ids, attention_mask (for input), and (maybe) output_ids

# retrieve just tokenized data
samples = { k : encoded_dataset[k] for k in encoded_dataset.features if k in [ 'input_ids' ] } # attention mask is all 1s of same size tensor as input_ids, so don't need to store it

# dataloader = DataLoader(samples, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)
tokenized_dataset = TokenizedDataset(samples)
dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=tokenized_dataset.collate_fn)

# for data in dataloader:
#     output = model(**data)
#     print(data['input_ids'].shape)
#     print(output.hidden_states[-1])
#     print(output.hidden_states[-1].shape)
#     break

In [None]:
# Train function
def train(llm, sae, dataloader, epochs, optimizer):
    for epoch in tqdm.trange(epochs, desc="training", unit="epoch"):
        with tqdm.tqdm(dataloader, desc=f"epoch {epoch + 1}", unit="batch", total=len(dataloader), position=0, leave=True) as batch_iterator:
            sae.train()
            total_loss = 0.0
            for i, batch in enumerate(batch_iterator):
                output = llm(**batch)
                z = output.hidden_states[-1].to(torch.float32) # b, L, feature_dim

                optimizer.zero_grad()

                zhat, hz = sae.forward(z)

                loss = sae.loss(z, zhat, hz, batch['attention_mask'])
                total_loss += loss.item()
                loss.backward()

                optimizer.step()

                batch_iterator.set_postfix(mean_loss=total_loss / (i + 1), current_loss=loss.item())

In [None]:
feature_dim = 2048 # 2048 for this Llama model
sparse_dim = feature_dim * 8 # paper recommends 8-32x of feature dim for the SAE sparse dim
alpha = 0.001 # hyperparameter, tune

In [None]:
# Training
sae = SAE(feature_dim, sparse_dim, alpha, dtype=torch.float32).to(device=device)

optimizer = torch.optim.AdamW(sae.parameters())
epochs = 2

train(model, sae, dataloader, epochs, optimizer)

In [None]:
# save the sae model weights
torch.save(sae.state_dict(), 'llama-sae.pt')

In [None]:
# load the sae model
sae = SAE(feature_dim, sparse_dim, alpha, dtype=torch.float32).to(device=device)
sae.load_state_dict(torch.load('llama-sae.pt', weights_only=True))

In [None]:
sad_inputs = tokenizer('I am very sad and disappointed.', return_tensors='pt').to(model.device)
happy_inputs = tokenizer('I am very happy and energetic.', return_tensors='pt').to(model.device)

with torch.no_grad():
    sad_outputs = model(**sad_inputs)
    z = sad_outputs.hidden_states[-1].to(dtype=torch.float32)
    sae.eval()
    zhat, hz = sae.forward(z)
    print(z.shape)
    print(sad_inputs.input_ids)
    print(torch.topk(z, 5, dim=2)[1])

    happy_outputs = model(**happy_inputs)
    z = happy_outputs.hidden_states[-1].to(dtype=torch.float32)
    sae.eval()
    zhat, hz = sae.forward(z)
    print(z.shape)
    print(happy_inputs.input_ids)
    print(torch.topk(z, 5, dim=2)[1])


In [None]:
# Interpreting the model
encoder_matrix = sae.E.enc.weight # 8 * 2048, 2048

# Retrieve top k encoder weight values (corresponds to hidden features) per sparse feature
k = 5
topk_hidden = torch.topk(encoder_matrix, k, dim=1)[1]

# Retrieve bottom k encoder weight values per sparse feature (for negative correlation)
botk_hidden = torch.topk(encoder_matrix, k, dim=1, largest=False)[1]

# Each row is a sparse feature, each value in the column are the top / bottom k hidden dimension indices that correlate to that sparse feature
print(topk_hidden)
print(botk_hidden)

In [None]:
# Retrieve top k samples for each sparse feature
topk_samples = [[] for i in range(sparse_dim)]

# instead of doing top k samples per feature, for each sample, pick top k of sparse features???
with tqdm.tqdm(dataloader, desc="retrieving samples", unit="batch", total=len(dataloader), position=0, leave=True) as batch_iterator:
    sae.eval()
    total_loss = 0.0
    for i, batch in enumerate(batch_iterator):
        output = model(**batch)
        z = output.hidden_states[-1].to(torch.float32) # b, L, feature_dim
        _, hz = sae.forward(z) # b, L, sparse_dim

        for b in range(4):
            sentence = hz[b]
            print(sentence.shape)
            for d in range(sparse_dim):
                avg_activation = torch.mean(sentence[:, d])
                max_token_activation = torch.argmax(sentence[:, d])
                if len(topk_samples[d]) < k:
                    heapq.heappush(topk_samples[d], (avg_activation, max_token_activation, batch['input_ids'][b])) # error here: Boolean value of Tensor with more than one value is ambiguous, but unsure how
                else:
                    heapq.heappushpop(topk_samples[d], (avg_activation, max_token_activation, batch['input_ids'][b]))

with open('topk-samples.pkl', 'wb') as f:
    pickle.dump(topk_samples, f)