In [6]:
from pathlib import Path

import datasets  # type: ignore[missingTypeStubs, import-untyped]
import torch
from datasets import IterableDataset  # type: ignore[missingTypeStubs]
from transformer_lens.HookedTransformer import HookedTransformer  # type: ignore[import]

from mechint import hooks
from mechint.device import get_device  # type: ignore[import]
from mechint.mas import algorithm, html
from mechint.mas.algorithm import MASLayer, MASParams

In [11]:
params = MASParams(
    sample_overlap=128, num_max_samples=16, sample_length_pre=96, sample_length_post=32, samples_to_check=4096
)

device = get_device()

model: HookedTransformer = HookedTransformer.from_pretrained("gelu-1l", device=device.torch())  # type: ignore[reportUnknownVariableType]

dataset: IterableDataset = datasets.load_dataset(  # type: ignore[reportUnknownMemberType]
    "monology/pile-uncopyrighted", streaming=True, split="train", trust_remote_code=True
)

hook_point = "blocks.0.mlp.hook_post"
layers = [MASLayer.from_hook_id(hook_point, 2048)]

mas_store = algorithm.run(model, dataset, layers, params, device)

mas_samples = mas_store.feature_samples()
mas_activations = mas_store.feature_activations()

assert mas_samples.shape == (2048, 16, 128)
assert mas_activations.shape == (2048, 16, 128)

assert mas_samples.isfinite().all()
assert mas_activations.isfinite().all()

Loaded pretrained model gelu-1l into HookedTransformer


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Using device: cuda
Model context size: 1024
0%
1%
2%
3%
4%
5%
6%
7%
8%
9%
10%
11%
12%
13%
14%
15%
16%
17%
18%
19%
20%
21%
22%
23%
24%
25%
26%
27%
28%
29%
30%
31%
32%
33%
34%
35%
36%
37%
38%
39%
40%
41%
42%
43%
44%
45%
46%
47%
48%
49%
50%
51%
52%
53%
54%
55%
56%
57%
58%
59%
60%
61%
62%
63%
64%
65%
66%
67%
68%
69%
70%
71%
72%
73%
74%
75%
76%
77%
78%
79%
80%
81%
82%
83%
84%
85%
86%
87%
88%
89%
90%
91%
92%
93%
94%
95%
96%
97%
98%
99%
Time taken: 473.43s
Time taken per sample: 115.58ms
Model time: 47.49s (10.03%)
MAS time: 407.87s (86.15%)


In [12]:
indices = [423, 512, 1502, 30]

for index in indices:
    output_dir = Path("outputs") / "mas_test"
    output_dir.mkdir(parents=True, exist_ok=True)
    with open(output_dir / f"mas_{index}.html", "w") as f:
        f.write(html.generate_html(model, mas_samples[index], mas_activations[index]))
    activations = hooks.neuron_activations(model, hook_point, mas_samples[index], index, device)
    with open(output_dir / f"{index}.html", "w") as f:
        f.write(html.generate_html(model, mas_samples[index], activations))

In [28]:
for i in indices:
    activations = hooks.neuron_activations(model, hook_point, mas_samples[i, :, :], i, device)
    if (activations.argmax(dim=1) != mas_activations[i].argmax(dim=1)).any():
        print("Neuron: {i}")
        print(torch.stack((activations.argmax(dim=1), mas_activations[i].argmax(dim=1)), dim=1))
        print(model.to_str_tokens)
        break

tensor([[99, 96],
        [96, 96],
        [96, 96],
        [94, 96],
        [96, 96],
        [94, 96],
        [96, 96],
        [96, 96],
        [99, 96],
        [96, 96],
        [96, 96],
        [96, 96],
        [96, 96],
        [96, 96],
        [96, 96],
        [96, 96]], device='cuda:0')
tensor([[ 96,  96],
        [ 96,  96],
        [ 96,  96],
        [ 30,  96],
        [ 96,  96],
        [125,  96],
        [ 96,  96],
        [ 96,  96],
        [ 96,  96],
        [ 96,  96],
        [ 96,  96],
        [ 96,  96],
        [ 96,  96],
        [ 87,  96],
        [ 96,  96],
        [ 69,  96]], device='cuda:0')
tensor([[99, 96],
        [95, 96],
        [96, 96],
        [32, 32],
        [48, 48],
        [96, 96],
        [96, 96],
        [96, 96],
        [96, 96],
        [96, 96],
        [96, 96],
        [19, 19],
        [96, 96],
        [96, 96],
        [71, 71],
        [96, 96]], device='cuda:0')
