In [1]:
import torch
from sae_auto_interp.config import ExperimentConfig, FeatureConfig
from sae_auto_interp.features import (
    FeatureDataset,
    FeatureLoader
)
from sae_auto_interp.features.constructors import default_constructor
from sae_auto_interp.features.samplers import sample

feat_layer = 32
sae_model = "gemma/131k"
module = f".model.layers.{feat_layer}"
n_train, n_test, n_quantiles = 5, 10, 5
n_feats = 300
feature_dict = {f"{module}": torch.arange(0, n_feats)}
feature_cfg = FeatureConfig(width=131072, n_splits=5, max_examples=100000, min_examples=200)
experiment_cfg = ExperimentConfig(n_random=0, example_ctx_len=64, n_quantiles=5, n_examples_test=0, n_examples_train=n_train + n_test // n_quantiles, train_type="quantiles", test_type="even")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from sae_auto_interp.features import FeatureDataset
from functools import partial
import random

dataset = FeatureDataset(
        raw_dir=f"/mnt/ssd-1/gpaulo/SAE-Zoology/raw_features/{sae_model}",
        cfg=feature_cfg,
        modules=[module],
        features=feature_dict,
)

constructor=partial(
            default_constructor,
            tokens=dataset.tokens,
            n_random=experiment_cfg.n_random, 
            ctx_len=experiment_cfg.example_ctx_len, 
            max_examples=feature_cfg.max_examples
        )

sampler=partial(sample,cfg=experiment_cfg)
loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler)
    
record = next(iter(loader))


EleutherAI/rpj-v2-sample  train[:1%]


In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import torch
from huggingface_hub import hf_hub_download
import numpy as np

subject_device = "cuda:6"

subject_name = "google/gemma-2-9b"
subject = AutoModelForCausalLM.from_pretrained(subject_name).to(subject_device)
subject_tokenizer = AutoTokenizer.from_pretrained(subject_name)
subject_tokenizer.pad_token = subject_tokenizer.eos_token
subject.config.pad_token_id = subject_tokenizer.eos_token_id

Loading checkpoint shards: 100%|██████████| 8/8 [00:01<00:00,  7.39it/s]


In [5]:
scorer_device = "cuda:7"
scorer_name = "google/gemma-2-27b"
scorer = AutoModelForCausalLM.from_pretrained(
    scorer_name,
    device_map={"": scorer_device},
    torch_dtype=torch.bfloat16,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
)
scorer_tokenizer = AutoTokenizer.from_pretrained(scorer_name)
scorer_tokenizer.pad_token = scorer_tokenizer.eos_token
scorer.config.pad_token_id = scorer_tokenizer.eos_token_id
scorer.generation_config.pad_token_id = scorer_tokenizer.eos_token_id

# explainer is the same model as the scorer
explainer_device = scorer_device
explainer = scorer
explainer_tokenizer = scorer_tokenizer

Loading checkpoint shards: 100%|██████████| 24/24 [00:26<00:00,  1.08s/it]


In [4]:
from dataclasses import dataclass
import copy

@dataclass
class ExplainerInterventionExample:
    prompt: str
    top_tokens: list[str]
    top_p_increases: list[float]

    def __post_init__(self):
        self.prompt = self.prompt.replace("\n", "\\n")

    def text(self) -> str:
        tokens_str = ", ".join(f"'{tok}' (+{round(p, 3)})" for tok, p in zip(self.top_tokens, self.top_p_increases))
        return f"<PROMPT>{self.prompt}</PROMPT>\nMost increased tokens: {tokens_str}"
    
@dataclass
class ExplainerNeuronFormatter:
    intervention_examples: list[ExplainerInterventionExample]
    explanation: str | None = None

    def text(self) -> str:
        text = "\n\n".join(example.text() for example in self.intervention_examples)
        text += "\n\nExplanation:"
        if self.explanation is not None:
            text += " " + self.explanation
        return text


