In [2]:
%pip install transformer_lens

Collecting transformer_lens
  Downloading transformer_lens-2.16.1-py3-none-any.whl.metadata (12 kB)
Collecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-1.10.0-py3-none-any.whl.metadata (19 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting einops>=0.6.0 (from transformer_lens)
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting numpy<2,>=1.24 (from transformer_lens)
  Downloading n

In [1]:
from transformer_lens import HookedEncoderDecoder
import transformer_lens.utils as utils
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES

import torch

torch.set_grad_enabled(False)


torch.autograd.grad_mode.set_grad_enabled(mode=False)

## Loading the Model in TransformerLens

Please download the model first: https://cloud.anja.re/s/Qpo8CZ6yRzDH7ZF

In [4]:
# !wget "https://cloud.anja.re/s/qckH8GQPyN6YK8w/download?path=%2F&files=DSI-large-TriviaQA.zip"
# !unzip "download?path=%2F&files=DSI-large-TriviaQA.zip"
checkpoint = "../DSI-large-TriviaQA"
device = utils.get_device()

OFFICIAL_MODEL_NAMES.append(checkpoint)

hf_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(checkpoint, device=device)
model = HookedEncoderDecoder.from_pretrained(checkpoint, hf_model=hf_model, device=device)

tokenizer_t5 = AutoTokenizer.from_pretrained('google-t5/t5-large')


# Our model has a new token for each document id that we trained it on.

# token id of first document that was added
# first_added_doc_id = len(tokenizer_t5)
# # token id of the last document that was added
# last_added_doc_id = len(tokenizer_t5) + (len(tokenizer) - len(tokenizer_t5))
# del tokenizer_t5


If using T5 for interpretability research, keep in mind that T5 has some significant architectural differences to GPT. The major one is that T5 is an Encoder-Decoder modelAlso, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm


Loaded pretrained model ../DSI-large-TriviaQA into HookedTransformer


## Loading data that was collected by Shir
Not needed in each run

In [3]:
import torch, numpy as np, pickle
from torch.serialization import add_safe_globals, safe_globals

path = "ids_with_more_than_threshold_correct_queries.json"

obj = torch.load(path, weights_only=False)  # risks arbitrary code exec if untrusted

# add_safe_globals works process-wide:
add_safe_globals([np.core.multiarray.scalar])
obj = torch.load(path, weights_only=False)

# OR: use a one-off context (no global change):
with safe_globals([np.core.multiarray.scalar]):
    obj = torch.load(path, weights_only=False)

print(type(obj))

<class 'dict'>


## Loading training and validation data
This part generates the data loader for the training and validation data which are used to get the statistics and the activated neurons dictionary.

In [5]:
#wget "https://cloud.anja.re/s/qckH8GQPyN6YK8w/download?path=%2FGenIR-Data&files=TriviaQAData.zip"
#unzip "download?path=%2FGenIR-Data&files=TriviaQAData.zip"

import json
from torch.utils.data import Dataset, DataLoader

class QuestionsDataset(Dataset):
    def __init__(self, inputs, targets, ids):
        self.inputs = inputs
        self.targets = targets
        self.ids = ids

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.ids[idx]

    def collate_fn(self, batch):
        input_texts, target_texts, ids = zip(*batch)
        return input_texts, target_texts, ids

# with open("../TriviaQAData/test_queries_trivia_qa.json", mode='r') as f:
#   test_data = json.load(f)

with open("../TriviaQAData/train_queries_trivia_qa.json", mode='r') as f:
  train_data = json.load(f)
with open("../TriviaQAData/val_queries_trivia_qa.json", mode='r') as f:
  val_data = json.load(f)
data = train_data + val_data

queries = [entry['query'] for entry in data]
ground_truths = [entry['relevant_docs'] for entry in data]
query_ids = [entry['id'] for entry in data]
dataset = QuestionsDataset(queries, ground_truths, query_ids)
data_loader = DataLoader(dataset, batch_size=16, shuffle = False, collate_fn=dataset.collate_fn)

## Creating activated neurons dictionary
This dictionary includes the correctly answered queries, and for each one we have a list of triggered neurons in each layer's MLP, and the correct document ID.

In [6]:
def extract_activated_neurons(hooks, layer_indices):
    result_dict = {}
    for layer_id in layer_indices:
        hook_post = hooks[f'decoder.{layer_id}.mlp.hook_post']
        activated_neuron_indices = (hook_post > 0).nonzero(as_tuple=False)
        result_dict[f'layer_{layer_id}'] = activated_neuron_indices
    return result_dict

def pad_relevant_docs(relevant_docs):
    relevant_docs = [[int(item) for item in sublist] for sublist in relevant_docs]
    max_len = 0
    for sublist in relevant_docs:
        if len(sublist) > max_len:
            max_len = len(sublist)
    padded_relevant_docs = []
    for sublist in relevant_docs:
        # Calculate how many padding elements are needed
        num_padding = max_len - len(sublist)
        # Create the padded sublist
        padded_sublist = sublist + [-1] * num_padding
        padded_relevant_docs.append(padded_sublist)
    return padded_relevant_docs

def extract_doc_ids_from_output(decoder_output):
    doc_out_ids = []
    for out in doc_out:
        doc_id = re.findall(r"@DOC_ID_([0-9]+)@", decoder_output)
        assert len(doc_id) <= 1
        doc_out_ids.append(doc_id[0] if len(doc_id) else '-1')

In [9]:
import regex as re
from collections import defaultdict

# Create a defaultdict where new keys will get the 'default_item'
result_dict = defaultdict(lambda: {})
count = 0
for batch in data_loader:
    # if count < 161:
    #     count += 1
    #     continue
    layer_indices = range(17,24)
    inputs, relevant_docs, queries = batch
    inputs = list(inputs)
    padded_relevant_docs = pad_relevant_docs(relevant_docs)
    logits, hooks = model.run_with_cache(inputs)
    doc_out = tokenizer.batch_decode(torch.argmax(logits, dim=-1).squeeze(1))
    # doc_out_ids = re.findall(r"@DOC_ID_([0-9]+)@", doc_out)
    doc_out_ids = []
    for out in doc_out:
        doc_id = re.findall(r"@DOC_ID_([0-9]+)@", out)
        assert len(doc_id) <= 1
        doc_out_ids.append(doc_id[0] if len(doc_id) else '-1')
    print(f'Batch {count}/{len(data_loader)}:', doc_out_ids, len(doc_out_ids))
    assert(len(doc_out_ids) == len(inputs))
    padded_relevant_docs = torch.tensor(padded_relevant_docs, dtype=int, device=device)
    doc_out_ids = torch.tensor(list(map(int, doc_out_ids)), dtype=int, device=device).unsqueeze(1)
    correctly_answered = (padded_relevant_docs == doc_out_ids).any(dim=1)
    # correct_inputs = []
    # correct_queries = []
    correct_doc_out_ids = doc_out_ids[correctly_answered]
    correct_relevant_docs = padded_relevant_docs[correctly_answered]
    # print(hooks['decoder.22.mlp.hook_post'].shape)
    activated_neurons_dict = extract_activated_neurons(hooks, layer_indices)
    # Precompute which examples are correct (CPU list to drive Python loops)
    correct_idxs = torch.nonzero(correctly_answered, as_tuple=False).squeeze(1).tolist()
    
    # Build fast slice maps per layer: batch_id -> (start, end) into that layer's index tensor
    layer_slice_maps = {}
    for layer in layer_indices:
        idx = activated_neurons_dict[f'layer_{layer}']  # [K, 3] = (batch, pos, neuron)
        if idx.numel() == 0:
            layer_slice_maps[layer] = {}
            continue
    
        # If you're NOT sure idx is grouped by batch, uncomment the next two lines:
        # order = torch.argsort(idx[:, 0])          # sort once by batch id
        # idx = idx[order]
    
        b = idx[:, 0]
        K = idx.size(0)
        # boundaries where batch id changes
        change = torch.nonzero(b[1:] != b[:-1], as_tuple=False).squeeze(1) + 1  # shape [U-1]
        starts = torch.cat([b.new_zeros(1), change])                            # [U]
        ends   = torch.cat([change, b.new_full((1,), K, dtype=change.dtype, device=change.device)])  # [U]
        uniq_b = b[starts]
    
        # Move tiny vectors to CPU once, build a dict for O(1) lookup
        layer_slice_maps[layer] = {
            int(bi): (int(s), int(e))
            for bi, s, e in zip(uniq_b.tolist(), starts.tolist(), ends.tolist())
        }
    
    # Now fill result_dict without scanning
    for i in correct_idxs:
        current_input_dict = {}
        for layer in layer_indices:
            s_e = layer_slice_maps[layer].get(i)
            if s_e is None:
                current_input_dict[f'layer_{layer}'] = []
                continue
            s, e = s_e
            # Grab neuron ids in one slice, then move once to CPU
            neurons = activated_neurons_dict[f'layer_{layer}'][s:e, -1].cpu().tolist()
            # OPTIONAL: deduplicate to shrink output (uncomment if desired)
            # neurons = sorted(set(neurons))
            current_input_dict[f'layer_{layer}'] = neurons
    
        # Make sure to store plain Python types (no tensors) for JSON
        result_dict[queries[i]] = {
            'activated_neurons': current_input_dict,
            'input': inputs[i],
            'correct_doc_id': int(doc_out_ids[i].item()),
            'relevant_docs': [int(x) for x in relevant_docs[i]],
        }
    count += 1
    if count % 50 == 1:
        with open("activated_neurons_train_val_copy.json", "w") as f:
            json.dump(dict(result_dict), f, indent=4,
                      default=lambda o: o.detach().cpu().tolist() if isinstance(o, torch.Tensor)
                      else (o.tolist() if isinstance(o, np.ndarray) else o))
# Save the dictionary to a JSON file
with open("activated_neurons_train_val_copy.json", "w") as f:
    json.dump(result_dict, f, indent=4) # indent for pretty printing
# print(hooks['decoder.17.mlp.hook_post'][0][0][23])
# print(activated_neurons_dict['layer_17'][:300])

Batch 0/4368: ['7160', '48669', '43055', '4774', '33431', '43042', '21052', '17424', '26635', '25619', '70804', '18974', '28653', '50362', '17548', '32908'] 16
Batch 1/4368: ['1695', '18243', '51049', '36082', '3454', '57396', '10645', '64189', '6528', '25675', '69264', '29694', '41713', '15642', '17503', '72535'] 16
Batch 2/4368: ['67153', '68874', '3086', '69231', '5329', '17150', '42550', '58776', '31816', '24961', '47414', '35182', '72261', '5115', '29151', '6210'] 16
Batch 3/4368: ['52421', '3086', '73647', '26530', '21044', '14463', '63442', '57888', '53329', '50595', '35103', '72512', '31194', '40961', '1085', '46365'] 16
Batch 4/4368: ['513', '16569', '60612', '49223', '32137', '47414', '53554', '50265', '5706', '28894', '33422', '15663', '23511', '69551', '43484', '69050'] 16
Batch 5/4368: ['3004', '49060', '19647', '71016', '65800', '9856', '24503', '39484', '33323', '48975', '6684', '11463', '54996', '69916', '59777', '34335'] 16
Batch 6/4368: ['61759', '62998', '9181', '538

## Creating activation stats with triggering queries
This section takes the result dict which is the activated neurons dictionary(we have it in a JSON) and creates the actionation statistics for each neuron.

In [12]:
import json
from collections import defaultdict

def build_activation_stats_with_entries(result_dict, layer_indices=None, round_to=6, sort_entries=True):
    """
    result_dict: { query_str: {
        "activated_neurons": {"layer_17": [nids...], "layer_18": [...] , ...},
        ...
    }}

    Returns:
      {
        "total_entries": N,
        "by_layer": {
          "layer_22": {
            "307": {"count": c, "percentage": c/N, "entries": [q1, q2, ...]},
            ...
          },
          ...
        },
        "global": {
          "22_307": {"count": c, "percentage": c/N, "entries": [q1, q2, ...]},
          ...
        }
      }
    """
    # Decide which layers to aggregate
    if layer_indices is None:
        first = next(iter(result_dict.values()), None)
        if first is None:
            return {"total_entries": 0, "by_layer": {}, "global": {}}
        layer_keys = list(first.get("activated_neurons", {}).keys())
    else:
        layer_keys = [f"layer_{L}" for L in layer_indices]

    total_entries = len(result_dict)
    if total_entries == 0:
        return {"total_entries": 0, "by_layer": {}, "global": {}}

    # Counters & per-neuron entry collectors
    by_layer_counts   = {lk: defaultdict(int) for lk in layer_keys}
    by_layer_entries  = {lk: defaultdict(set) for lk in layer_keys}
    global_counts     = defaultdict(int)
    global_entries    = defaultdict(set)

    # Aggregate (count once per entry; store query once per neuron)
    for q, entry in result_dict.items():
        per_layer = entry.get("activated_neurons", {})
        for lk in layer_keys:
            neurons = per_layer.get(lk, [])
            if not neurons:
                continue
            for n in set(int(x) for x in neurons):  # dedup within the entry
                by_layer_counts[lk][n] += 1
                by_layer_entries[lk][n].add(q)
                L = lk.split("_")[1]
                gk = f"{L}_{n}"
                global_counts[gk] += 1
                global_entries[gk].add(q)

    # Build JSON-safe dicts with percentages + sorted entries
    by_layer_out = {}
    for lk, cnts in by_layer_counts.items():
        out = {}
        for n, c in cnts.items():
            entries_list = sorted(by_layer_entries[lk][n]) if sort_entries else list(by_layer_entries[lk][n])
            out[str(n)] = {
                "count": c,
                "percentage": round(c / total_entries, round_to),
                "entries": entries_list,
            }
        by_layer_out[lk] = out

    global_out = {}
    for gk, c in global_counts.items():
        entries_list = sorted(global_entries[gk]) if sort_entries else list(global_entries[gk])
        global_out[gk] = {
            "count": c,
            "percentage": round(c / total_entries, round_to),
            "entries": entries_list,
        }

    return {
        "total_entries": total_entries,
        "by_layer": by_layer_out,
        "global": global_out,
    }

# --- usage ---
stats = build_activation_stats_with_entries(result_dict, layer_indices=range(17, 24))
with open("neuron_activation_stats_with_entries.json", "w", encoding="utf-8") as f:
    json.dump(stats, f, indent=2, ensure_ascii=False)


## Loading dictionaries from the JSON files
This part for any processing that we will be doing

In [44]:
import json

def _to_int(x):
    """Best-effort int conversion for scalars, 0-d tensors/np scalars, or single-item lists."""
    try:
        # torch.Tensor / numpy scalar path
        if hasattr(x, "item"):
            return int(x.item())
        return int(x)
    except Exception:
        try:
            import numpy as np
            if isinstance(x, np.generic):
                return int(x)
        except Exception:
            pass
        if isinstance(x, (list, tuple)) and len(x) == 1:
            try:
                return int(x[0])
            except Exception:
                return None
        return None

def load_result_and_stats(
    result_path: str = "activated_neurons_train_val.json",
    stats_path: str = "neuron_activation_stats.json",
):
    # ---- Load result_dict ----
    with open(result_path, "r", encoding="utf-8") as f:
        rd = json.load(f)

    # Support checkpoint format: {"meta": {...}, "data": {...}}
    if isinstance(rd, dict) and "data" in rd and isinstance(rd["data"], dict):
        rd = rd["data"]

    # Normalize types to plain Python ints/lists
    for q, rec in list(rd.items()):
        if not isinstance(rec, dict):
            continue

        # correct_doc_id -> int
        cdi = rec.get("correct_doc_id", None)
        if cdi is not None:
            val = _to_int(cdi)
            if val is None and isinstance(cdi, (list, tuple)) and cdi:
                val = _to_int(cdi[0])
            rec["correct_doc_id"] = val

        # relevant_docs -> list[int]
        rds = rec.get("relevant_docs", None)
        if rds is not None:
            rec["relevant_docs"] = [v for x in rds if (v := _to_int(x)) is not None]

        # activated_neurons -> {layer_key: list[int]}
        an = rec.get("activated_neurons", {})
        if isinstance(an, dict):
            for lk, lst in list(an.items()):
                if isinstance(lst, list):
                    an[lk] = [v for x in lst if (v := _to_int(x)) is not None]
        rec["activated_neurons"] = an
        rd[q] = rec

    # ---- Load stats ----
    with open(stats_path, "r", encoding="utf-8") as f:
        stats = json.load(f)

    # Ensure expected keys exist
    stats.setdefault("total_entries", len(rd))
    stats.setdefault("by_layer", {})
    stats.setdefault("global", {})

    return rd, stats

In [43]:
result_dict, stats = load_result_and_stats(
    "activated_neurons_train_val_copy.json",
    "neuron_activation_stats_with_entries.json"
)

## Helping functions to extract useful data from the activated neurons and the neuron activation stats dictionaries.

In [29]:
def per_layer_sorted_percentages(stats, layer_indices=None, descending=True):
    """
    stats: output from build_activation_stats_with_entries(...) (or the earlier stats builder)
    Returns dict: { "layer_22": [(nid, pct, count), ...], ... } sorted by pct.
    """
    total = stats.get("total_entries", 0) or 1  # avoid div-by-zero
    by_layer = stats.get("by_layer", {})
    layer_keys = [f"layer_{L}" for L in layer_indices] if layer_indices is not None else list(by_layer.keys())

    out = {}
    for lk in layer_keys:
        rows = []
        layer = by_layer.get(lk, {})
        for nid_str, rec in layer.items():
            c = int(rec.get("count", 0))
            pct = c / total
            rows.append((int(nid_str), pct, c))
        rows.sort(key=lambda x: x[1], reverse=descending)
        out[lk] = rows
    return out


def sample_from_sorted(sorted_layers, layer_key, mode="head", n=10, center=0.5):
    """
    mode: 'head' | 'tail' | 'middle'
    - head: top-n highest pct
    - tail: bottom-n lowest pct
    - middle: n around the given center (0..1) in the sorted list
    """
    lst = sorted_layers.get(layer_key, [])
    if not lst or n <= 0:
        return []

    if mode == "head":
        return lst[:n]
    if mode == "tail":
        return lst[-n:]

    # middle
    idx = int(round(center * (len(lst) - 1)))
    half = n // 2
    start = max(0, idx - half)
    end = min(len(lst), start + n)
    start = max(0, end - n)
    return lst[start:end]


def select_percentage_band(sorted_layers, layer_key, min_pct=None, max_pct=None, n=None):
    """
    Filter neurons in a percentage band [min_pct, max_pct], then take first n (already sorted).
    Percentages are in [0,1].
    """
    lst = sorted_layers.get(layer_key, [])
    def ok(x):
        p = x[1]
        return (min_pct is None or p >= min_pct) and (max_pct is None or p <= max_pct)
    res = [t for t in lst if ok(t)]
    return res if n is None else res[:n]

def neuron_correct_doc_ids(stats, result_dict, neuron_id, layer_id, unique=True):
    """
    stats: output from build_activation_stats_with_entries (has 'entries' per neuron)
    result_dict: your original dict {query: {"correct_doc_id": ..., "activated_neurons": {...}}}
    neuron_id: int (neuron index)
    layer_id: int like 22, or string like "layer_22"
    unique: if True, deduplicate doc IDs while preserving order

    Returns: list[int] of correct document IDs for this neuron in the given layer
    """
    # Normalize layer key
    if isinstance(layer_id, int):
        layer_key = f"layer_{layer_id}"
    else:
        layer_key = layer_id if str(layer_id).startswith("layer_") else f"layer_{layer_id}"

    # Get the queries that triggered this neuron (once per entry)
    by_layer = stats.get("by_layer", {})
    neuron_rec = by_layer.get(layer_key, {}).get(str(int(neuron_id)), {})
    queries = neuron_rec.get("entries", [])
    if not queries:
        return []

    def _to_int(x):
        # Robust conversion for int / str / numpy scalar / torch tensor
        try:
            # torch.Tensor path
            if hasattr(x, "item"):
                return int(x.item())
            return int(x)
        except Exception:
            try:
                import numpy as np
                if isinstance(x, np.generic):
                    return int(x)
            except Exception:
                pass
            return None

    doc_ids = []
    seen = set()
    for q in queries:
        rec = result_dict.get(q, {})
        did = _to_int(rec.get("correct_doc_id"))
        if did is None:
            continue
        if unique:
            if did in seen:
                continue
            seen.add(did)
        doc_ids.append(did)

    return doc_ids


In [17]:
# Build once
sorted_layers = per_layer_sorted_percentages(stats, layer_indices=range(17, 24))

In [24]:
# 1) Take top 20 neurons in layer 22
top20_L22 = sample_from_sorted(sorted_layers, "layer_22", mode="head", n=20)

# 2) Take bottom 5 neurons in layer 20
bottom5_L20 = sample_from_sorted(sorted_layers, "layer_20", mode="tail", n=5)

# 3) Take 12 neurons from the middle band of layer 23 (centered at median)
mid12_L23 = sample_from_sorted(sorted_layers, "layer_23", mode="middle", n=12, center=0.5)

# 4) Get neurons in layer 19 with trigger percentage between 5% and 10%
band_L19 = select_percentage_band(sorted_layers, "layer_19", min_pct=0.05, max_pct=0.10)

# If we prefer 0–100%
top10_L22_percent = [(nid, pct*100, cnt) for nid, pct, cnt in top10_L22]

In [39]:
# sample_from_sorted(sorted_layers, "layer_17", mode="tail", n=20)
len(neuron_correct_doc_ids(stats, result_dict, neuron_id=1387, layer_id=17, unique=False))

11

## Old code for getting activated neurons - slow implementation

In [15]:
# # import regex as re
# from collections import defaultdict

# # Create a defaultdict where new keys will get the 'default_item'
# result_dict = defaultdict(lambda: {})
# count = 0
# for batch in data_loader:
#     # if count < 161:
#     #     count += 1
#     #     continue
#     layer_indices = range(17,24)
#     inputs, relevant_docs, queries = batch
#     inputs = list(inputs)
#     padded_relevant_docs = pad_relevant_docs(relevant_docs)
#     logits, hooks = model.run_with_cache(inputs)
#     doc_out = tokenizer.batch_decode(torch.argmax(logits, dim=-1).squeeze(1))
#     # doc_out_ids = re.findall(r"@DOC_ID_([0-9]+)@", doc_out)
#     doc_out_ids = []
#     for out in doc_out:
#         doc_id = re.findall(r"@DOC_ID_([0-9]+)@", out)
#         assert len(doc_id) <= 1
#         doc_out_ids.append(doc_id[0] if len(doc_id) else '-1')
#     print(f'Batch {count}/{len(data_loader)}:', doc_out_ids, len(doc_out_ids))
#     assert(len(doc_out_ids) == len(inputs))
#     padded_relevant_docs = torch.tensor(padded_relevant_docs, dtype=int, device=device)
#     doc_out_ids = torch.tensor(list(map(int, doc_out_ids)), dtype=int, device=device).unsqueeze(1)
#     correctly_answered = (padded_relevant_docs == doc_out_ids).any(dim=1)
#     # correct_inputs = []
#     # correct_queries = []
#     correct_doc_out_ids = doc_out_ids[correctly_answered]
#     correct_relevant_docs = padded_relevant_docs[correctly_answered]
#     # print(hooks['decoder.22.mlp.hook_post'].shape)
#     activated_neurons_dict = extract_activated_neurons(hooks, layer_indices)
#     for i in range(len(inputs)):
#         reached_idx = defaultdict(lambda: 0)
#         if correctly_answered[i]:
#             current_input_dict = {}
#             # correct_inputs.append(inputs[i])
#             # correct_queries.append(queries[i])
#             for layer in layer_indices:
#                 activated_neurons_tmp_list = []
#                 while reached_idx[f'layer_{layer}'] < len(activated_neurons_dict[f'layer_{layer}']):
#                     if i == activated_neurons_dict[f'layer_{layer}'][reached_idx[f'layer_{layer}']][0]:
#                         activated_neurons_tmp_list.append(activated_neurons_dict[f'layer_{layer}'][reached_idx[f'layer_{layer}']][-1].item()) #add the neuron idx
#                         reached_idx[f'layer_{layer}'] += 1
#                         # print(f'stopped here and next idx is:{activated_neurons_dict[f"layer_{layer}"][reached_idx[f"layer_{layer}"]][0]}')
#                     else:
#                         if reached_idx[f'layer_{layer}'] == 0: #first iteration therefore should skip the items until we reach it
#                             while activated_neurons_dict[f"layer_{layer}"][reached_idx[f"layer_{layer}"]][0] != i:
#                                 reached_idx[f'layer_{layer}'] += 1
#                             continue
#                         break
#                 current_input_dict[f'layer_{layer}'] = activated_neurons_tmp_list
#             result_dict[queries[i]] = { 'activated_neurons': current_input_dict, "input": inputs[i], 'correct_doc_id': doc_out_ids[i], 'relevant_docs': relevant_docs[i] }
#     count += 1
#     if count % 50 == 1:
#         with open("activated_neurons_train_val.json", "w") as f:
#             json.dump(dict(result_dict), f, indent=4,
#                       default=lambda o: o.detach().cpu().tolist() if isinstance(o, torch.Tensor)
#                       else (o.tolist() if isinstance(o, np.ndarray) else o))
# # Save the dictionary to a JSON file
# with open("activated_neurons_train_val.json", "w") as f:
#     json.dump(result_dict, f, indent=4) # indent for pretty printing
# # print(hooks['decoder.17.mlp.hook_post'][0][0][23])
# # print(activated_neurons_dict['layer_17'][:300])

## Hook functions for changing a specific neuron activation value in a specific layer

In [291]:
def make_mlp_hook_function(target_token_pos, target_neuron_index, new_activation_value):
    def modify_mlp_neuron_hook(
        activation_tensor: torch.Tensor, 
        hook
    ) -> torch.Tensor:
        """
        A hook function to modify a specific MLP neuron's activation.
        activation_tensor shape: [batch, position, n_mlp_neurons]
        """
        print(f"Hook fired at {hook.name}. Original activation value at "
              f"pos {target_token_pos}, neuron {target_neuron_index}: "
              f"{activation_tensor[:, target_token_pos, target_neuron_index]}")
              # f"{activation_tensor[0, target_token_pos, target_neuron_index].item():.4f}")
    
        # Modify the specific neuron's activation in-place
        # We use [0] for batch dimension assuming a single prompt
        activation_tensor[:, target_token_pos, target_neuron_index] = new_activation_value
        # activation_tensor[0, target_token_pos, target_neuron_index] = new_activation_value
    
        print(f"Modified activation to: "
              f"{activation_tensor[:, target_token_pos, target_neuron_index]}")
              # f"{activation_tensor[0, target_token_pos, target_neuron_index].item():.4f}")
    
        return activation_tensor # Always return the modified tensor

    return modify_mlp_neuron_hook

def run_model_with_activation_hook(model, prompt, mlp_hook_name, neuron_index, neuron_new_value):
    # mlp_hook_name = f"blocks.{target_layer}.mlp.hook_post"
    # Now, run with the hook
    modified_logits = model.run_with_hooks(
        prompt,
        fwd_hooks=[(mlp_hook_name, make_mlp_hook_function(0, neuron_index, neuron_new_value))]
    )
    hook_result = tokenizer.batch_decode(torch.argmax(modified_logits, dim=-1).squeeze(-1))

    logits = model(prompt)
    orig_result = tokenizer.batch_decode(torch.argmax(logits, dim=-1).squeeze(-1))
    
    print(f'original result:{orig_result}, and after using the hook:{hook_result}')
    correct_count = 0
    for i in range(len(orig_result)):
        if orig_result[i] == hook_result[i]:
            correct_count += 1
    print(f'Total correct answered:{correct_count}/ {len(orig_result)}')

def get_affected_prompts(model, queries_dict, mlp_hook_name, layer_index, neuron_index, neuron_new_value):
    layer_activated_neuron_inputs, layer_activated_neurons_correct_doc_ids = [], []
    for key in queries_dict:
        if neuron_index in queries_dict[key]['activated_neurons'][f'layer_{layer_id}']:
            layer_activated_neuron_inputs.append(results_copy[key]['input'])
            layer_activated_neurons_correct_doc_ids.append(results_copy[key]['correct_doc_id'])
    return run_model_with_activation_hook(model, layer_activated_neuron_inputs, mlp_hook_name, neuron_index, neuron_new_value)

In [303]:
hook_layer_id = 23
hook_neuron_id = 3079
mlp_hook_name = f'decoder.{layer_id}.mlp.hook_post'
hook_new_value = 0.0
# prompt = "For which county does Jonathan Trott play cricket?"
# run_model_with_activation_hook(model, prompt, mlp_hook_name, hook_neuron_id, hook_new_value)
get_affected_prompts(model, results_copy, mlp_hook_name, hook_layer_id, hook_neuron_id, hook_new_value)

Hook fired at decoder.17.mlp.hook_post. Original activation value at pos 0, neuron 3079: tensor([1.6953e+01, 4.2768e+00, 1.2313e+01, 2.8081e+00, 1.0876e+01, 2.3650e+00,
        4.9777e+00, 1.6263e+00, 3.2894e+00, 6.7698e+00, 6.8917e+00, 6.1033e+00,
        4.2770e-01, 3.2351e+00, 6.8706e+00, 2.7763e+00, 1.5245e+00, 4.8938e+00,
        1.4285e+01, 2.7602e+00, 2.1249e-01, 1.1253e+00, 1.3686e+01, 1.7106e+00,
        8.0304e+00, 1.6549e+01, 1.8619e+01, 3.0922e+01, 4.3147e+00, 1.4808e+01,
        1.1608e+01, 1.4639e+01, 1.0112e+00, 3.3695e+00, 7.2568e-01, 1.6056e+00,
        2.3771e-02, 1.1469e+00, 1.7134e+00, 3.3589e+00, 4.6800e+00, 7.8607e+00,
        2.9926e+00, 1.2749e+01, 4.1194e-02, 2.3855e-01, 1.2381e+01, 3.8436e+00,
        4.6980e-01, 2.3818e+01, 1.2509e+00, 6.4127e+00, 6.4346e-02, 1.7243e+01,
        9.1733e+00, 3.5950e+00, 2.2917e+01, 1.2488e+00, 7.8565e+00, 4.2466e+00,
        3.7660e+00, 4.8297e+00, 1.4659e+01, 1.1176e+01, 2.2397e+01, 3.8647e-02,
        5.9159e+00, 8.9398e+00,

## Draft cells - just for trying stuff

In [254]:
neuron_id = 3031
layer_id = 17
# layer_17_neuron_3027 = []
layer_activated_neuron_inputs = []
layer_activated_neurons_correct_doc_ids = []
for key in results_copy:
    if neuron_id in results_copy[key]['activated_neurons'][f'layer_{layer_id}']:
        layer_activated_neuron_inputs.append(results_copy[key]['input'])
        layer_activated_neurons_correct_doc_ids.append(results_copy[key]['correct_doc_id'])
        # layer_17_neuron_3027.append([results_copy[key]['input'], results_copy[key]['correct_doc_id']])

In [290]:
# print(layer_activated_neuron_inputs)
# model(layer_activated_neuron_inputs)
tmp = torch.Tensor([0,0]).item()

RuntimeError: a Tensor with 2 elements cannot be converted to Scalar

In [237]:
#['For which county does Jonathan Trott play cricket?', 37981],
logits = model("For which county does Jonathan Trott play cricket?")
tokenizer.decode(torch.argmax(logits, dim=-1).squeeze(-1))

'@DOC_ID_37981@'

In [224]:
print(result_dict[next(iter(result_dict.keys()))]['correct_doc_id'].item())
import copy
print(results_copy[next(iter(results_copy.keys()))])
# results_copy
with open("activated_neurons.json", "w") as f:
    json.dump(results_copy, f, indent=4)

21871
{'activated_neurons': {'layer_17': [10, 22, 28, 42, 43, 51, 56, 59, 79, 89, 107, 108, 128, 130, 134, 137, 157, 184, 221, 245, 277, 286, 349, 361, 381, 382, 396, 406, 408, 431, 442, 457, 501, 502, 558, 564, 576, 587, 595, 603, 622, 626, 639, 641, 652, 672, 678, 679, 686, 694, 695, 702, 719, 727, 734, 765, 775, 789, 793, 807, 812, 819, 853, 906, 921, 942, 946, 1002, 1021, 1058, 1067, 1078, 1093, 1095, 1100, 1134, 1136, 1150, 1210, 1218, 1242, 1250, 1260, 1276, 1316, 1380, 1392, 1420, 1432, 1485, 1517, 1536, 1539, 1610, 1644, 1663, 1683, 1699, 1700, 1705, 1712, 1731, 1733, 1761, 1767, 1832, 1862, 1874, 1890, 1915, 1923, 1940, 1979, 2002, 2013, 2014, 2016, 2047, 2099, 2110, 2123, 2146, 2153, 2184, 2211, 2239, 2283, 2352, 2383, 2388, 2398, 2407, 2419, 2428, 2429, 2441, 2449, 2458, 2460, 2462, 2499, 2502, 2510, 2520, 2526, 2572, 2573, 2579, 2605, 2643, 2648, 2666, 2688, 2692, 2699, 2700, 2712, 2736, 2756, 2757, 2768, 2796, 2797, 2799, 2812, 2824, 2845, 2872, 2873, 2874, 2893, 2917, 293

In [86]:
# print(out[1]['decoder.22.mlp.hook_post'].squeeze(0, 1)[336])
# print(out[1]['decoder.22.hook_mlp_out'].shape)
# print((out[1]['decoder.22.mlp.hook_post'].squeeze(0,1) > 0).nonzero(as_tuple=True))
print(range(10))
print(extract_activated_neurons(out, range(17,24)))

range(0, 10)
{'layer_17': tensor([[   0,    0,    3],
        [   0,    0,   22],
        [   0,    0,   42],
        [   0,    0,   49],
        [   0,    0,   92],
        [   0,    0,  103],
        [   0,    0,  129],
        [   0,    0,  133],
        [   0,    0,  162],
        [   0,    0,  215],
        [   0,    0,  254],
        [   0,    0,  280],
        [   0,    0,  288],
        [   0,    0,  321],
        [   0,    0,  346],
        [   0,    0,  377],
        [   0,    0,  402],
        [   0,    0,  459],
        [   0,    0,  467],
        [   0,    0,  478],
        [   0,    0,  479],
        [   0,    0,  481],
        [   0,    0,  494],
        [   0,    0,  501],
        [   0,    0,  503],
        [   0,    0,  509],
        [   0,    0,  515],
        [   0,    0,  516],
        [   0,    0,  522],
        [   0,    0,  527],
        [   0,    0,  530],
        [   0,    0,  536],
        [   0,    0,  563],
        [   0,    0,  570],
        [   0,    0,  

In [59]:
print(out[1].keys())

dict_keys(['hook_embed', 'encoder.0.hook_resid_pre', 'encoder.0.ln1.hook_scale', 'encoder.0.ln1.hook_normalized', 'encoder.0.attn.hook_q', 'encoder.0.attn.hook_k', 'encoder.0.attn.hook_v', 'encoder.0.attn.hook_attn_scores', 'encoder.0.attn.hook_pattern', 'encoder.0.attn.hook_z', 'encoder.0.hook_attn_out', 'encoder.0.hook_resid_mid', 'encoder.0.ln2.hook_scale', 'encoder.0.ln2.hook_normalized', 'encoder.0.mlp.hook_pre', 'encoder.0.mlp.hook_post', 'encoder.0.hook_mlp_out', 'encoder.0.hook_resid_post', 'encoder.1.hook_resid_pre', 'encoder.1.ln1.hook_scale', 'encoder.1.ln1.hook_normalized', 'encoder.1.attn.hook_q', 'encoder.1.attn.hook_k', 'encoder.1.attn.hook_v', 'encoder.1.attn.hook_attn_scores', 'encoder.1.attn.hook_pattern', 'encoder.1.attn.hook_z', 'encoder.1.hook_attn_out', 'encoder.1.hook_resid_mid', 'encoder.1.ln2.hook_scale', 'encoder.1.ln2.hook_normalized', 'encoder.1.mlp.hook_pre', 'encoder.1.mlp.hook_post', 'encoder.1.hook_mlp_out', 'encoder.1.hook_resid_post', 'encoder.2.hook_r

In [42]:
idx = 4
input_question = next(iter(data_loader))[0][idx]
input_relevant_docs = next(iter(data_loader))[1][idx]
logits = model(input_question)
res = tokenizer.decode(torch.argmax(logits, dim=-1).squeeze(-1))
print(f'Model predicted document:{res}, and right answer is:{input_relevant_docs}')

Model predicted document:@DOC_ID_52288@, and right answer is:['67981', '52288']


In [5]:
correct_queries = []
decoder_input = torch.tensor([[0]])
for input_texts, target_texts, ids in data_loader:
  input_tokens = tokenizer(input_texts, return_tensors='pt', padding=True)['input_ids']
  logits = model.forward(input_tokens, decoder_input)
  # print(torch.argmax(logits, dim=-1).squeeze(-1))
  res = tokenizer.decode(torch.argmax(logits, dim=-1).squeeze(-1))
  doc_ids = [s for s in res.replace('@','_').split(sep='_') if s.isdigit()]
  correct_queries += [(id, query, predicted, truth) for id,query,predicted,truth in zip(ids, input_texts, doc_ids, target_texts) if predicted in truth]


# for entry in training_data:
#   id, query, relevant_docs = entry
#   input_tokens = tokenizer(query, return_tensors='pt')['input_ids']
#   decoder_input = torch.tensor([[0]])

#   logits, cache = model.run_with_cache(input_tokens, decoder_input, remove_batch_dim=True)
#   res = tokenizer.decode(torch.argmax(logits, dim=-1)[0][0])
#   if res in relevant_docs:
#     correct_queries.append(entry)



In [17]:
ids = [entry[0] for entry in correct_queries]
queries = [entry[1] for entry in correct_queries]
truths = [entry[3] for entry in correct_queries]

correct_dataset = QuestionsDataset(queries, truths, ids)
dl = DataLoader(correct_dataset, batch_size=16, shuffle = False, collate_fn=dataset.collate_fn)

decoder_input = torch.tensor([[0]])
cached_mlps = {}
for input_texts, target_texts, ids in dl:
  input_tokens = tokenizer(input_texts, return_tensors='pt', padding=True)['input_ids']
  _, cache = model.run_with_cache(input_tokens, decoder_input)
  for layer in range(18, 24):
      cached_mlps[f"layer_{layer}"] = torch.cat((cached_mlps.setdefault(f"layer_{layer}", torch.Tensor()), cache[f"decoder.{layer}.mlp.hook_post"]), dim = 0)

torch.save(cached_mlps,"cached_mlp_from_correct_queries.json")




In [42]:
# cached = torch.load("cached_mlp_from_correct_queries.json")
# cached_with_query_id = {q_id : {key : cached_layer.squeeze(1)[index] for key,cached_layer in cached.items()} for index, q_id in enumerate(ids)} 
# queries_predicted = {l[0] : l[2] for l in correct_queries}
queries_predicted
# TODO: what do we need to save for each doc id and query
# dict: doc-id -> num-of-valid-queries
# doc-id -> (activations, valid-queries)
# TODO: make sure all docs were indexed



{'QTest0': '21871',
 'QTest2': '70062',
 'QTest4': '52288',
 'QTest5': '38019',
 'QTest6': '8466',
 'QTest7': '73330',
 'QTest8': '9181',
 'QTest9': '70053',
 'QTest10': '52421',
 'QTest11': '13600',
 'QTest12': '60198',
 'QTest14': '46780',
 'QTest15': '50855',
 'QTest17': '68189',
 'QTest18': '11358',
 'QTest20': '42034',
 'QTest21': '66108',
 'QTest22': '4579',
 'QTest23': '38155',
 'QTest24': '59353',
 'QTest25': '36612',
 'QTest27': '52886',
 'QTest28': '52588',
 'QTest29': '34722',
 'QTest30': '23458',
 'QTest31': '7944',
 'QTest32': '23003',
 'QTest34': '11482',
 'QTest35': '56292',
 'QTest38': '45635',
 'QTest39': '29055',
 'QTest41': '57923',
 'QTest42': '59708',
 'QTest45': '39093',
 'QTest46': '45883',
 'QTest47': '59708',
 'QTest48': '51159',
 'QTest49': '21815',
 'QTest50': '56658',
 'QTest51': '55325',
 'QTest52': '3839',
 'QTest55': '4181',
 'QTest56': '29055',
 'QTest57': '34767',
 'QTest58': '67193',
 'QTest59': '59878',
 'QTest60': '18097',
 'QTest61': '68811',
 'QTes

In [None]:
query = "test query"

input_tokens = tokenizer(query, return_tensors='pt')['input_ids']
decoder_input = torch.tensor([[0]])

logits, cache = model.run_with_cache(input_tokens, decoder_input, remove_batch_dim=True)

In [None]:
logits

In [None]:
# Prediction from the logits
torch.argmax(logits, dim=-1), tokenizer.decode(torch.argmax(logits, dim=-1)[0][0])

## Examining the activations

The activations of each component in the transformer are stored in the `cache` object. It's basically a dict from which you choose which component to look at.

Here, we print all possible component keys for layer 0 in the decoder:

In [None]:
for key in cache.keys():
  if key.startswith('decoder.0.'):
    print(key)

We choose to look at the output of the MLP in layer 19 of the decoder:

In [None]:
cache['decoder.19.hook_mlp_out'], cache['decoder.19.hook_mlp_out'].shape

Take a look at where the MLP hooks are computed: https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/components/mlps/mlp.py

`hook_pre`: Before activation,
`hook_post`: After applying activation