In [1]:
# load subject model
# load SAEs without attaching them to the model
# for now just use the Islam feature and explanation
# load a scorer. The prompt should have the input as well this time
# (for now) on random pretraining data, evaluate gpt2 with a hook that 
# adds a multiple of the Islam feature to the appropriate residual stream layer and position
# Get the pre- and post-intervention output distributions of gpt2
# (TODO: check if all the Islam features just have similar embeddings)
# Show this to the scorer and get a score (scorer should be able to have a good prior without being given the clean output distribution)
# Also get a simplicity score for the explanation

In [1]:
import pandas as pd
from pathlib import Path
import json

results_dir = "/mnt/ssd-1/gpaulo/SAE-Zoology/results/gpt2_simulation/all_at_once"
results = dict()
for fname in Path(results_dir).iterdir():
    with open(fname, "r") as f:
        r = json.load(f)
    last = fname.stem.split(".")[-1]
    layer = int(last.split("_")[0])
    feat = int(last[last.index("_feature") + len("_feature"):])
    results[fname.stem] = {"ev_correlation_score": r["ev_correlation_score"], "layer": layer, "feature": feat}
input_scores_df = pd.DataFrame(results).T
input_scores_df["layer"] = input_scores_df["layer"].astype(int)
input_scores_df["feature"] = input_scores_df["feature"].astype(int)
input_scores_df = input_scores_df.sort_values("ev_correlation_score", ascending=False)
unq_layers = input_scores_df["layer"].unique()
input_scores_df

Unnamed: 0,ev_correlation_score,layer,feature
.transformer.h.2_feature0,0.970093,2,0
.transformer.h.2_feature19,0.966378,2,19
.transformer.h.0_feature0,0.952401,0,0
.transformer.h.4_feature4,0.952061,4,4
.transformer.h.0_feature5,0.949993,0,5
.transformer.h.2_feature4,0.941871,2,4
.transformer.h.2_feature11,0.930066,2,11
.transformer.h.4_feature19,0.918787,4,19
.transformer.h.0_feature14,0.906342,0,14
.transformer.h.0_feature3,0.89708,0,3


In [2]:
import json
import random

with open("pile.jsonl", "r") as f:
    pile = random.sample([json.loads(line) for line in f.readlines()], 10000)

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = "cuda:0"

subject_name = "gpt2"
subject = AutoModelForCausalLM.from_pretrained(subject_name).to(device)
subject_tokenizer = AutoTokenizer.from_pretrained(subject_name)
subject_tokenizer.pad_token = subject_tokenizer.eos_token
subject.config.pad_token_id = subject_tokenizer.eos_token_id

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
scorer_name = "meta-llama/Meta-Llama-3.1-8B"
scorer = AutoModelForCausalLM.from_pretrained(scorer_name).to(torch.bfloat16).to(device)
scorer_tokenizer = AutoTokenizer.from_pretrained(scorer_name)
scorer_tokenizer.pad_token = scorer_tokenizer.eos_token
scorer.config.pad_token_id = scorer_tokenizer.eos_token_id
scorer.generation_config.pad_token_id = scorer_tokenizer.eos_token_id

# explainer is the same model as the scorer
explainer = scorer
explainer_tokenizer = scorer_tokenizer


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.73it/s]


In [5]:
from dataclasses import dataclass
import copy

@dataclass
class ExplainerInterventionExample:
    prompt: str
    top_tokens: list[str]
    top_p_increases: list[float]

    def text(self) -> str:
        tokens_str = ", ".join(f"'{tok}' (+{round(p, 3)})" for tok, p in zip(self.top_tokens, self.top_p_increases))
        return f"<PROMPT>{self.prompt}</PROMPT>\nMost increased tokens: {tokens_str}"
    
@dataclass
class ExplainerNeuronFormatter:
    intervention_examples: list[ExplainerInterventionExample]
    explanation: str | None = None

    def text(self) -> str:
        text = "\n\n".join(example.text() for example in self.intervention_examples)
        text += "\n\nExplanation: "
        if self.explanation is not None:
            text += self.explanation
        return text


def get_explainer_prompt(neuron_prompter: ExplainerNeuronFormatter, few_shot_examples: list[ExplainerNeuronFormatter] | None = None) -> str:
    prompt = "We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.\n\n" \
        "For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.\n\n" \
        "The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.\n\n"
    
    i = 1
    for few_shot_example in few_shot_examples or []:
        assert few_shot_example.explanation is not None
        prompt += f"Neuron {i}\n" + few_shot_example.text() + "\n\n"
        i += 1

    prompt += f"Neuron {i}\n"
    prompt += neuron_prompter.text()

    return prompt