def get_explainer_prompt(neuron_prompter: ExplainerNeuronFormatter, few_shot_examples: list[ExplainerNeuronFormatter] | None = None) -> str:
    prompt = "We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.\n\n" \
        "For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.\n\n" \
        "The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.\n\n"
    
    i = 1
    for few_shot_example in few_shot_examples or []:
        assert few_shot_example.explanation is not None
        prompt += f"Neuron {i}\n" + few_shot_example.text() + "\n\n"
        i += 1

    prompt += f"Neuron {i}\n"
    prompt += neuron_prompter.text()

    return prompt


fs_examples = [
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="My favorite food is",
                top_tokens=[" oranges", " bananas", " apples"],
                top_p_increases=[0.81, 0.09, 0.02]
            ),
            ExplainerInterventionExample(
                prompt="Whenever I would see",
                top_tokens=[" fruit", " a", " apples", " red"],
                top_p_increases=[0.09, 0.06, 0.06, 0.05]
            ),
            ExplainerInterventionExample(
                prompt="I like to eat",
                top_tokens=[" fro", " fruit", " oranges", " bananas", " strawberries"],
                top_p_increases=[0.14, 0.13, 0.11, 0.10, 0.03]
            )
        ],
        explanation="fruits"
    ),
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="Once",
                top_tokens=[" upon", " in", " a", " long"],
                top_p_increases=[0.22, 0.2, 0.05, 0.04]
            ),
            ExplainerInterventionExample(
                prompt="Ryan Quarles\\n\\nRyan Francis Quarles (born October 20, 1983)",
                top_tokens=[" once", " happily", " for"],
                top_p_increases=[0.03, 0.31, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="MSI Going Full Throttle @ CeBIT",
                top_tokens=[" Once", " once", " in", " the", " a", " The"],
                top_p_increases=[0.02, 0.01, 0.01, 0.01, 0.01, 0.01]
            ),
        ],
        explanation="storytelling"
    ),
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="Given 4x is less than 10,",
                top_tokens=[" 4", " 10", " 40", " 2"],
                top_p_increases=[0.11, 0.04, 0.02, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="For some reason",
                top_tokens=[" one", " 1", " fr"],
                top_p_increases=[0.14, 0.01, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="insurance does not cover claims for accounts with",
                top_tokens=[" one", " more", " 10"],
                top_p_increases=[0.10, 0.02, 0.01]
            )
        ],
        explanation="numbers"
    )
]

neuron_prompter = copy.deepcopy(fs_examples[0])
neuron_prompter.explanation = None
print(get_explainer_prompt(neuron_prompter, fs_examples))


We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.

For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.

The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.

Neuron 1
<PROMPT>My favorite food is</PROMPT>
Most increased tokens: ' oranges' (+0.81), ' bananas' (+0.09), ' apples' (+0.02)

<PROMPT>Whenever I would see</PROMPT>
Most increased tokens: ' fruit' (+0.09), ' a' (+0.06), ' apples' (+0.06), ' red' (+0.05)

<PROMPT>I like to eat</PROMPT>
Most increased tokens: ' fro' (+0.14), ' fruit' (+0.13), ' oranges' (+0.11), ' bananas' (+0.1), ' strawberries' (+0.03)

Explanation: fruits

Neuron 2
<PROMPT>Once</PROMPT>
Most increased tokens: ' upon' (+0.22), ' in' (+0.2), ' a' (+0.05), ' long' (+0.04)

<PROMPT>Ryan

In [7]:
def get_scorer_simplicity_prompt(explanation):
    prefix = "Explanation\n\n"
    return f"{prefix}{explanation}{scorer_tokenizer.eos_token}", prefix

def get_scorer_predictiveness_prompt(prompt, explanation, few_shot_prompts=None, few_shot_explanations=None, few_shot_tokens=None):
    if few_shot_explanations is not None:
        assert few_shot_tokens is not None and few_shot_prompts is not None
        assert len(few_shot_explanations) == len(few_shot_tokens) == len(few_shot_prompts)
        few_shot_prompt = "\n\n".join(get_scorer_predictiveness_prompt(pr, expl) + token for pr, expl, token in zip(few_shot_prompts, few_shot_explanations, few_shot_tokens)) + "\n\n"
    else:
        few_shot_prompt = ""
    return few_shot_prompt + f"Explanation: {explanation}\n<PROMPT>{prompt}</PROMPT>"

