In [1]:

import os
import json
import random
from tqdm import tqdm
import numpy as np
from transformers import (
    GenerationConfig,
)
import torch
import torch.nn.functional as F
from fancy_einsum import einsum
import einops

from record_utils import record_activations, get_module
from hook_utils import HookWithCountThreshold
from explore_utils import *

In [2]:

cos = F.cosine_similarity

In [24]:

config = {
    "data_path": "data/train.parquet",
    "model_path": "checkpoints/TinyZero/v4/actor/global_step_300",
    "probe_path": "probe_checkpoints/v2/probe.pt",
    "batch_size": 64,
    "valid_size": 256,
    "max_prompt_length": 256,
    "max_response_length": 300,
    "n_layers": 36,
    "d_model": 2048,
    "seed": 42,
    "hook_config": {
        "hook_layers": list(range(24, 33)),
        "hook_target_char": " (",
        "hook_target_threshold": 0,
        "hook_scale": 20,
    },
}

seed_all(config["seed"])

In [4]:

actor = load_model(config["model_path"])
generation_config = GenerationConfig(do_sample=False)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [25]:

_, valid_dataloader = get_dataloader(
    config["data_path"],
    config["batch_size"],
    config["max_prompt_length"],
    config["valid_size"],
    actor.tokenizer,
)

original dataset len: 327680
filter dataset len: 327680


In [6]:

probe_model = torch.load(config["probe_path"]).detach().cuda()
probe_model.shape

  probe_model = torch.load(config["probe_path"]).detach().cuda()


torch.Size([36, 2048, 2])

In [7]:


for layer_idx in range(probe_model.shape[0]):
    print(f"Layer {layer_idx}")
    print(cos(probe_model[layer_idx, :, 0], probe_model[layer_idx, :, 1], dim=0))

for layer_idx in range(probe_model.shape[0] - 1):
    print(f"Layer {layer_idx} vs. Layer {layer_idx + 1}")
    print(cos(probe_model[layer_idx, :, 1], probe_model[layer_idx + 1, :, 1], dim=0))

Layer 0
tensor(-0.9812, device='cuda:0')
Layer 1
tensor(-0.9841, device='cuda:0')
Layer 2
tensor(-0.9874, device='cuda:0')
Layer 3
tensor(-0.9926, device='cuda:0')
Layer 4
tensor(-0.9930, device='cuda:0')
Layer 5
tensor(-0.9941, device='cuda:0')
Layer 6
tensor(-0.9937, device='cuda:0')
Layer 7
tensor(-0.9935, device='cuda:0')
Layer 8
tensor(-0.9939, device='cuda:0')
Layer 9
tensor(-0.9932, device='cuda:0')
Layer 10
tensor(-0.9938, device='cuda:0')
Layer 11
tensor(-0.9935, device='cuda:0')
Layer 12
tensor(-0.9932, device='cuda:0')
Layer 13
tensor(-0.9938, device='cuda:0')
Layer 14
tensor(-0.9936, device='cuda:0')
Layer 15
tensor(-0.9939, device='cuda:0')
Layer 16
tensor(-0.9939, device='cuda:0')
Layer 17
tensor(-0.9940, device='cuda:0')
Layer 18
tensor(-0.9944, device='cuda:0')
Layer 19
tensor(-0.9940, device='cuda:0')
Layer 20
tensor(-0.9944, device='cuda:0')
Layer 21
tensor(-0.9942, device='cuda:0')
Layer 22
tensor(-0.9946, device='cuda:0')
Layer 23
tensor(-0.9942, device='cuda:0')
La

In [8]:


def get_mlp_value_vecs(model):
    mlp_value_vecs = [layer.mlp.down_proj.weight for layer in model.model.layers]
    # [n_layers, d_mlp (11008), d_model (2048)]
    return torch.stack(mlp_value_vecs, dim=0)

In [9]:

value_vecs = get_mlp_value_vecs(actor)
print(value_vecs.shape)

top_cos_scores = {0: [], 1: []}
for target_label in [0, 1]:
    for target_probe_layer in range(24, 36):
        target_probe = probe_model[target_probe_layer, :, target_label]

        for layer_idx in range(0, target_probe_layer + 1):
            print(f"Layer {layer_idx}")
            cos_scores = cos(value_vecs[layer_idx], target_probe.unsqueeze(-1), dim=0)
            _topk = cos_scores.topk(k=100)
            _values = [x.item() for x in _topk.values]
            _idxs = [x.item() for x in _topk.indices]
            topk = list(
                zip(
                    _values,
                    _idxs,
                    [target_probe_layer] * _topk.indices.shape[0],
                    [layer_idx] * _topk.indices.shape[0],
                )
            )
            top_cos_scores[target_label].extend(topk)