fs_examples = [
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="My favorite food is",
                top_tokens=[" oranges", " bananas", " apples"],
                top_p_increases=[0.81, 0.09, 0.02]
            ),
            ExplainerInterventionExample(
                prompt="Whenever I would see",
                top_tokens=[" fruit", " a", " apples", " red"],
                top_p_increases=[0.09, 0.06, 0.06, 0.5]
            )
        ],
        explanation="fruits"
    ),
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="Once upon a time",
                top_tokens=[" there was", " a", " a time"],
                top_p_increases=[0.22, 0.2, 0.05]
            )
        ],
        explanation="storytelling"
    ),
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="He owned the watch for a long time. While he never said it was",
                top_tokens=[" hers", " hers", " hers"],
                top_p_increases=[0.09, 0.06, 0.06, 0.5]
            ),
            ExplainerInterventionExample(
                prompt="For some reason",
                top_tokens=[" she", " her", " hers"],
                top_p_increases=[0.14, 0.01, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="insurance does not cover",
                top_tokens=[" her", " women", " her's"],
                top_p_increases=[0.10, 0.02, 0.01]
            )
        ],
        explanation="she/her pronouns"
    )
]

neuron_prompter = copy.deepcopy(fs_examples[0])
neuron_prompter.explanation = None
print(get_explainer_prompt(neuron_prompter, fs_examples))


We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.

For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.

The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.

Neuron 1
<PROMPT>My favorite food is</PROMPT>
Most increased tokens: ' oranges' (+0.81), ' bananas' (+0.09), ' apples' (+0.02)

<PROMPT>Whenever I would see</PROMPT>
Most increased tokens: ' fruit' (+0.09), ' a' (+0.06), ' apples' (+0.06), ' red' (+0.5)

Explanation: fruits

Neuron 2
<PROMPT>Once upon a time</PROMPT>
Most increased tokens: ' there was' (+0.22), ' a' (+0.2), ' a time' (+0.05)

Explanation: storytelling

Neuron 3
<PROMPT>He owned the watch for a long time. While he never said it was</PROMPT>
Most increased tokens: ' hers' (+0.09), ' her

In [6]:
def get_scorer_simplicity_prompt(explanation):
    prefix = "Explanation\n\n"
    return f"{prefix}{explanation}{scorer_tokenizer.eos_token}", prefix

def get_scorer_predictiveness_prompt(prompt, explanation, few_shot_prompts=None, few_shot_explanations=None, few_shot_tokens=None):
    if few_shot_explanations is not None:
        assert few_shot_tokens is not None and few_shot_prompts is not None
        assert len(few_shot_explanations) == len(few_shot_tokens) == len(few_shot_prompts)
        few_shot_prompt = "\n\n".join(get_scorer_predictiveness_prompt(pr, expl) + token for pr, expl, token in zip(few_shot_prompts, few_shot_explanations, few_shot_tokens)) + "\n\n"
    else:
        few_shot_prompt = ""
    return few_shot_prompt + f"Explanation: {explanation}\n<PROMPT>{prompt}</PROMPT>"

few_shot_prompts = ["My favorite food is", "From west to east, the westmost of the seven", "He owned the watch for a long time. While he never said it was"]
few_shot_explanations = ["fruits and vegetables", "ateg", "she/her pronouns"]
few_shot_tokens = [" oranges", "WAY", " hers"]
print(get_scorer_predictiveness_prompt(few_shot_prompts[0], few_shot_explanations[0], few_shot_prompts, few_shot_explanations, few_shot_tokens))

Explanation: fruits and vegetables
<PROMPT>My favorite food is</PROMPT> oranges

Explanation: ateg
<PROMPT>From west to east, the westmost of the seven</PROMPT>WAY

Explanation: she/her pronouns
<PROMPT>He owned the watch for a long time. While he never said it was</PROMPT> hers

Explanation: fruits and vegetables
<PROMPT>My favorite food is</PROMPT>


In [7]:
from functools import partial

def intervene(module, input, output, intervention_strength=10.0, position=-1):
    hiddens = output[0]  # the later elements of the tuple are the key value cache
    hiddens[:, position, :] += intervention_strength * feat.to(hiddens.device)

def get_texts(n, seed=42):
    random.seed(seed)
    texts = []
    for _ in range(n):
        # sample a random text from the pile, and stop it at a random token position, less than 64 tokens
        text = random.choice(pile)["text"]
        text = text.replace("\n", "\\n")
        tokenized_text = subject_tokenizer.encode(text, add_special_tokens=False, max_length=64, truncation=True)
        stop_pos = random.randint(1, min(len(tokenized_text) - 1, 63))
        text = subject_tokenizer.decode(tokenized_text[:stop_pos])
        texts.append(text)
    return texts

n_explainer_texts = 3
n_scorer_texts = 10
# explainer_texts = get_texts(n_explainer_texts)
# explainer_texts = ["Current religion:", "A country that is", "Many people believe that"]
# scorer_texts = get_texts(n_scorer_texts)