few_shot_prompts = ["My favorite food is", "From west to east, the westmost of the seven", "Given 4x is less than 10,"]
few_shot_explanations = ["fruits and vegetables", "ateg", "numbers"]
few_shot_tokens = [" oranges", "WAY", " 4"]
print(get_scorer_predictiveness_prompt(few_shot_prompts[0], few_shot_explanations[0], few_shot_prompts, few_shot_explanations, few_shot_tokens))


Explanation: fruits and vegetables
<PROMPT>My favorite food is</PROMPT> oranges

Explanation: ateg
<PROMPT>From west to east, the westmost of the seven</PROMPT>WAY

Explanation: numbers
<PROMPT>Given 4x is less than 10,</PROMPT> 4

Explanation: fruits and vegetables
<PROMPT>My favorite food is</PROMPT>


In [8]:
from functools import partial
def intervene(module, input, output, intervention_strength=10.0, position=-1, feat=None):
    hiddens = output[0]  # the later elements of the tuple are the key value cache
    hiddens[:, position, :] += intervention_strength * feat.to(hiddens.device)  # type: ignore

In [9]:
subject_layers = subject.model.layers

In [10]:
n_intervention_tokens = 5
scorer_intervention_strengths = [0, 10, 32, 100, 320, 1000]
explainer_intervention_strength = 32

path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-9b-pt-res",
    filename=f"layer_{feat_layer}/width_131k/average_l0_51/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(subject_device) for k, v in params.items()}


def get_encoder_decoder_weights(feat_idx, device, random_resid_direction):
    encoder_feat = pt_params["W_enc"][feat_idx, :]
    decoder_feat = pt_params["W_dec"][feat_idx, :]
    if random_resid_direction:
        decoder_feat = torch.randn_like(decoder_feat)
    return encoder_feat, decoder_feat


def garbage_collect():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        print("CUDA garbage collection performed.")


In [11]:
assert subject_tokenizer.get_vocab() == scorer_tokenizer.get_vocab()

In [12]:
from tqdm.auto import tqdm
from itertools import product, islice
import time
from collections import Counter
import pandas as pd

random_resid_direction = False  # this is a random baseline
save_path = f"counterfactual_results/neg_{subject_name.split('/')[-1]}_{feat_layer}layer_{n_feats}feats{'_random_dir' if random_resid_direction else ''}.json"
all_results = []
n_explanations = 5
n_scorer_texts = 10
n_explainer_texts = 7


