In [1]:

import os
import json
import random
from tqdm import tqdm
import numpy as np
from transformers import (
    GenerationConfig,
)
import torch
import torch.nn.functional as F
from fancy_einsum import einsum
import einops

from record_utils import record_activations, get_module
from hook_utils import HookWithCountThreshold
from explore_utils import *

In [2]:

cos = F.cosine_similarity

In [3]:

config = {
    "data_path": "data/train.parquet",
    "model_path": "checkpoints/TinyZero/v4/actor/global_step_300",
    "probe_path": "probe_checkpoints/probe_from_mlp/probe.pt",
    "batch_size": 64,
    "valid_size": 256,
    "max_prompt_length": 256,
    "max_response_length": 300,
    "n_layers": 36,
    "d_model": 2048,
    "seed": 42,
    "hook_config": {
        "hook_layers": list(range(24, 33)),
        "hook_target_char": " (",
        "hook_target_threshold": 0,
        "hook_scale": 20,
    },
}

seed_all(config["seed"])

In [4]:

actor = load_model(config["model_path"])
generation_config = GenerationConfig(do_sample=False)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:

_, valid_dataloader = get_dataloader(
    config["data_path"],
    config["batch_size"],
    config["max_prompt_length"],
    config["valid_size"],
    actor.tokenizer,
)

original dataset len: 327680
filter dataset len: 327680


In [6]:

probe_model = torch.load(config["probe_path"]).detach().cuda()
probe_model.shape

  probe_model = torch.load(config["probe_path"]).detach().cuda()


torch.Size([36, 2048, 2])

In [7]:


for layer_idx in range(probe_model.shape[0]):
    print(f"Layer {layer_idx}")
    print(cos(probe_model[layer_idx, :, 0], probe_model[layer_idx, :, 1], dim=0))

for layer_idx in range(probe_model.shape[0] - 1):
    print(f"Layer {layer_idx} vs. Layer {layer_idx + 1}")
    print(cos(probe_model[layer_idx, :, 1], probe_model[layer_idx + 1, :, 1], dim=0))

Layer 0
tensor(-0.9433, device='cuda:0')
Layer 1
tensor(-0.9787, device='cuda:0')
Layer 2
tensor(-0.9791, device='cuda:0')
Layer 3
tensor(-0.9797, device='cuda:0')
Layer 4
tensor(-0.9839, device='cuda:0')
Layer 5
tensor(-0.9843, device='cuda:0')
Layer 6
tensor(-0.9832, device='cuda:0')
Layer 7
tensor(-0.9833, device='cuda:0')
Layer 8
tensor(-0.9835, device='cuda:0')
Layer 9
tensor(-0.9833, device='cuda:0')
Layer 10
tensor(-0.9830, device='cuda:0')
Layer 11
tensor(-0.9822, device='cuda:0')
Layer 12
tensor(-0.9827, device='cuda:0')
Layer 13
tensor(-0.9833, device='cuda:0')
Layer 14
tensor(-0.9841, device='cuda:0')
Layer 15
tensor(-0.9826, device='cuda:0')
Layer 16
tensor(-0.9848, device='cuda:0')
Layer 17
tensor(-0.9848, device='cuda:0')
Layer 18
tensor(-0.9848, device='cuda:0')
Layer 19
tensor(-0.9846, device='cuda:0')
Layer 20
tensor(-0.9847, device='cuda:0')
Layer 21
tensor(-0.9840, device='cuda:0')
Layer 22
tensor(-0.9837, device='cuda:0')
Layer 23
tensor(-0.9837, device='cuda:0')
La

In [8]:


def get_mlp_value_vecs(model):
    mlp_value_vecs = [layer.mlp.down_proj.weight for layer in model.model.layers]
    # [n_layers, d_mlp (11008), d_model (2048)]
    return torch.stack(mlp_value_vecs, dim=0)

In [14]:

value_vecs = get_mlp_value_vecs(actor)
print(value_vecs.shape)

top_cos_scores = {0: [], 1: []}
for target_label in [0, 1]:
    for layer_idx, target_probe_layer in enumerate(range(36)):
        target_probe = probe_model[target_probe_layer, :, target_label]

        print(f"Layer {target_probe_layer}")
        cos_scores = cos(value_vecs[target_probe_layer], target_probe.unsqueeze(-1), dim=0)
        _topk = cos_scores.topk(k=20)
        _values = [x.item() for x in _topk.values]
        _idxs = [x.item() for x in _topk.indices]
        topk = list(
            zip(
                _values,
                _idxs,
                [target_probe_layer] * _topk.indices.shape[0],
            )
        )
        top_cos_scores[target_label].extend(topk)

sorted_scores_0 = sorted(top_cos_scores[0], key=lambda x: x[0], reverse=True)
sorted_scores_1 = sorted(top_cos_scores[1], key=lambda x: x[0], reverse=True)