sorted_scores_0 = sorted(top_cos_scores[0], key=lambda x: x[0], reverse=True)
sorted_scores_1 = sorted(top_cos_scores[1], key=lambda x: x[0], reverse=True)

torch.Size([36, 2048, 11008])
Layer 0
Layer 1
Layer 2
Layer 3
Layer 4
Layer 5
Layer 6
Layer 7
Layer 8
Layer 9
Layer 10
Layer 11
Layer 12
Layer 13
Layer 14
Layer 15
Layer 16
Layer 17
Layer 18
Layer 19
Layer 20
Layer 21
Layer 22
Layer 23
Layer 24
Layer 0
Layer 1
Layer 2
Layer 3
Layer 4
Layer 5
Layer 6
Layer 7
Layer 8
Layer 9
Layer 10
Layer 11
Layer 12
Layer 13
Layer 14
Layer 15
Layer 16
Layer 17
Layer 18
Layer 19
Layer 20
Layer 21
Layer 22
Layer 23
Layer 24
Layer 25
Layer 0
Layer 1
Layer 2
Layer 3
Layer 4
Layer 5
Layer 6
Layer 7
Layer 8
Layer 9
Layer 10
Layer 11
Layer 12
Layer 13
Layer 14
Layer 15
Layer 16
Layer 17
Layer 18
Layer 19
Layer 20
Layer 21
Layer 22
Layer 23
Layer 24
Layer 25
Layer 26
Layer 0
Layer 1
Layer 2
Layer 3
Layer 4
Layer 5
Layer 6
Layer 7
Layer 8
Layer 9
Layer 10
Layer 11
Layer 12
Layer 13
Layer 14
Layer 15
Layer 16
Layer 17
Layer 18
Layer 19
Layer 20
Layer 21
Layer 22
Layer 23
Layer 24
Layer 25
Layer 26
Layer 27
Layer 0
Layer 1
Layer 2
Layer 3
Layer 4
Layer 5
Layer 6


In [11]:

seen = []
for elem in sorted_scores_0[:100]:
    cos_score, mlp_idx, probe_layer_idx, layer_idx = elem
    curr = (layer_idx, mlp_idx)
    if curr in seen:
        continue
    seen.append(curr)
    print(curr)
    print(cos_score)
    print(
        unembed_text(
            actor.model.layers[layer_idx].mlp.down_proj.weight[:, mlp_idx],
            actor.lm_head.weight,
            actor.tokenizer,
            k=10,
        )
    )