for iter, record in enumerate(tqdm(islice(loader, 10))):
    garbage_collect()
    
    feat_idx = record.feature.feature_index
    print("Loading autoencoder...", end="")
    encoder_feat, decoder_feat = get_encoder_decoder_weights(feat_idx, subject_device, random_resid_direction)

    ### Find examples where the feature activates
    # Remove any hooks
    for l in range(len(subject_layers)):
        subject_layers[l]._forward_hooks.clear()
    print("done")

    random.shuffle(record.train)
    scorer_texts = [subject_tokenizer.decode(e.tokens) for e in record.train[:n_scorer_texts]]
    explainer_texts = [subject_tokenizer.decode(e.tokens) for e in record.train[:n_explainer_texts]]
    
    # get explanation
    print("Getting explanations...")
    def get_subject_logits(text, layer, intervention_strength=0.0, position=-1, feat=None):
        for l in range(len(subject_layers)):
            subject_layers[l]._forward_hooks.clear()
        subject_layers[layer].register_forward_hook(partial(intervene, intervention_strength=intervention_strength, position=-1, feat=feat))

        inputs = subject_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(subject_device)
        with torch.inference_mode():
            outputs = subject(**inputs)

        return outputs.logits[0, -1, :]

    explainer_time = time.time()
    intervention_examples = []
    for text in explainer_texts:
        clean_logits = get_subject_logits(text, feat_layer, intervention_strength=0.0, feat=decoder_feat)
        intervened_logits = get_subject_logits(text, feat_layer, intervention_strength=explainer_intervention_strength, feat=decoder_feat)
        top_probs = (intervened_logits.softmax(dim=-1) - clean_logits.softmax(dim=-1)).topk(n_intervention_tokens)
        
        top_tokens = [subject_tokenizer.decode(i) for i in top_probs.indices]
        top_p_increases = top_probs.values.tolist()
        intervention_examples.append(
            ExplainerInterventionExample(
                prompt=text,
                top_tokens=top_tokens,
                top_p_increases=top_p_increases
            )
        )

    neuron_prompter = ExplainerNeuronFormatter(
        intervention_examples=intervention_examples
    )

    # TODO: improve the few-shot examples
    explainer_prompt = get_explainer_prompt(neuron_prompter, fs_examples)
    explainer_input_ids = explainer_tokenizer(explainer_prompt, return_tensors="pt").input_ids.to(explainer_device)
    with torch.inference_mode():
        samples = explainer.generate(
            explainer_input_ids,
            max_new_tokens=20,
            eos_token_id=explainer_tokenizer.encode("\n")[-1],
            num_return_sequences=n_explanations,
            temperature=0.7,
            do_sample=True,
        )[:, explainer_input_ids.shape[1]:]
    explanations = Counter([explainer_tokenizer.decode(sample).split("\n")[0].strip() for sample in samples])
    explainer_time = time.time() - explainer_time
    print(f"Explainer took {explainer_time:.2f} seconds")

    for ie in intervention_examples:
        print(ie.top_tokens)
        print(ie.top_p_increases)
    print(explanations)

    scoring_time = time.time()
    predictiveness_score_by_explanation = dict()
    normalized_predictiveness_score_by_explanation = dict()
    all_pred_scores = []
    scoring_interventions = dict()
    for explanation in explanations:
        expl_predictiveness_scores = []
        scoring_interventions[explanation] = dict()
        for scorer_intervention_strength in tqdm(scorer_intervention_strengths):
            
            current_pred_scores = []
            max_intervened_prob = 0.0
            scoring_interventions[explanation][scorer_intervention_strength] = []
            for text in scorer_texts:
                
                intervened_probs = get_subject_logits(text, feat_layer, intervention_strength=scorer_intervention_strength, feat=decoder_feat).softmax(dim=-1).to(scorer_device)

                # get the explanation predictiveness
                scorer_predictiveness_prompt = get_scorer_predictiveness_prompt(text, explanation, few_shot_prompts, few_shot_explanations, few_shot_tokens)
                scorer_input_ids = scorer_tokenizer(scorer_predictiveness_prompt, return_tensors="pt").input_ids.to(scorer_device)
                with torch.inference_mode():
                    scorer_logits = scorer(scorer_input_ids).logits[0, -1, :]
                    scorer_logp = scorer_logits.log_softmax(dim=-1)
                
                current_pred_scores.append((intervened_probs * scorer_logp).sum())

                topk = intervened_probs.topk(n_intervention_tokens).indices
                top_tokens = [subject_tokenizer.decode(i) for i in topk]
                scoring_interventions[explanation][scorer_intervention_strength].append({
                    "prompt": text,
                    "top_tokens": top_tokens,
                    "top_token_probs": intervened_probs[topk].tolist()
                })

            expl_predictiveness_scores.append(torch.tensor(current_pred_scores).mean().item())
            all_pred_scores.extend(current_pred_scores * explanations[explanation])  # as if we did the inference on the scorer multiple times
    
        assert scorer_intervention_strengths[0] == 0
        null_predictiveness_score = expl_predictiveness_scores[0]
        normalized_predictiveness_scores = [score - null_predictiveness_score for score in expl_predictiveness_scores[1:]]
        normalized_predictiveness_score = sum(normalized_predictiveness_scores) / len(normalized_predictiveness_scores)
        predictiveness_score = normalized_predictiveness_score + null_predictiveness_score
        
        predictiveness_score_by_explanation[explanation] = predictiveness_score
        normalized_predictiveness_score_by_explanation[explanation] = normalized_predictiveness_score
    
    # note that this computes stderr over explanations, pile samples, *and* intervention strengths (which is kind of weird)
    pred_score_stderr = torch.std(torch.tensor(all_pred_scores)).item() / len(all_pred_scores) ** 0.5
    pred_score = torch.mean(torch.tensor(all_pred_scores)).item()
    normalized_predictiveness_score = sum([normalized_predictiveness_score_by_explanation[explanation] * count for explanation, count in explanations.items()]) / sum(explanations.values())

    scoring_time = time.time() - scoring_time
    print(f"Scoring took {scoring_time:.2f} seconds")

    print(f"{normalized_predictiveness_score=}")
    print()
    print()
    all_results.append({
        "feat_idx": feat_idx,
        "feat_layer": feat_layer,
        "explanations": dict(explanations),
        "predictiveness_score": pred_score,
        "normalized_predictiveness_score": normalized_predictiveness_score,
        "predictiveness_score_stderr": pred_score_stderr,
        "max_predictiveness_score": max(predictiveness_score_by_explanation.values()),
        "max_normalized_predictiveness_score": max(normalized_predictiveness_score_by_explanation.values()),
        "explainer_prompts": [example.prompt for example in intervention_examples],
        "explainer_top_tokens": [example.top_tokens for example in intervention_examples],
        "explainer_top_p_increases": [example.top_p_increases for example in intervention_examples],
        "scorer_intervention_strengths": scorer_intervention_strengths,
        "explainer_intervention_strength": explainer_intervention_strength,
        "scorer_texts": scorer_texts,
        "explainer_texts": explainer_texts,
        "predictiveness_score_by_explanation": predictiveness_score_by_explanation,
        "normalized_predictiveness_score_by_explanation": normalized_predictiveness_score_by_explanation,
        "scoring_interventions": scoring_interventions
    })
    if (iter - 1) % 10 == 0:
        all_df = pd.DataFrame(all_results)
        all_df = all_df.sort_values("predictiveness_score", ascending=False)
        all_df.to_json(save_path)