In [8]:
scorer_vocab = scorer_tokenizer.get_vocab()
subject_vocab = subject_tokenizer.get_vocab()

# Pre-compute the mapping of subject tokens to scorer tokens
subject_to_scorer = {}
text_subject_to_scorer = {}
for subj_tok, subj_id in subject_vocab.items():
    if subj_tok in scorer_vocab:
        subject_to_scorer[subj_id] = scorer_vocab[subj_tok]
        text_subject_to_scorer[subj_tok] = subj_tok
    else:
        for i in range(len(subj_tok) - 1, 0, -1):
            if subj_tok[:i] in scorer_vocab:
                subject_to_scorer[subj_id] = scorer_vocab[subj_tok[:i]]
                text_subject_to_scorer[subj_tok] = subj_tok[:i]
                break
        else:
            raise ValueError(f"No scorer token found for {subj_tok}")
subject_ids = torch.tensor(list(subject_to_scorer.keys()), device=device)
scorer_ids = torch.tensor(list(subject_to_scorer.values()), device=device)

In [14]:
from sae_auto_interp.autoencoders.OpenAI.model import Autoencoder
weight_dir = "/mnt/ssd-1/gpaulo/SAE-Zoology/weights/gpt2_128k"

path = f"{weight_dir}/{layer}.pt"
state_dict = torch.load(path)
ae = Autoencoder.from_state_dict(state_dict=state_dict)
print(f"{ae.activation.k=}")

ae.activation.k=32


In [15]:
from collections import Counter
import torch
from sae_auto_interp.autoencoders.OpenAI.model import Autoencoder
from itertools import product
from tqdm.auto import tqdm
import time

weight_dir = "/mnt/ssd-1/gpaulo/SAE-Zoology/weights/gpt2_128k"

all_results = []

