In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from hook_manager import HookManager
from data_handling import load_tinystories_data
from sae import SaeTrainer

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
model_name = "roneneldan/TinyStories-33M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [5]:
data = load_tinystories_data('data/tinystories_val.txt')

In [59]:
dragon_indices = []

for i in range(len(data)):
    if 'dragon' in data[i]:
        dragon_indices.append(i)

## get activations

In [6]:
layers_pre_attn = []
input_ids = []

with HookManager(model) as hook_manager:
    for layer in range(model.config.num_layers):
        layers_pre_attn.append(hook_manager.attach_residstream_hook(layer=layer))
    for idx, story in enumerate(data[:100]):
        print(idx)
        tokenized = tokenizer(story, return_tensors='pt')
        model.forward(tokenized.input_ids)
        input_ids.append(tokenized.input_ids)

input_ids = torch.concat(input_ids, dim=1).squeeze()
all_resids = torch.concat([torch.concat(layer) for layer in layers_pre_attn])

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [51]:
layer_2_resids = torch.concat(layers_pre_attn[3])
ds = TensorDataset(layer_2_resids)
loader = DataLoader(ds, shuffle=True, batch_size=32)

In [52]:
trainer = SaeTrainer(
    input_size=768,
    hidden_size=768*4,
    k=16,
    learning_rate=0.001,
    device='cpu'
)

In [53]:
num_epochs = 5
for epoch_idx in range(num_epochs):
    for idx, batch in enumerate(loader):
        activation = batch[0].detach()
        label = activation.detach()

        loss = trainer.train_step(activation, label)
        print(f'{epoch_idx}.{idx}\t\t{loss}')

0.0		0.6504125595092773
0.1		0.6391193270683289
0.2		0.6269657611846924
0.3		0.6326513886451721
0.4		0.6266652345657349
0.5		0.6290397644042969
0.6		0.6178308129310608
0.7		0.6095291972160339
0.8		0.6132170557975769
0.9		0.5985453128814697
0.10		0.5982585549354553
0.11		0.58291095495224
0.12		0.5629100799560547
0.13		0.5748870968818665
0.14		0.5643284916877747
0.15		0.55426424741745
0.16		0.5461624264717102
0.17		0.5454410910606384
0.18		0.5496053099632263
0.19		0.5356634855270386
0.20		0.5138247013092041
0.21		0.4964899718761444
0.22		0.4960956871509552
0.23		0.5032385587692261
0.24		0.5007567405700684
0.25		0.5028956532478333
0.26		0.49064719676971436
0.27		0.4727928340435028
0.28		0.464013934135437
0.29		0.44800910353660583
0.30		0.4379093647003174
0.31		0.42408815026283264
0.32		0.4355497658252716
0.33		0.4229515790939331
0.34		0.4048696458339691
0.35		0.42837968468666077
0.36		0.4077020585536957
0.37		0.39351725578308105
0.38		0.40966978669166565
0.39		0.40337321162223816
0.40		0.

KeyboardInterrupt: 

In [54]:

with torch.no_grad():
    activity_per_token = layer_2_resids @ trainer.model._parameters['WT']

    topk_acts = torch.topk(activity_per_token, k=10, dim=0)

In [55]:
def color_text(text, scalar):
    scalar = max(-1, min(1, scalar))
    
    # Calculate RGB values based on scalar
    # Blue (-10) to white (0) to red (10)
    if scalar < 0:
        # From blue to white
        normalized = 1 + scalar / 1  # 0 at -10, 1 at 0
        r = int(255 * normalized)
        g = int(255 * normalized)
        b = 255
    else:
        # From white to red
        normalized = scalar / 1  # 0 at 0, 1 at 10
        r = 255
        g = int(255 * (1 - normalized))
        b = int(255 * (1 - normalized))
    
    # calculate perceived brightness using common luminance formula
    # if luminance is below threshold, use white text; otherwise use black
    luminance = (0.299 * r + 0.587 * g + 0.114 * b) / 255
    text_color = "255;255;255" if luminance < 0.5 else "0;0;0"
    
    return f"\033[48;2;{r};{g};{b}m\033[38;2;{text_color}m{text}\033[0m"

In [56]:
activity_per_token.shape

torch.Size([16841, 3072])