torch.Size([36, 2048, 11008])
Layer 0
Layer 1
Layer 2
Layer 3
Layer 4
Layer 5
Layer 6
Layer 7
Layer 8
Layer 9
Layer 10
Layer 11
Layer 12
Layer 13
Layer 14
Layer 15
Layer 16
Layer 17
Layer 18
Layer 19
Layer 20
Layer 21
Layer 22
Layer 23
Layer 24
Layer 25
Layer 26
Layer 27
Layer 28
Layer 29
Layer 30
Layer 31
Layer 32
Layer 33
Layer 34
Layer 35
Layer 0
Layer 1
Layer 2
Layer 3
Layer 4
Layer 5
Layer 6
Layer 7
Layer 8
Layer 9
Layer 10
Layer 11
Layer 12
Layer 13
Layer 14
Layer 15
Layer 16
Layer 17
Layer 18
Layer 19
Layer 20
Layer 21
Layer 22
Layer 23
Layer 24
Layer 25
Layer 26
Layer 27
Layer 28
Layer 29
Layer 30
Layer 31
Layer 32
Layer 33
Layer 34
Layer 35


In [15]:

seen = []
for elem in sorted_scores_0[:100]:
    cos_score, mlp_idx, probe_layer_idx = elem
    curr = (probe_layer_idx, mlp_idx)
    if curr in seen:
        continue
    seen.append(curr)
    print(curr)
    print(cos_score)
    print(
        unembed_text(
            actor.model.layers[probe_layer_idx].mlp.down_proj.weight[:, mlp_idx],
            actor.lm_head.weight,
            actor.tokenizer,
            k=10,
        )
    )

(0, 1426)
0.37973541021347046
['陉', '_FE', '愉', 'featured', ' span', 'riv', '荑', '-prev', 'oten', ' bre']
(3, 5756)
0.2656635642051697
[' Mey', '.flat', ' instructor', ' deferred', '蔑', ' Scratch', '包', ' flats', 'vasion', 'cipher']
(2, 10899)
0.24624955654144287
['oxide', '.SDK', '旄', '.acquire', '.sdk', 'SES', 'ettle', 'Coeff', ' SDK', '斛']
(2, 617)
0.24192531406879425
['-topic', '_signed', 'afs', 'ters', ' Weeks', '窨', 'topics', ' Signed', 'IGNED', 'ffa']
(30, 6404)
0.241890549659729
['不是', '并非', ' not', '不是一个', '也不是', 'not', ' NOT', '\tnot', '并不是', '不再是']
(22, 6443)
0.23765268921852112
['dex', 'heel', ' neck', 'edm', '脖', 'ベル', '增加值', 'ardu', '.evaluate', 'SYSTEM']
(1, 3363)
0.22943776845932007
['眚', 'issan', 'ote', '出', ' ', ' out', '为中心', '', 'byter', ' by']
(4, 7380)
0.22756074368953705
['\')"\n', '\xa0', '\'"\n', 'Dead', '\'")\n', '…\n', 'emia', 'oline', '\n', '死']
(4, 2484)
0.2183368057012558
['uss', 'cap', ' carr', 'adir', ' Mats', ' Disc', 'os', ' cap', 'aged', 'ideos']
(19,

In [16]:

seen = []
for elem in sorted_scores_1[:100]:
    cos_score, mlp_idx, probe_layer_idx = elem
    curr = (probe_layer_idx, mlp_idx)
    if curr in seen:
        continue
    seen.append(curr)
    print(curr)
    print(probe_layer_idx)
    print(cos_score)
    print(
        unembed_text(
            actor.model.layers[probe_layer_idx].mlp.down_proj.weight[:, mlp_idx],
            actor.lm_head.weight,
            actor.tokenizer,
            k=10,
        )
    )

(23, 3143)
23
0.4406369924545288
['闼', 'elif', ' piger', '.Suppress', '_BLOCKS', '锥', '煳', ' TSR', ' Thumb', 'hä']
(26, 6475)
26
0.3136760890483856
['倒是', '不失', '适度', 'successful', '.success', '却是', ' successful', '没事', '完好', '还不错']
(31, 4851)
31
0.30563002824783325
['CAF', 'bió', 'WithValue', 'PFN', 'isque', 'CTS', 'nda', ' المص', ' Tup', '("(%']
(19, 5818)
19
0.28585493564605713
['不用担心', '大胆', '管', ' #{', '#{', ' feu', '愿意', ' cott', '_#{', ' %{']
(1, 8025)
1
0.2599819004535675
['tip', ' Parkway', '-match', '进门', ' SAL', 'lage', 'zew', 'texts', ' nors', '_FUNCTIONS']
(12, 10661)
12
0.2495301067829132
['nde', 'ndo', 'фин', ' Comple', 'PasswordField', '.prepareStatement', '.Complete', '喉', 'climate', '拉开']
(26, 3665)
26
0.2468251883983612
['урс', 'swick', '然而', 'ulton', '最新', '�', '半个', 'iox', '然', '但是']
(0, 1591)
0
0.23041090369224548
['ellant', 'lando', ' paying', 'ext', 'уществ', ' bench', '_escape', ' afterward', ' escape', ' landlord']
(19, 9952)
19
0.2284160852432251
['大胆', ' url