all_df = pd.DataFrame(all_results)
all_df = all_df.sort_values("predictiveness_score", ascending=False)
all_df.to_json(save_path)


0it [00:00, ?it/s]

CUDA garbage collection performed.
Loading autoencoder...done
Getting explanations...
Explainer took 10.49 seconds
[' glass', ' windows', ' shop', ' door', ' side']
[0.007918842136859894, 0.0029396452009677887, 0.000982820987701416, 0.0006711115129292011, 0.0006391257047653198]
[' degrees', '8', '9', '2', '6']
[1.1867938155774027e-05, 6.884278263896704e-06, 6.822760042268783e-06, 6.046495400369167e-06, 5.583569873124361e-06]
['“', 'The', 'A', '"', 'On']
[0.005751661956310272, 0.004830658435821533, 0.0012094564735889435, 0.0006070331437513232, 0.0004717400297522545]
[' "', " '", '"', ' E', ' “']
[0.018843382596969604, 0.00157957524061203, 0.0013330169022083282, 0.0009814589284360409, 0.0006831996142864227]
[' have', ' say', ' mention', ' "', ' LOOK']
[0.01625191420316696, 0.00334613467566669, 0.0020623994059860706, 0.0009762971894815564, 0.0008321376517415047]
Counter({'quotes': 3, 'quotation marks': 1, 'symbols': 1})


100%|██████████| 6/6 [00:26<00:00,  4.37s/it]
100%|██████████| 6/6 [00:26<00:00,  4.38s/it]
100%|██████████| 6/6 [00:26<00:00,  4.37s/it]
1it [01:33, 93.20s/it]

