# Sparse Autoencoders: Interpreting the Llama-3.2-1B Model

In [1]:
%pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding, pipeline
from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names
from huggingface_hub import notebook_login
import tqdm

In [3]:
# llama 3.2-1B is a gated model, so we need to login to use it with transformers
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
# loading stuff here
try:
    model_path = '../Llama-3.2-1B-Instruct'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', return_dict_in_generate=True, output_hidden_states=True)
except:
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", torch_dtype=torch.float16, device_map='auto', return_dict_in_generate=True, output_hidden_states=True)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

cuda


In [5]:
# Testing the Llama model

inputs = tokenizer('Hello LLaMa!', return_tensors='pt').to(model.device)
# input_ids = tokenizer('Hello LLaMa!', return_tensors='pt').input_ids.to(model.device)

with torch.no_grad():
    outputs = model(**inputs)
    z = outputs.hidden_states[-1]
    generated_ids = model.generate(**inputs, max_new_tokens=50)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [6]:
print(generated_ids)
print(generated_text)

print(z.shape)
print(z)

tensor([[128000,   9906,    445,   8921,  30635,      0,    358,   1097,  12304,
            311,    387,    264,    961,    315,    420,   4029,     13,    358,
           1097,  24450,    311,   4048,    323,   4430,    856,   6677,    449,
           3885,     13,    358,   1097,   5644,    311,   1520,   4320,    904,
           4860,    499,   1253,    617,     11,    323,    358,   1427,   4741,
            311,    279,   6776,    311,   7945,    499,    304,    904,   1648,
            358,    649]], device='cuda:0')