(32, 5650)
0.2273489236831665
[' não', ' nicht', '不', '不在', ' không', '不会', ' не', ' 不', '不會', '不是一个']
(24, 5159)
0.19371235370635986
['圬', 'ableObject', 'utow', 'lops', ' \n \n', 'enci', 'ENTE', 'バイ', '调', 'fce']
(32, 767)
0.1869061291217804
[' không', '没有', ' nicht', ' not', '不会', '不能', ' neither', ' não', 'ไม', '沒有']
(30, 6404)
0.1857713907957077
['不是', '并非', ' not', '不是一个', '也不是', 'not', ' NOT', '\tnot', '并不是', '不再是']
(26, 744)
0.18252520263195038
['未能', '不够', ' nicht', '不像', '达不到', '不清楚', '不具备', '不到位', '不符', '还不']
(31, 10127)
0.1824290007352829
[' not', ' không', ' nicht', '不会', '不能', '不是', ' не', 'not', '\tnot', ' não']
(26, 6619)
0.16895824670791626
['缺乏', '缺少', '不方便', ' lacks', '难以', '未能', '无法', ' lack', '得不到', '不符合']
(30, 10722)
0.1651017963886261
['不会', '不會', '也不会', '都不会', ' doesn', '并不会', ' neither', ' doesnt', ' nicht', ' never']
(23, 10504)
0.1642017960548401
['yro', '_Global', '糠', '有色', ' genie', '�回', ' kd', 'onio', '但如果', '冉']
(27, 9766)
0.16189239919185638
['是不可能', ' 

In [12]:

seen = []
for elem in sorted_scores_1[:100]:
    cos_score, mlp_idx, probe_layer_idx, layer_idx = elem
    curr = (layer_idx, mlp_idx)
    if curr in seen:
        continue
    seen.append(curr)
    print(curr)
    print(probe_layer_idx)
    print(cos_score)
    print(
        unembed_text(
            actor.model.layers[layer_idx].mlp.down_proj.weight[:, mlp_idx],
            actor.lm_head.weight,
            actor.tokenizer,
            k=10,
        )
    )

(23, 3143)
24
0.26678231358528137
['闼', 'elif', ' piger', '.Suppress', '_BLOCKS', '锥', '煳', ' TSR', ' Thumb', 'hä']
(26, 6475)
26
0.26015180349349976
['倒是', '不失', '适度', 'successful', '.success', '却是', ' successful', '没事', '完好', '还不错']
(26, 3665)
26
0.23566988110542297
['урс', 'swick', '然而', 'ulton', '最新', '�', '半个', 'iox', '然', '但是']
(28, 10153)
32
0.21838121116161346
['uito', 'Todo', 'mdi', 'それで', 'ISR', 'jal', 'odd', 'jn', 'todo', '_lite']
(22, 4606)
24
0.19218069314956665
['-banner', '<small', 'Rare', '(fin', '.bits', 'ocode', '大大小', 'lij', ' Finch', '_fin']
(31, 3311)
32
0.18836161494255066
['eworthy', 'uele', 'edio', '圄', 'aroo', '��', '迨', 'uptools', 'etheless', ' INTERN']
(26, 4334)
26
0.18598763644695282
['iston', 'StackSize', 'ISIBLE', 'てくれ', '|int', 'ут', '几分', ' TER', 'TERN', '藻']
(29, 6676)
29
0.18184325098991394
[' yes', ' Yes', 'Bindable', ' exactly', 'Yes', '"Yes', 'yes', ' Yep', ' Exactly', ' included']
(35, 8199)
35
0.1802186369895935
['uele', 'nis', '饥饿', '一路', 'wing'

In [None]:

print(
    cos(
        actor.model.layers[25].mlp.down_proj.weight[:, 7613],
        actor.model.layers[25].mlp.down_proj.weight[:, 1688],
        dim=0,
    )
)

print(
    cos(
        actor.model.layers[25].mlp.down_proj.weight[:, 7613],
        probe_model[25, :, 1],
        dim=0,
    )
)

print(
    cos(
        actor.model.layers[25].mlp.down_proj.weight[:, 1688],
        probe_model[25, :, 1],
        dim=0,
    )
)

print(
    cos(
        actor.model.layers[25].mlp.down_proj.weight[:, 1688]
        + actor.model.layers[25].mlp.down_proj.weight[:, 7613],
        probe_model[25, :, 1],
        dim=0,
    )
)

In [None]:

mlp_idxs = [1688, 7613, 9748, 3521, 2929, 8947]

for offset in range(1, len(mlp_idxs)):
    added = actor.model.layers[25].mlp.down_proj.weight[:, mlp_idxs[:offset]].sum(dim=1)

    print(
        unembed_text(
            added,
            # actor.model.layers[25].mlp.down_proj.weight[:, 1688]
            # + actor.model.layers[25].mlp.down_proj.weight[:, 7613],
            actor.lm_head.weight,
            actor.tokenizer,
            k=10,
        )
    )
    print(cos(added, probe_model[25, :, 1], dim=0))

# [28, 10153]
# [29, 6676]

From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


torch.Size([4, 9, 300, 11008])


In [34]:

mlp_layers = list(range(24, 36))
record_module_names = [f"model.layers.{i}.mlp.act_fn" for i in mlp_layers]
max_new_tokens = config["max_response_length"]
tokenizer = actor.tokenizer
token_open = tokenizer.encode(" (")[0]  # 320
token_not = tokenizer.encode("not")[0]  # 1921
token_this = tokenizer.encode("this")[0]  # 574


not_acts = []
this_acts = []
for batch_idx, batch in enumerate(valid_dataloader):

    input_ids = batch["input_ids"].cuda()
    attention_mask = batch["attention_mask"].cuda()

    with record_activations(actor, record_module_names) as recording:
        output = actor.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            generation_config=generation_config,
            output_scores=False,  # this is potentially very large
            return_dict_in_generate=True,
            use_cache=True,
        )

    # len(recording["model.layers.0"]): max_response_length
    # recording["model.layers.0"][0].shape: [batch, prompt_length, d_model]
    # recording["model.layers.0"][1].shape: [batch, 1, d_model]
    recording = {
        layer_name: torch.cat(acts, dim=1) for layer_name, acts in recording.items()
    }

    # recording["model.layers.0"].shape:
    # [batch, prompt_length + max_new_tokens, d_mlp]
    seq = output.sequences
    response = seq[:, -max_new_tokens:]
    response_text = tokenizer.batch_decode(seq, skip_special_tokens=True)

    # [batch, n_layers, response_length, d_mlp]
    activations = torch.stack(
        [acts[:, -max_new_tokens:] for acts in recording.values()], dim=1
    )
    print(resid_stream.shape)

    mask_not = (response[:, :-1] == token_open) & (response[:, 1:] == token_not)
    mask_this = (response[:, :-1] == token_open) & (response[:, 1:] == token_this)
    batch_idx_not, timesteps_not = torch.where(mask_not)
    batch_idx_this, timesteps_this = torch.where(mask_this)

    batch_idx_not = batch_idx_not
    batch_idx_this = batch_idx_this

    overlap_batches = torch.tensor(
        sorted(
            list(set(batch_idx_not.tolist()).intersection(set(batch_idx_this.tolist())))
        )
    ).cuda()
    batch_mask_not = torch.isin(batch_idx_not, overlap_batches)
    batch_mask_this = torch.isin(batch_idx_this, overlap_batches)

    # TODO: probe_timestep_offset.
    filtered_timesteps_not = {
        b_idx: timesteps_not[(batch_idx_not == b_idx)]
        for b_idx in overlap_batches.tolist()
    }
    filtered_timesteps_this = {
        b_idx: timesteps_this[(batch_idx_this == b_idx)]
        for b_idx in overlap_batches.tolist()
    }

    for b_idx in filtered_timesteps_not.keys():
        _not_timesteps = filtered_timesteps_not[b_idx].tolist()
        not_acts.append(
            activations[
                b_idx,
                :,
                _not_timesteps,
            ].cpu()
        )
        _this_timesteps = filtered_timesteps_this[b_idx].tolist()
        this_acts.append(
            activations[
                b_idx,
                :,
                _this_timesteps,
            ].cpu()
        )

not_acts = torch.cat(not_acts, dim=1)
this_acts = torch.cat(this_acts, dim=1)

torch.Size([4, 9, 300, 11008])
torch.Size([4, 9, 300, 11008])
torch.Size([4, 9, 300, 11008])
torch.Size([4, 9, 300, 11008])


In [43]:

_not_acts = not_acts.clone()
_this_acts = this_acts.clone()

# [_layers, d_mlp]
_not_acts = _not_acts.mean(dim=1).cuda()
_this_acts = _this_acts.mean(dim=1).cuda()

_value_vecs = value_vecs[mlp_layers]

print(_not_acts.shape)
print(_value_vecs.shape)

# [layers, d_model, d_mlp]
neg_scaled_value_vecs = _not_acts.unsqueeze(1) * _value_vecs
pos_scaled_value_vecs = _this_acts.unsqueeze(1) * _value_vecs

# [layers, d_mlp]
dot_prods = einsum(
    "layers d_model d_mlp, layers d_model -> layers d_mlp",
    pos_scaled_value_vecs,
    probe_model[mlp_layers, :, 1],
)

for layer_idx in range(dot_prods.shape[0]):
    top_idxs = dot_prods[layer_idx].topk(k=10).indices

    curr_layer = mlp_layers[layer_idx]
    for _idx in top_idxs.tolist():
        print(f"Layer {curr_layer} Index {_idx}")
        curr_value_vecs = actor.model.layers[curr_layer].mlp.down_proj.weight[:, _idx]

        print(unembed_text(curr_value_vecs, actor.lm_head.weight, actor.tokenizer, k=10))

torch.Size([12, 11008])
torch.Size([12, 2048, 11008])
Layer 24 Index 10695
['ofil', 'ilen', '的好', 'illes', ' Toro', 'igans', 'inoa', 'halten', 'odo', 'elda']
Layer 24 Index 1612
['蛸', 'breaking', '丕', 'lá', 'tí', '(spec', 'ece', 'ót', '-spec', 'break']
Layer 24 Index 1689
['iest', '最受欢迎', '正常使用', '也正是', '喷', ' implicit', '妲', '最基本', ' MOST', '前台']
Layer 24 Index 9559
['###', '(es', ' ###', 'InSection', 'ighth', '字号', '不负', '_cou', '남', '�']
Layer 24 Index 3584
['后来', '正式', ' finally', 'finally', '到来', '来了', '实际', 'になると', '日正式', '实质']
Layer 24 Index 3437
['适', ' exactly', '也正是', ' precisely', 'needed', '恰好', '刚好', '正是', ' needed', '正是因为']
Layer 24 Index 9376
[' henne', 'ѣ', ' клуб', 'iamo', ' Cruiser', ' turnovers', ' WaitForSeconds', 'نصوص', '瞭解', 'preced']
Layer 24 Index 2558
[' Auto', 'Auto', ' })(', 'apolis', ' immunity', '识别', ' Soph', '�', 'auto', '&id']
Layer 24 Index 5259
['红线', "(',')\n", '.Companion', 'getQuery', 'rim', " '~", '怼', '홈', '聋', ' humour']
Layer 24 Index 7990
['pe