total_iterations = 10 * 7  # 10 feature indices, 7 layer options
for feat_idx, feat_layer in tqdm(product(range(200, 250), [2, 6, 11]), total=total_iterations):
    scorer_intervention_strengths = [10, 32, 100, 320, 1000]
    explainer_intervention_strength = 32

    path = f"{weight_dir}/{layer}.pt"
    state_dict = torch.load(path)
    ae = Autoencoder.from_state_dict(state_dict=state_dict)
    feat = ae.decoder.weight[:, feat_idx].to(device)
    encoder_feat = ae.encoder.weight[feat_idx, :].to(device)
    del ae

    # find examples where the feature activates
    n_candidate_texts = 100

    # Remove any hooks
    for l in range(len(subject.transformer.h)):
        subject.transformer.h[layer]._forward_hooks.clear()

    texts = get_texts(n_candidate_texts)
    subtexts = []
    subtext_acts = []
    for text in texts:
        input_ids = subject_tokenizer(text, return_tensors="pt").input_ids.to(device)
        with torch.inference_mode():
            out = subject(input_ids, output_hidden_states=True)
            # hidden_states is actually one longer than the number of layers, because it includes the input embeddings
            h = out.hidden_states[layer + 1].squeeze(0)
            feat_acts = h @ encoder_feat

        for i in range(1, len(feat_acts) + 1):
            reassembled_text = subject_tokenizer.decode(input_ids[0, :i])
            subtexts.append(reassembled_text)
            subtext_acts.append(feat_acts[i - 1])

    # get top k
    # Sort subtexts by activation and get top k
    sorted_indices = sorted(range(len(subtext_acts)), key=lambda i: subtext_acts[i], reverse=True)
    top_k_indices = sorted_indices[:n_scorer_texts + n_explainer_texts]

    # Get top k subtexts and their activations
    top_k_subtexts = [subtexts[i] for i in top_k_indices]
    top_k_activations = [subtext_acts[i] for i in top_k_indices]

    # Print top k results
    print("Top subtexts with highest feature activation:")
    for i, (subtext, activation) in enumerate(zip(top_k_subtexts, top_k_activations), 1):
        print(f"{i}. Activation: {activation:.4f}")
        print(f"   Text: {subtext}")
        print()

    random.shuffle(top_k_subtexts)
    scorer_texts = top_k_subtexts[:n_scorer_texts]
    explainer_texts = top_k_subtexts[n_scorer_texts:]

    # get explanation
    def get_subject_logits(text, layer, intervention_strength=0.0, position=-1):
        for l in range(len(subject.transformer.h)):
            subject.transformer.h[l]._forward_hooks.clear()
        subject.transformer.h[layer].register_forward_hook(partial(intervene, intervention_strength=intervention_strength, position=-1))

        inputs = subject_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(device)
        with torch.inference_mode():
            outputs = subject(**inputs)

        return outputs.logits[0, -1, :]

    intervention_examples = []
    for text in explainer_texts:
        clean_logits = get_subject_logits(text, feat_layer, intervention_strength=0.0)
        intervened_logits = get_subject_logits(text, feat_layer, intervention_strength=explainer_intervention_strength)
        top_probs = (intervened_logits.softmax(dim=-1) - clean_logits.softmax(dim=-1)).topk(10)
        # top_logits = intervened_logits.topk(10)
        top_tokens = [subject_tokenizer.decode(i) for i in top_probs.indices]
        top_p_increases = top_probs.values.tolist()
        intervention_examples.append(
            ExplainerInterventionExample(
                prompt=text,
                top_tokens=top_tokens,
                top_p_increases=top_p_increases
            )
        )

    neuron_prompter = ExplainerNeuronFormatter(
        intervention_examples=intervention_examples
    )

    # TODO: improve the few-shot examples
    explainer_prompt = get_explainer_prompt(neuron_prompter, fs_examples)
    explainer_input_ids = explainer_tokenizer(explainer_prompt, return_tensors="pt").input_ids.to(device)
    with torch.inference_mode():
        samples = explainer.generate(explainer_input_ids, max_new_tokens=100, eos_token_id=explainer_tokenizer.encode("\n\n")[-1], num_return_sequences=10)[:, explainer_input_ids.shape[1]:]
    explanations = Counter([explainer_tokenizer.decode(sample).split("\n\n")[0].strip() for sample in samples])
    explanation = explanations.most_common(1)[0][0]
    print(explanations)

    predictiveness_scores = []
    max_intervened_probs = []
    for scorer_intervention_strength in tqdm(scorer_intervention_strengths):
        
        predictiveness_score = torch.tensor(0.0, device=device)
        max_intervened_prob = 0.0
        
        for text in scorer_texts:
            
            intervened_probs = get_subject_logits(text, feat_layer, intervention_strength=scorer_intervention_strength).softmax(dim=-1)
            max_intervened_prob = max(max_intervened_prob, intervened_probs.max().item())

            # get the explanation predictiveness
            scorer_predictiveness_prompt = get_scorer_predictiveness_prompt(text, explanation, few_shot_prompts, few_shot_explanations, few_shot_tokens)
            scorer_input_ids = scorer_tokenizer(scorer_predictiveness_prompt, return_tensors="pt").input_ids.to(device)
            with torch.inference_mode():
                scorer_logits = scorer(scorer_input_ids).logits[0, -1, :]
                scorer_logp = scorer_logits.log_softmax(dim=-1)
            
            predictiveness_score += (intervened_probs[subject_ids] * scorer_logp[scorer_ids]).sum()

            # Print tokens with high probability (if needed)
            # high_prob_mask = intervened_probs > 0.05
            # high_prob_tokens = subject_tokenizer.convert_ids_to_tokens(high_prob_mask.nonzero().squeeze())
            # high_prob_values = intervened_probs[high_prob_mask]
            # for tok, val in zip(high_prob_tokens, high_prob_values):
            #     print(tok, val.item())

        max_intervened_probs.append(max_intervened_prob)
        predictiveness_scores.append(predictiveness_score.item() / len(scorer_texts))
        
    predictiveness_score = sum(predictiveness_scores) / len(predictiveness_scores)
    max_intervened_prob = max(max_intervened_probs)
    all_results.append({
        "feat_idx": feat_idx,
        "feat_layer": feat_layer,
        "explanation": explanation,
        "predictiveness_score": predictiveness_score,
        "intervention_examples": intervention_examples,
        "max_intervened_prob": max_intervened_prob,
        "scorer_intervention_strengths": scorer_intervention_strengths,
        "explainer_intervention_strength": explainer_intervention_strength,
        "scorer_texts": scorer_texts,
        "explainer_texts": explainer_texts,
        "predictiveness_scores": predictiveness_scores,
        "max_intervened_probs": max_intervened_probs,
    })
all_results

  0%|          | 0/70 [00:00<?, ?it/s]

Top subtexts with highest feature activation:
1. Activation: 133.6860
   Text: Ak

2. Activation: 133.1981
   Text: //

3. Activation: 133.1981
   Text: //

4. Activation: 133.0989
   Text: Em

5. Activation: 133.0280
   Text: 1

6. Activation: 133.0280
   Text: 1

7. Activation: 133.0280
   Text: 1

8. Activation: 132.8722
   Text: US

9. Activation: 132.8492
   Text: K

10. Activation: 132.8124
   Text: Can

11. Activation: 132.7975
   Text: Ryan

12. Activation: 132.7851
   Text: ;

13. Activation: 132.6943
   Text: G