Scoring took 78.73 seconds
normalized_predictiveness_score=-1.1835374450683591


CUDA garbage collection performed.
Loading autoencoder...done
Getting explanations...
Explainer took 8.10 seconds
[' the', ' an', ' a', ' another', ' ']
[0.004625275731086731, 0.0022381246089935303, 0.0012402534484863281, 0.0008356817997992039, 0.00019913632422685623]
['1', '2', ' ', '5', '6']
[2.7468049665912986e-06, 1.1346000974299386e-06, 3.088625817326829e-07, 2.910569492087234e-07, 2.079577825497836e-07]
[' an', ' the', ' clear', ' either', ' any']
[0.020979493856430054, 0.0003308332525193691, 0.00026065949350595474, 0.00018377765081822872, 0.00017113890498876572]
[' led', ' tested', ' put', ' kept', ' made']
[0.0008836984634399414, 0.0007538255304098129, 0.0007409974932670593, 0.0006638877093791962, 0.0005448833107948303]
[' in', ' for', ' against', ' on', ' and']
[0.008508801460266113, 0.0022646524012088776, 0.0021819742396473885, 0.0011248812079429626, 0.0010835528373718262]
Counter({'numbers': 5})

100%|██████████| 6/6 [00:26<00:00,  4.38s/it]
2it [02:07, 58.63s/it]

Scoring took 26.26 seconds
normalized_predictiveness_score=-1.1469223976135254


CUDA garbage collection performed.
Loading autoencoder...done
Getting explanations...
Explainer took 9.33 seconds
[' the', ' Canonical', ' Can', ' ', ' CAN']
[0.019162297248840332, 0.018479973077774048, 0.005154597572982311, 0.0035864152014255524, 0.0027160393074154854]
[' relevant', ' means', ' official', ' interested', ' competent']
[0.005846090614795685, 0.005589604377746582, 0.00272214412689209, 0.0020774826407432556, 0.0019824784249067307]
[' the', ' that', ' last', ' they', ' on']
[0.003691956400871277, 0.003122478723526001, 0.001771455630660057, 0.0010092006996273994, 0.001000954769551754]
[' on', ' outside', ' through', ' before', ' after']
[0.0066987499594688416, 0.0026769600808620453, 0.002449234016239643, 0.0008801314979791641, 0.0008492367342114449]
[' rights', ' right', ' discrimination', ' issues', ' detention']
[0.0014113187789916992, 0.00037645664997398853, 0.000369891757145524, 0.000107859

100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.38s/it]
100%|██████████| 6/6 [00:26<00:00,  4.40s/it]
100%|██████████| 6/6 [00:26<00:00,  4.38s/it]
3it [04:29, 96.63s/it]

Scoring took 131.65 seconds
normalized_predictiveness_score=-0.32393417358398435


CUDA garbage collection performed.
Loading autoencoder...done
Getting explanations...
Explainer took 9.83 seconds
[' did', ' made', ' created', ' finished', ' got']
[0.002987705171108246, 0.0021089166402816772, 0.0013878047466278076, 9.005784522742033e-05, 8.70607327669859e-05]
[' very', ' nice', ' neat', ' cool', ' great']
[0.010700983926653862, 0.007432856597006321, 0.004208534490317106, 0.004178110975772142, 0.002053816569969058]
[' important', ' exciting', ' memorable', ' special', ' adventurous']
[0.014849811792373657, 0.011626994237303734, 0.009479421190917492, 0.007749274373054504, 0.00735454261302948]
[' most', ' biggest', ' big', ' event', ' fights']
[0.006460733711719513, 0.004815317690372467, 0.004628919064998627, 0.001966264098882675, 0.0017177388072013855]
['8', '9', '7', '4', '...']
[0.0005819275975227356, 0.0005178079009056091, 0.00015772134065628052, 3.2782554626464844e-06, 1.360458554700

100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.38s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.37s/it]
4it [06:51, 114.47s/it]

