In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from hook_manager import HookManager
from data_handling import load_tinystories_data
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
model_name = "roneneldan/TinyStories-33M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [4]:
model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-3): 4 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_feat

In [5]:
data = load_tinystories_data('data/tinystories_val.txt')

## get activations

In [6]:
layers_pre_attn = []

with HookManager(model) as hook_manager:
    for layer in range(model.config.num_layers):
        layers_pre_attn.append(hook_manager.attach_residstream_hook(layer=layer))
    for idx, story in enumerate(data[:100]):
        print(idx)
        tokenized = tokenizer(story, return_tensors='pt')
        model.forward(tokenized.input_ids)

all_resids = torch.concat([torch.concat(layer) for layer in layers_pre_attn])

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [7]:
layer_2_resids = torch.concat(layers_pre_attn[2])
ds = TensorDataset(layer_2_resids)
loader = DataLoader(ds, shuffle=True, batch_size=32)

In [8]:
from sae import SaeTrainer

In [9]:
trainer = SaeTrainer(
    input_size=768,
    hidden_size=768*16,
    k=32,
    learning_rate=0.001,
    device='cpu'
)

In [10]:
num_epochs = 5
for epoch_idx in range(num_epochs):
    for idx, batch in enumerate(loader):
        activation = batch[0].detach()
        label = activation.detach()

        loss = trainer.train_step(activation, label)
        print(f'{epoch_idx}.{idx}\t\t{loss}')

0.0		0.5986819863319397
0.1		0.586446225643158
0.2		0.5844685435295105
0.3		0.5697855353355408
0.4		0.5604385733604431
0.5		0.5367621779441833
0.6		0.5444661378860474
0.7		0.536171019077301
0.8		0.519108772277832
0.9		0.5086182355880737
0.10		0.4949188232421875
0.11		0.5010320544242859
0.12		0.4603680372238159
0.13		0.46044185757637024
0.14		0.4428839385509491
0.15		0.44240161776542664
0.16		0.4468660056591034
0.17		0.40532755851745605
0.18		0.3950551748275757
0.19		0.39205053448677063
0.20		0.3803686201572418
0.21		0.38708925247192383
0.22		0.3510204255580902
0.23		0.3228052854537964
0.24		0.30815979838371277
0.25		0.31131234765052795
0.26		0.3229897916316986
0.27		0.2750348448753357
0.28		0.3065382242202759
0.29		0.2948538064956665
0.30		0.2720610499382019
0.31		0.286416620016098
0.32		0.26694604754447937
0.33		0.2803826928138733
0.34		0.26811525225639343
0.35		0.24681459367275238
0.36		0.2576916217803955
0.37		0.2553018033504486
0.38		0.2528240978717804
0.39		0.2773647904396057
0.40

KeyboardInterrupt: 

In [21]:
for i in range(10):
    vector = trainer.model._parameters['WT'][:,i]

    torch.save(vector, f'steering_vectors/test_vector_{i}.pt')