Counter({'1/2/3/4': 2, '2/3': 1, '1,2,3,4,5': 1, '1': 1, '1-2-3-4-5': 1, '2': 1, "2's": 1, "2's, 3's, and 4's": 1, '1, 2, 3, 4': 1})


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
100%|██████████| 5/5 [00:01<00:00,  2.61it/s]
  1%|▏         | 1/70 [00:07<08:18,  7.22s/it]

Top subtexts with highest feature activation:
1. Activation: 133.6860
   Text: Ak

2. Activation: 133.1981
   Text: //

3. Activation: 133.1981
   Text: //

4. Activation: 133.0989
   Text: Em

5. Activation: 133.0280
   Text: 1

6. Activation: 133.0280
   Text: 1

7. Activation: 133.0280
   Text: 1

8. Activation: 132.8722
   Text: US

9. Activation: 132.8492
   Text: K

10. Activation: 132.8124
   Text: Can

11. Activation: 132.7975
   Text: Ryan

12. Activation: 132.7851
   Text: ;

13. Activation: 132.6943
   Text: G

Counter({'2, 3, 4': 4, '2/3/4': 2, "2, 3, 4, '<|end_of_text|><|begin_of_text|>\n' (+0.001),'R' (+0.001)": 1, '2, 3, 4, and The': 1, '3/4 of the most increased tokens are punctuation, indicating the neuron promotes writing style': 1, "2's, 3's, 4's": 1})


100%|██████████| 5/5 [00:01<00:00,  2.62it/s]
  3%|▎         | 2/70 [00:20<11:30, 10.15s/it]


KeyboardInterrupt: 

In [14]:
all_df