Scoring took 131.58 seconds
normalized_predictiveness_score=-1.237775173187256


CUDA garbage collection performed.
Loading autoencoder...done
Getting explanations...
Explainer took 8.87 seconds
[' been', ' come', ' not', ' certainly', ' taken']
[0.005828842520713806, 0.0024400316178798676, 0.0017267465591430664, 0.0004254784435033798, 0.0003960728645324707]
[' invaded', ' had', ' imposed', ' intervened', ' Embassy']
[0.004554763436317444, 0.0028827209025621414, 0.0017738398164510727, 0.0015324330888688564, 0.0013124910183250904]
[' from', ' at', ' nor', ' fro', '\n']
[0.02262893319129944, 4.060578066855669e-05, 2.462854899931699e-05, 2.3687243810854852e-05, 1.7537677194923162e-05]
[' of', ' men', ' Read', ' women', ' singles']
[0.003221750259399414, 0.00229780375957489, 0.0010656416416168213, 0.0007863305509090424, 0.0005323849618434906]
[' on', ' after', ' into', ' to', ' over']
[0.0038809701800346375, 0.0027465522289276123, 0.002558339387178421, 0.0024374723434448242, 0.000534273684

100%|██████████| 6/6 [00:26<00:00,  4.40s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.37s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
5it [09:11, 123.91s/it]

Scoring took 131.66 seconds
normalized_predictiveness_score=-0.8909709739685059


CUDA garbage collection performed.
Loading autoencoder...done
Getting explanations...
Explainer took 8.28 seconds
['7', '4', '5', '3', '8']
[0.0010830536484718323, 0.000489942729473114, 0.0004773437976837158, 0.0003250911831855774, 0.0002501755952835083]
[',', '/', ' Jerry', '.', 'Jerry']
[0.005631685256958008, 0.0052846819162368774, 0.0021593626588582993, 0.0021180715411901474, 0.0012285616248846054]
['.', ' on', ' at', ' Feb', ' Sept']
[0.019575044512748718, 0.005126446485519409, 0.0025830641388893127, 0.0010011615231633186, 0.0006754865171387792]
['Appel', '.', ' Appell', ' Karel', ' Cornel']
[0.0005667319055646658, 0.00018899468705058098, 0.0001870493870228529, 0.00017209292855113745, 0.00015883834566920996]
[' a', ' on', ' probably', ' working', ' also']
[0.003400757908821106, 0.0012211054563522339, 0.0007490422576665878, 0.0006103911437094212, 0.0005946112796664238]
Counter({'numbers': 2, 'dates': 1

100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
6it [11:05, 120.45s/it]

Scoring took 105.38 seconds
normalized_predictiveness_score=-1.0986333847045897


CUDA garbage collection performed.
Loading autoencoder...done
Getting explanations...
Explainer took 11.77 seconds
[' in', ' to', ' for', ' so', ' (']
[0.003654271364212036, 0.0030960291624069214, 0.0013840124011039734, 0.00028721895068883896, 0.00025981664657592773]
[' in', ' top', ' timely', ' relevant', ' hot']
[0.006657872349023819, 0.002183586359024048, 0.001974262297153473, 0.0017239898443222046, 0.0009134896099567413]
[' the', ' being', ' running', ' not', ' keeping']
[0.007752776145935059, 0.0037232227623462677, 0.0005478765815496445, 0.0004093702882528305, 0.00037085823714733124]
[' singing', ' talent', ' way', ' gift', ' knack']
[0.005230717360973358, 0.0038839876651763916, 0.0017700418829917908, 0.0017595961689949036, 0.001476682722568512]
[' s', ' ratings', '€™', ' cable', ' premium']
[1.0356961865909398e-06, 9.642499207984656e-07, 8.060587788349949e-07, 6.921766271261731e-07, 4.78220272270846

100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
100%|██████████| 6/6 [00:26<00:00,  4.43s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]
7it [12:36, 110.86s/it]

Scoring took 79.31 seconds
normalized_predictiveness_score=-0.3431140899658203


CUDA garbage collection performed.
Loading autoencoder...done
Getting explanations...
Explainer took 10.47 seconds
['4', '1', '2', '.', ' Δ']
[0.002171456813812256, 0.00010591931641101837, 6.465800106525421e-05, 1.7496466170996428e-05, 6.647926056757569e-06]
['playing', 'player', 'Playing', 'plays', 'p']
[0.013190865516662598, 0.00018440000712871552, 0.0001676729880273342, 2.886043512262404e-05, 1.995847560465336e-05]
[' (', ',', ' of', 'berg', 'eaux']
[0.006079591810703278, 0.004985928535461426, 0.001190527225844562, 0.0006750426255166531, 0.0002493110951036215]
[' childhood', ' youth', ' infancy', ' age', ' tender']
[0.01592208445072174, 0.00627363845705986, 0.005018085241317749, 0.00048429612070322037, 0.0003074272535741329]
[' because', ' (', '."', ' the', '-']
[0.0009946245700120926, 0.00031023938208818436, 0.0002858508378267288, 0.00020172609947621822, 0.00019023241475224495]
Counter({'numbers': 1, '

100%|██████████| 6/6 [00:26<00:00,  4.40s/it]
100%|██████████| 6/6 [00:26<00:00,  4.40s/it]
100%|██████████| 6/6 [00:26<00:00,  4.40s/it]
100%|██████████| 6/6 [00:26<00:00,  4.39s/it]


In [1]:
import pandas as pd
all_df = pd.read_json("counterfactual_results/gemma-2-9b_32layer_300feats.json")

In [5]:
x = all_df.sort_values("max_normalized_predictiveness_score", ascending=False)
for i in list(range(5)) + [9, 49, 99, 499, 999]:

    row = x.iloc[i]
    best_expl = max(row.normalized_predictiveness_score_by_explanation, key=row.normalized_predictiveness_score_by_explanation.get)
    print(f"({i + 1}/{len(x)}) \"{best_expl}\" | layer {row.feat_layer}, idx {row.feat_idx}")
    print(f"\texplanations: {row.explanations}")
    print(f"\tnormalized predictiveness score: {row.normalized_predictiveness_score:.1f} ± {row.predictiveness_score_stderr:.1f}")
    print(f"\tmax normalized predictiveness score: {row.max_normalized_predictiveness_score:.1f} (expl={best_expl})")
    print("\n\tExample counterfactuals (intervention strength = 32):")
    for pr, tts, tpis in zip(row["explainer_prompts"], row["explainer_top_tokens"], row["explainer_top_p_increases"]):
        ex = ExplainerInterventionExample(
                    prompt=pr,
                    top_tokens=tts,
                    top_p_increases=tpis
                )
        print("\t" + ex.text().replace("\nMost", "\n\t\tMost") + "\n")

(1/214) "prepositional phrases (mostly "in")" | layer 32, idx 36
	explanations: {'spatial': 1, 'prepositions': 3, 'prepositional phrases (mostly "in")': 1}
	normalized predictiveness score: 1.3 ± 0.1
	max normalized predictiveness score: 1.7 (expl=prepositional phrases (mostly "in"))

	Example counterfactuals (intervention strength = 32):
	<PROMPT>123»BEIRUT—Syrian rebels clashed with regime troops in the narrow stone alleyways around a historic 12th century mosque in the Old City of Aleppo on Thursday, while a government airstrike north of the city killed at least seven people, activists said. The rebels, who have been slowly</PROMPT>
		Most increased tokens: ' en' (+0.009), ' surrounding' (+0.008), ' encro' (+0.006), ' taking' (+0.004), ' over' (+0.003)

	<PROMPT><bos> coverage. Though adjustable, each unit easily conceals the 161-square-foot area typically allotted per family.\nBan, paired with his students from Tokyo, loaded the disassembled parts into a van and headed north to sev

IndexError: single positional indexer is out-of-bounds