In [57]:
for i, indices in enumerate(topk_acts.indices.T):
    print(f'-------##{i}##--------')
    for idx in indices:
        lo = max(0, idx - 10)
        hi = min(len(input_ids), idx + 10)
        decoded = [
            color_text(tokenizer.decode(input_id), activity_per_token[lo + idy, i])
            for idy, input_id in enumerate(input_ids[lo:hi])
        ]
        print(''.join(decoded))
    if i > 100:
        break

-------##0##--------
[48;2;255;246;246m[38;2;0;0;0m pain[0m[48;2;255;240;240m[38;2;0;0;0m.[0m[48;2;224;224;255m[38;2;0;0;0m Her[0m[48;2;254;254;255m[38;2;0;0;0m mom[0m[48;2;208;208;255m[38;2;0;0;0mmy[0m[48;2;255;149;149m[38;2;0;0;0m came[0m[48;2;255;246;246m[38;2;0;0;0m to[0m[48;2;255;182;182m[38;2;0;0;0m help[0m[48;2;255;219;219m[38;2;0;0;0m her[0m[48;2;171;171;255m[38;2;0;0;0m and[0m[48;2;255;62;62m[38;2;255;255;255m put[0m[48;2;255;215;215m[38;2;0;0;0m some[0m[48;2;255;206;206m[38;2;0;0;0m cool[0m[48;2;255;232;232m[38;2;0;0;0m water[0m[48;2;255;237;237m[38;2;0;0;0m on[0m[48;2;241;241;255m[38;2;0;0;0m her[0m[48;2;255;210;210m[38;2;0;0;0m finger[0m[48;2;206;206;255m[38;2;0;0;0m to[0m[48;2;242;242;255m[38;2;0;0;0m make[0m[48;2;255;175;175m[38;2;0;0;0m it[0m
[48;2;215;215;255m[38;2;0;0;0m asked[0m[48;2;209;209;255m[38;2;0;0;0m her[0m[48;2;230;230;255m[38;2;0;0;0m to[0m[48;2;244;244;255m[38;2;0;0;0m come[0m[48;2;189

In [42]:
idz =34

for idx in torch.topk(activity_per_token[:, idz], k=30).indices:
    lo = max(0, idx - 10)
    hi = min(len(input_ids), idx + 10)
    decoded = [
        color_text(tokenizer.decode(input_id), activity_per_token[lo + idy, idz])
        for idy, input_id in enumerate(input_ids[lo:hi])
    ]
    print(''.join(decoded))

[48;2;59;59;255m[38;2;255;255;255m good[0m[48;2;0;0;255m[38;2;255;255;255m helper[0m[48;2;0;0;255m[38;2;255;255;255m,[0m[48;2;0;0;255m[38;2;255;255;255m Lily[0m[48;2;0;0;255m[38;2;255;255;255m."[0m[48;2;0;0;255m[38;2;255;255;255m Lily[0m[48;2;0;0;255m[38;2;255;255;255m smiled[0m[48;2;215;215;255m[38;2;0;0;0m and[0m[48;2;191;191;255m[38;2;0;0;0m said[0m[48;2;0;0;255m[38;2;255;255;255m,[0m[48;2;255;0;0m[38;2;255;255;255m "[0m[48;2;255;0;0m[38;2;255;255;255mI[0m[48;2;138;138;255m[38;2;0;0;0m love[0m[48;2;135;135;255m[38;2;0;0;0m helping[0m[48;2;0;0;255m[38;2;255;255;255m you[0m[48;2;255;41;41m[38;2;255;255;255m,[0m[48;2;0;0;255m[38;2;255;255;255m Mom[0m[48;2;0;0;255m[38;2;255;255;255mmy[0m[48;2;0;0;255m[38;2;255;255;255m."[0m[48;2;0;0;255m[38;2;255;255;255m And[0m
[48;2;84;84;255m[38;2;255;255;255m see[0m[48;2;4;4;255m[38;2;255;255;255m."[0m[48;2;0;0;255m[38;2;255;255;255m [0m[48;2;184;184;255m[38;2;0;0;0m The[0m[48;

In [36]:
torch.topk(activity_per_token[:, idz], k=20).indices

tensor([ 5105, 12395,  9768,  3724,  8384,  8010,  3399,  1666, 10646,  2022,
         2495,  2028,  3359,  5422,  6217, 15939, 11019,  6925,  7626, 10264])

In [50]:
for i in range(100):
    vector = trainer.model._parameters['WT'][:,i]

    torch.save(vector, f'steering_vectors/SAE_vectors/latent_{i}.pt')