Unnamed: 0,feat_idx,feat_layer,explanation,predictiveness_score,intervention_examples,max_intervened_prob,scorer_intervention_strengths,explainer_intervention_strength,scorer_texts,explainer_texts,predictiveness_scores,max_intervened_probs
547,182,6,1-2-3-4-0-9,-8.721251,"[ExplainerInterventionExample(prompt='Road', t...",0.231213,"[10, 32, 100, 320, 1000]",32,"[Project, March, 1, Sim, 2019, 2016, 1, A, Hou...","[Road, A, Value]","[-10.622151947021484, -10.456940460205079, -8....","[0.11679935455322266, 0.1188754290342331, 0.14..."
546,182,2,1234,-8.875147,"[ExplainerInterventionExample(prompt='Road', t...",0.226274,"[10, 32, 100, 320, 1000]",32,"[Project, March, 1, Sim, 2019, 2016, 1, A, Hou...","[Road, A, Value]","[-10.631166076660156, -10.523173522949218, -8....","[0.11783038079738617, 0.12202408164739609, 0.1..."
556,185,6,"2019, value, house",-8.937399,"[ExplainerInterventionExample(prompt='2019', t...",0.431593,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 200, 538, A, A, A, 1, ', 1]","[2019, Value, House]","[-10.816554260253906, -10.416304779052734, -9....","[0.11420270800590515, 0.11989623308181763, 0.1..."
555,185,2,2019,-8.978640,"[ExplainerInterventionExample(prompt='2019', t...",0.435408,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 200, 538, A, A, A, 1, ', 1]","[2019, Value, House]","[-10.726454925537109, -10.404013061523438, -9....","[0.11452289670705795, 0.12135178595781326, 0.1..."
183,61,2,2016,-9.364605,[ExplainerInterventionExample(prompt='Project'...,0.943791,"[10, 32, 100, 320, 1000]",32,"[Sim, 2019, G, Ay, A, A, A, 200, ', L]","[Project, G, 2016]","[-10.697615814208984, -10.711712646484376, -9....","[0.07175029814243317, 0.06015954166650772, 0.4..."
...,...,...,...,...,...,...,...,...,...,...,...,...
402,134,2,2019,-13.192244,"[ExplainerInterventionExample(prompt='2019', t...",0.992656,"[10, 32, 100, 320, 1000]",32,"[2016, Sim, 1, ', A, A, 1, X, A, H]","[2019, 200, House]","[-11.074691772460938, -11.361431121826172, -12...","[0.12653400003910065, 0.16054397821426392, 0.1..."
404,134,11,2019,-13.266629,"[ExplainerInterventionExample(prompt='2019', t...",0.962138,"[10, 32, 100, 320, 1000]",32,"[2016, Sim, 1, ', A, A, 1, X, A, H]","[2019, 200, House]","[-10.987806701660157, -11.11118392944336, -11....","[0.11369820684194565, 0.11930322647094727, 0.1..."
539,179,11,2019,-13.326119,"[ExplainerInterventionExample(prompt='2019', t...",0.962642,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 1, 200, A, A, 1, 538, A, Value]","[2019, ', X]","[-10.844153594970702, -10.966635131835938, -11...","[0.11189045011997223, 0.11463377624750137, 0.0..."
278,92,11,1-200,-13.544800,"[ExplainerInterventionExample(prompt='A', top_...",0.951868,"[10, 32, 100, 320, 1000]",32,"[2016, Road, L, 1, *, A, 538, Value, A, ']","[A, 200, 1]","[-10.737577056884765, -10.89888916015625, -11....","[0.11086613684892654, 0.1106131449341774, 0.08..."


In [15]:
all_df.iloc[0].intervention_examples

[ExplainerInterventionExample(prompt='Road', top_tokens=['-', '1', '2', '.', '/', '(', '4', 'R', '3', 'I'], top_p_increases=[0.02295052632689476, 0.009246092289686203, 0.007624265272170305, 0.006824463605880737, 0.006669001653790474, 0.003991384990513325, 0.0035209404304623604, 0.0030987514182925224, 0.003035462461411953, 0.0026108790189027786]),
 ExplainerInterventionExample(prompt='A', top_tokens=['-', '.', '1', '2', '/', '4', 'I', '0', '9', '3'], top_p_increases=[0.01804528199136257, 0.013529673218727112, 0.011766989715397358, 0.007346044294536114, 0.005036055110394955, 0.004025498405098915, 0.003730517579242587, 0.003692652564495802, 0.003685819683596492, 0.003536658128723502]),
 ExplainerInterventionExample(prompt='Value', top_tokens=['-', '.', '1', '2', '(', '/', '0', '4', '3', ':'], top_p_increases=[0.018474796786904335, 0.012588724493980408, 0.011959636583924294, 0.008491882123053074, 0.007572157308459282, 0.006167306564748287, 0.004818673245608807, 0.0037697781808674335, 0.003

: 

In [13]:
all_df = pd.DataFrame(all_results)
all_df = all_df.sort_values("predictiveness_score", ascending=False)
all_df.to_pickle(f"counterfactual_results/3layers_200feats.pkl")

In [17]:
all_df = pd.read_pickle(f"counterfactual_results/3layers_200feats.pkl")
all_df

Unnamed: 0,feat_idx,feat_layer,explanation,predictiveness_score,intervention_examples,max_intervened_prob,scorer_intervention_strengths,explainer_intervention_strength,scorer_texts,explainer_texts,predictiveness_scores,max_intervened_probs
547,182,6,1-2-3-4-0-9,-8.721251,"[ExplainerInterventionExample(prompt='Road', t...",0.231213,"[10, 32, 100, 320, 1000]",32,"[Project, March, 1, Sim, 2019, 2016, 1, A, Hou...","[Road, A, Value]","[-10.622151947021484, -10.456940460205079, -8....","[0.11679935455322266, 0.1188754290342331, 0.14..."
546,182,2,1234,-8.875147,"[ExplainerInterventionExample(prompt='Road', t...",0.226274,"[10, 32, 100, 320, 1000]",32,"[Project, March, 1, Sim, 2019, 2016, 1, A, Hou...","[Road, A, Value]","[-10.631166076660156, -10.523173522949218, -8....","[0.11783038079738617, 0.12202408164739609, 0.1..."
556,185,6,"2019, value, house",-8.937399,"[ExplainerInterventionExample(prompt='2019', t...",0.431593,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 200, 538, A, A, A, 1, ', 1]","[2019, Value, House]","[-10.816554260253906, -10.416304779052734, -9....","[0.11420270800590515, 0.11989623308181763, 0.1..."
555,185,2,2019,-8.978640,"[ExplainerInterventionExample(prompt='2019', t...",0.435408,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 200, 538, A, A, A, 1, ', 1]","[2019, Value, House]","[-10.726454925537109, -10.404013061523438, -9....","[0.11452289670705795, 0.12135178595781326, 0.1..."
183,61,2,2016,-9.364605,[ExplainerInterventionExample(prompt='Project'...,0.943791,"[10, 32, 100, 320, 1000]",32,"[Sim, 2019, G, Ay, A, A, A, 200, ', L]","[Project, G, 2016]","[-10.697615814208984, -10.711712646484376, -9....","[0.07175029814243317, 0.06015954166650772, 0.4..."
...,...,...,...,...,...,...,...,...,...,...,...,...
402,134,2,2019,-13.192244,"[ExplainerInterventionExample(prompt='2019', t...",0.992656,"[10, 32, 100, 320, 1000]",32,"[2016, Sim, 1, ', A, A, 1, X, A, H]","[2019, 200, House]","[-11.074691772460938, -11.361431121826172, -12...","[0.12653400003910065, 0.16054397821426392, 0.1..."
404,134,11,2019,-13.266629,"[ExplainerInterventionExample(prompt='2019', t...",0.962138,"[10, 32, 100, 320, 1000]",32,"[2016, Sim, 1, ', A, A, 1, X, A, H]","[2019, 200, House]","[-10.987806701660157, -11.11118392944336, -11....","[0.11369820684194565, 0.11930322647094727, 0.1..."
539,179,11,2019,-13.326119,"[ExplainerInterventionExample(prompt='2019', t...",0.962642,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 1, 200, A, A, 1, 538, A, Value]","[2019, ', X]","[-10.844153594970702, -10.966635131835938, -11...","[0.11189045011997223, 0.11463377624750137, 0.0..."
278,92,11,1-200,-13.544800,"[ExplainerInterventionExample(prompt='A', top_...",0.951868,"[10, 32, 100, 320, 1000]",32,"[2016, Road, L, 1, *, A, 538, Value, A, ']","[A, 200, 1]","[-10.737577056884765, -10.89888916015625, -11....","[0.11086613684892654, 0.1106131449341774, 0.08..."


In [None]:
from collections import Counter

# get explanation
def get_subject_logits(text, layer, intervention_strength=0.0, position=-1):
    for l in range(len(subject.transformer.h)):
        subject.transformer.h[l]._forward_hooks.clear()
    subject.transformer.h[layer].register_forward_hook(partial(intervene, intervention_strength=intervention_strength, position=-1))

    inputs = subject_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(device)
    with torch.inference_mode():
        outputs = subject(**inputs)

    return outputs.logits[0, -1, :]

In [98]:
intervention_examples = []
for text in explainer_texts:
    clean_logits = get_subject_logits(text, feat_layer, intervention_strength=0.0)
    intervened_logits = get_subject_logits(text, feat_layer, intervention_strength=explainer_intervention_strength)
    top_probs = (intervened_logits.softmax(dim=-1) - clean_logits.softmax(dim=-1)).topk(10)
    # top_logits = intervened_logits.topk(10)
    top_tokens = [subject_tokenizer.decode(i) for i in top_probs.indices]
    top_p_increases = top_probs.values.tolist()
    intervention_examples.append(
        ExplainerInterventionExample(
            prompt=text,
            top_tokens=top_tokens,
            top_p_increases=top_p_increases
        )
    )

neuron_prompter = ExplainerNeuronFormatter(
    intervention_examples=intervention_examples
)

# TODO: improve the few-shot examples
explainer_prompt = get_explainer_prompt(neuron_prompter, fs_examples)
explainer_input_ids = explainer_tokenizer(explainer_prompt, return_tensors="pt").input_ids.to(device)
with torch.inference_mode():
    samples = explainer.generate(explainer_input_ids, max_new_tokens=100, eos_token_id=explainer_tokenizer.encode("\n\n")[-1], num_return_sequences=10)[:, explainer_input_ids.shape[1]:]
explanations = Counter([explainer_tokenizer.decode(sample).split("\n\n")[0].strip() for sample in samples])
explanation = explanations.most_common(1)[0][0]
print(explanations)

Counter({'2020 Democratic candidates': 2, '2020 presidential candidates': 2, '2020 US presidential candidate Andrew Yang': 1, '2020 Democratic presidential candidates': 1, '2020 Democratic presidential candidate Andrew Yang': 1, '2020 presidential candidates Yang, Siren, Kali, Az, Li, and Karin': 1, '2020 presidential candidates Andrew Yang, Tulsi Gabbard, and Pete Buttigieg': 1, '2020 presidential candidate Andrew Yang': 1})


In [99]:
samples = explainer.generate(explainer_input_ids, max_new_tokens=100, eos_token_id=explainer_tokenizer.encode("\n\n")[-1], num_return_sequences=10)

In [100]:
print(explainer_tokenizer.decode(samples[2]))

<|begin_of_text|>We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.

For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.

The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.

Neuron 1
<PROMPT>My favorite food is</PROMPT>
Most increased tokens:'oranges' (+0.81),'bananas' (+0.09),'apples' (+0.02)

<PROMPT>Whenever I would see</PROMPT>
Most increased tokens:'fruit' (+0.09),'a' (+0.06),'apples' (+0.06),'red' (+0.5)

Explanation: fruits

Neuron 2
<PROMPT>Once upon a time</PROMPT>
Most increased tokens:'there was' (+0.22),'a' (+0.2),'a time' (+0.05)

Explanation: storytelling

Neuron 3
<PROMPT>He owned the watch for a long time. While he never said it was</PROMPT>
Most increased tokens:'hers' (+0.09),'hers' (+0.

In [126]:
from tqdm.auto import tqdm
import time

predictiveness_scores = []
max_intervened_probs = []
for scorer_intervention_strength in tqdm(scorer_intervention_strengths):
    
    predictiveness_score = torch.tensor(0.0, device=device)
    max_intervened_prob = 0.0
    total_inference_time = 0
    total_loop_time = 0
    for text in scorer_texts:
        inference_start = time.time()
        intervened_probs = get_subject_logits(text, feat_layer, intervention_strength=scorer_intervention_strength).softmax(dim=-1)
        max_intervened_prob = max(max_intervened_prob, intervened_probs.max().item())

        # get the explanation predictiveness
        scorer_predictiveness_prompt = get_scorer_predictiveness_prompt(text, explanation, few_shot_prompts, few_shot_explanations, few_shot_tokens)
        scorer_input_ids = scorer_tokenizer(scorer_predictiveness_prompt, return_tensors="pt").input_ids.to(device)
        with torch.inference_mode():
            scorer_logits = scorer(scorer_input_ids).logits[0, -1, :]
            scorer_logp = scorer_logits.log_softmax(dim=-1)
        inference_end = time.time()
        total_inference_time += inference_end - inference_start

        loop_start = time.time()

        predictiveness_score += (intervened_probs[subject_ids] * scorer_logp[scorer_ids]).sum()

        # Print tokens with high probability (if needed)
        high_prob_mask = intervened_probs > 0.05
        # high_prob_tokens = subject_tokenizer.convert_ids_to_tokens(high_prob_mask.nonzero().squeeze())
        # high_prob_values = intervened_probs[high_prob_mask]
        # for tok, val in zip(high_prob_tokens, high_prob_values):
        #     print(tok, val.item())

        loop_end = time.time()
        total_loop_time += loop_end - loop_start
    max_intervened_probs.append(max_intervened_prob)
    predictiveness_scores.append(predictiveness_score.item() / len(scorer_texts))
    
    print(f"Total inference time: {total_inference_time:.2f} seconds")
    print(f"Total innermost loop time: {total_loop_time:.2f} seconds")

predictiveness_score = sum(predictiveness_scores) / len(predictiveness_scores)
predictiveness_score

 20%|██        | 1/5 [00:00<00:01,  2.27it/s]

Total inference time: 0.43 seconds
Total innermost loop time: 0.00 seconds


 40%|████      | 2/5 [00:00<00:01,  2.27it/s]

Total inference time: 0.43 seconds
Total innermost loop time: 0.00 seconds


 60%|██████    | 3/5 [00:01<00:00,  2.26it/s]

Total inference time: 0.43 seconds
Total innermost loop time: 0.00 seconds


 80%|████████  | 4/5 [00:01<00:00,  2.26it/s]

Total inference time: 0.43 seconds
Total innermost loop time: 0.00 seconds


100%|██████████| 5/5 [00:02<00:00,  2.26it/s]

Total inference time: 0.43 seconds
Total innermost loop time: 0.00 seconds





-8.572526245117189

In [107]:
intervened_probs

tensor([1.3747e-08, 1.0227e-07, 6.0999e-09,  ..., 6.2786e-12, 7.6808e-13,
        1.2035e-07], device='cuda:0')

In [108]:
topk = intervened_probs.topk(10)
# topk = clean_logits.softmax(dim=-1).topk(10)
[(subject_tokenizer.decode(p[0]), p[1]) for p in list(zip(topk.indices, topk.values))]

[(' 1984', tensor(0.0322, device='cuda:0')),
 (' 1986', tensor(0.0288, device='cuda:0')),
 (' 1987', tensor(0.0286, device='cuda:0')),
 (' 1981', tensor(0.0269, device='cuda:0')),
 (' 1983', tensor(0.0249, device='cuda:0')),
 (' 1985', tensor(0.0237, device='cuda:0')),
 (' 1980', tensor(0.0235, device='cuda:0')),
 (' 1982', tensor(0.0225, device='cuda:0')),
 (' 1989', tensor(0.0219, device='cuda:0')),
 (' 1988', tensor(0.0212, device='cuda:0'))]

In [70]:
topk = intervened_probs.topk(10)
# topk = clean_logits.softmax(dim=-1).topk(10)
[(subject_tokenizer.decode(p[0]), p[1]) for p in list(zip(topk.indices, topk.values))]

[(' Islamic', tensor(0.3354, device='cuda:0')),
 (' Quran', tensor(0.2592, device='cuda:0')),
 ('abad', tensor(0.1282, device='cuda:0')),
 (' Sharia', tensor(0.0909, device='cuda:0')),
 ('uddin', tensor(0.0755, device='cuda:0')),
 (' holiest', tensor(0.0253, device='cuda:0')),
 (' Koran', tensor(0.0096, device='cuda:0')),
 (' Mecca', tensor(0.0094, device='cuda:0')),
 (' blasphemy', tensor(0.0068, device='cuda:0')),
 ('Islamic', tensor(0.0067, device='cuda:0'))]

In [71]:
sum(topk.values)

tensor(0.9468, device='cuda:0')

In [46]:
predictiveness_score

-8.2017333984375

In [None]:
predictiveness_scores

In [None]:
print(predictiveness_scores)

In [None]:
scorer_intervention_strengths