Hello LLaMa! I am excited to be a part of this community. I am eager to learn and share my knowledge with others. I am ready to help answer any questions you may have, and I look forward to the opportunity to assist you in any way I can
torch.Size([1, 6, 2048])
tensor([[[ 0.9790,  0.1805,  0.6196,  ..., -0.8418, -0.2700,  0.0598],
         [ 0.6118,  4.0039,  2.5605,  ..., -4.4180, -4.5781, -0.6304],
         [-0.8062,  1.3691, -0.7114,  ..., -1.9092, -1.6943,  0.05

In [28]:
# simple encoder and decoder modules for the SAE

class Encoder(nn.Module):
    def __init__(self, in_dim, out_dim, dtype, activation_fn=torch.relu):
        super(Encoder, self).__init__()
        self.enc = nn.Linear(in_dim, out_dim, bias=True, dtype=dtype)
        self.activation_fn = activation_fn

    def forward(self, z):
        # z: b, L, in_dim
        # returns h(z): b, L, out_dim
        return self.activation_fn(self.enc.forward(z))

class Decoder(nn.Module):
    def __init__(self, in_dim, out_dim, dtype):
        super(Decoder, self).__init__()
        self.dec = nn.Linear(in_dim, out_dim, bias=True, dtype=dtype)

    def forward(self, hz):
        # hz: b, L, in_dim
        # returns zhat: b, L, out_dim
        return self.dec.forward(hz)


In [38]:
# standard SAE implementation

class SAE(nn.Module):
    def __init__(self, feature_dim, sparse_dim, alpha, dtype=torch.float16):
        super(SAE, self).__init__()
        self.E = Encoder(feature_dim, sparse_dim, dtype)
        self.D = Decoder(sparse_dim, feature_dim, dtype)
        self.alpha = alpha

    def forward(self, z):
        # z: b, L, feature_dim
        # returns zhat: b, L, feature_dim
        # returns hz: b, L, sparse_dim
        hz = self.E.forward(z)
        zhat = self.D.forward(hz)
        return zhat, hz

    def loss(self, z, zhat, hz):
        reconstruction_loss = torch.square(torch.norm(z - zhat, p=2))
        sparsity_regularization = self.alpha * torch.norm(hz, p=1)
        return reconstruction_loss + sparsity_regularization


In [7]:
# Custom DataLoader
class TokenizedDataset(Dataset):
    def __init__(self, dataset):
        # dataset is a dictionary containing 'input_ids': [tensors]
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset['input_ids'])

    def __getitem__(self, idx):
        return self.dataset['input_ids'][idx]

    def collate_fn(self, data):
        input_ids = pad_sequence(data, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
        attention_mask = torch.where(input_ids != tokenizer.pad_token_id, 1, 0).to(device)
        return { 'input_ids': input_ids, 'attention_mask': attention_mask }

In [8]:
# Load Llama Nemotron dataset

try:
    dataset = load_dataset('../Llama-Nemotron-Post-Training-Dataset/SFT/chat', split='train').with_format('torch')
except:
    dataset = load_dataset('nvidia/Llama-Nemotron-Post-Training-Dataset', 'SFT', data_dir='SFT/chat').with_format('torch')['train']

README.md:   0%|          | 0.00/6.39k [00:00<?, ?B/s]

chat.jsonl:   0%|          | 0.00/255M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [11]:
def tokenize_raw_data(x):
    input_text = [ex[0]['content'] for ex in x['input']]
    input_ids = tokenizer(input_text)
    # input_ids['output_ids'] = tokenizer(x['output'])['input_ids'] # Uncomment if we need output ids
    return input_ids

# trim for performance
trim_dataset = dataset.train_test_split(test_size=0.9)['train']
trim_dataset = dataset.filter(lambda sample: len(sample['input'][0]['content']) <= 75)

# dataset keys: input, output, category, license, reasoning, generator, used_in_training, version, system_prompt
encoded_dataset = trim_dataset.map(tokenize_raw_data, batched=True) # added input_ids, attention_mask (for input), and (maybe) output_ids

# retrieve just tokenized data
samples = { k : encoded_dataset[k] for k in encoded_dataset.features if k in [ 'input_ids' ] } # attention mask is all 1s of same size tensor as input_ids, so don't need to store it

# dataloader = DataLoader(samples, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)
tokenized_dataset = TokenizedDataset(samples)
dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=tokenized_dataset.collate_fn)

for data in dataloader:
    output = model(**data)
    print(data['input_ids'].shape)
    print(output.hidden_states[-1])
    print(output.hidden_states[-1].shape)
    break

Filter:   0%|          | 0/39792 [00:00<?, ? examples/s]

Map:   0%|          | 0/12170 [00:00<?, ? examples/s]

torch.Size([4, 13])
tensor([[[ 0.9785,  0.1799,  0.6201,  ..., -0.8433, -0.2710,  0.0602],
         [ 0.0933,  0.8760, -1.0986,  ..., -0.0638, -3.0508, -3.0488],
         [ 0.9404, -2.8809,  2.2129,  ...,  2.6602, -1.2959, -2.9141],
         ...,
         [-0.5083,  0.6226, -2.7168,  ..., -1.9639,  2.4199,  1.7969],
         [-1.7188, -0.2842, -1.7139,  ...,  0.7266,  0.7007,  0.5767],
         [-2.3223, -1.2734, -2.0703,  ...,  0.0380,  1.3584,  2.5625]],

        [[ 0.9785,  0.1799,  0.6201,  ..., -0.8433, -0.2710,  0.0602],
         [ 0.0676,  2.9238,  0.0240,  ..., -0.9106, -2.8926, -0.0579],
         [-2.0527,  2.4766,  2.6914,  ..., -1.7305, -1.7139,  0.8452],
         ...,
         [ 3.2969,  1.4150,  0.7979,  ..., -2.1465,  1.5781, -2.9258],
         [ 3.0059,  1.7744,  0.8647,  ..., -2.3984,  1.5430, -2.7715],
         [ 2.1387,  2.2051,  1.3291,  ..., -2.0723,  1.7598, -2.2734]],

        [[ 0.9785,  0.1799,  0.6201,  ..., -0.8433, -0.2710,  0.0602],
         [-0.4783,  0.857

In [42]:
# Train function
def train(llm, sae, dataloader, epochs, optimizer):
    for epoch in tqdm.trange(epochs, desc="training", unit="epoch"):
        with tqdm.tqdm(dataloader, desc=f"epoch {epoch + 1}", unit="batch", total=len(dataloader), position=0, leave=True) as batch_iterator:
            sae.train()
            total_loss = 0.0
            for i, batch in enumerate(batch_iterator):
                output = llm(**batch)
                z = output.hidden_states[-1].to(torch.float32) # b, L, feature_dim

                optimizer.zero_grad()

                zhat, hz = sae.forward(z)

                loss = sae.loss(z, zhat, hz)
                total_loss += loss.item()
                loss.backward()

                optimizer.step()

                batch_iterator.set_postfix(mean_loss=total_loss / (i + 1), current_loss=loss.item())

In [43]:
# Training
feature_dim = 2048 # 2048 for this Llama model
sparse_dim = feature_dim * 8 # paper recommends 8-32x of feature dim for the SAE sparse dim
alpha = 0.05 # hyperparameter, tune

sae = SAE(feature_dim, sparse_dim, alpha, dtype=torch.float32).to(device=device)

optimizer = torch.optim.Adam(sae.parameters())
epochs = 1

train(model, sae, dataloader, epochs, optimizer)

epoch 1: 100%|██████████| 3043/3043 [06:26<00:00,  7.87batch/s, current_loss=1.35e+5, mean_loss=1.14e+7]
training: 100%|██████████| 1/1 [06:26<00:00, 386.88s/epoch]
