In [7]:
import torch
from sae_auto_interp.config import ExperimentConfig, FeatureConfig
from sae_auto_interp.features import (
    FeatureDataset,
    FeatureLoader
)
from sae_auto_interp.features.constructors import default_constructor
from sae_auto_interp.features.samplers import sample

feat_layer = 32
sae_model = "gemma/131k"
module = f".model.layers.{feat_layer}"
n_train, n_test, n_quantiles = 5, 40, 5
n_feats = 30
feature_dict = {f"{module}": torch.arange(0, n_feats)}
feature_cfg = FeatureConfig(width=131072, n_splits=5, max_examples=100000, min_examples=200)
experiment_cfg = ExperimentConfig(n_random=0, example_ctx_len=64, n_quantiles=5, n_examples_test=0, n_examples_train=n_train + n_test // n_quantiles, train_type="quantiles", test_type="even")

In [8]:
from sae_auto_interp.features import FeatureDataset
from functools import partial
import random

dataset = FeatureDataset(
        raw_dir=f"/mnt/ssd-1/gpaulo/SAE-Zoology/raw_features/{sae_model}",
        cfg=feature_cfg,
        modules=[module],
        features=feature_dict,
)

constructor=partial(
            default_constructor,
            tokens=dataset.tokens,
            n_random=experiment_cfg.n_random, 
            ctx_len=experiment_cfg.example_ctx_len, 
            max_examples=feature_cfg.max_examples
        )

sampler=partial(sample,cfg=experiment_cfg)
loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler)


EleutherAI/rpj-v2-sample  train[:1%]


In [9]:
[i.feature.feature_index for i in loader]

[0, 2, 3, 6, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 22, 23, 24, 25, 26, 27]

In [10]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import torch
from huggingface_hub import hf_hub_download
import numpy as np

subject_device = "cuda:6"

subject_name = "google/gemma-2-9b"
subject = AutoModelForCausalLM.from_pretrained(subject_name).to(subject_device)
subject_tokenizer = AutoTokenizer.from_pretrained(subject_name)
subject_tokenizer.pad_token = subject_tokenizer.eos_token
subject.config.pad_token_id = subject_tokenizer.eos_token_id
subject_layers = subject.model.layers

Loading checkpoint shards: 100%|██████████| 8/8 [00:00<00:00,  8.42it/s]


In [11]:
scorer_device = "cuda:7"
scorer_name = "google/gemma-2-27b"
scorer = AutoModelForCausalLM.from_pretrained(
    scorer_name,
    device_map={"": scorer_device},
    torch_dtype=torch.bfloat16,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
)
scorer_tokenizer = AutoTokenizer.from_pretrained(scorer_name)
scorer_tokenizer.pad_token = scorer_tokenizer.eos_token
scorer.config.pad_token_id = scorer_tokenizer.eos_token_id
scorer.generation_config.pad_token_id = scorer_tokenizer.eos_token_id

# explainer is the same model as the scorer
explainer_device = scorer_device
explainer = scorer
explainer_tokenizer = scorer_tokenizer

Loading checkpoint shards: 100%|██████████| 24/24 [00:19<00:00,  1.23it/s]


In [12]:
from dataclasses import dataclass
import copy

@dataclass
class ExplainerInterventionExample:
    prompt: str
    top_tokens: list[str]
    top_p_increases: list[float]

    def __post_init__(self):
        self.prompt = self.prompt.replace("\n", "\\n")

    def text(self) -> str:
        tokens_str = ", ".join(f"'{tok}' (+{round(p, 3)})" for tok, p in zip(self.top_tokens, self.top_p_increases))
        return f"<PROMPT>{self.prompt}</PROMPT>\nMost increased tokens: {tokens_str}"
    
@dataclass
class ExplainerNeuronFormatter:
    intervention_examples: list[ExplainerInterventionExample]
    explanation: str | None = None

    def text(self) -> str:
        text = "\n\n".join(example.text() for example in self.intervention_examples)
        text += "\n\nExplanation:"
        if self.explanation is not None:
            text += " " + self.explanation
        return text


def get_explainer_prompt(neuron_prompter: ExplainerNeuronFormatter, few_shot_examples: list[ExplainerNeuronFormatter] | None = None) -> str:
    prompt = "We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.\n\n" \
        "For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.\n\n" \
        "The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.\n\n"
    
    i = 1
    for few_shot_example in few_shot_examples or []:
        assert few_shot_example.explanation is not None
        prompt += f"Neuron {i}\n" + few_shot_example.text() + "\n\n"
        i += 1

    prompt += f"Neuron {i}\n"
    prompt += neuron_prompter.text()

    return prompt


fs_examples = [
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="My favorite food is",
                top_tokens=[" oranges", " bananas", " apples"],
                top_p_increases=[0.81, 0.09, 0.02]
            ),
            ExplainerInterventionExample(
                prompt="Whenever I would see",
                top_tokens=[" fruit", " a", " apples", " red"],
                top_p_increases=[0.09, 0.06, 0.06, 0.05]
            ),
            ExplainerInterventionExample(
                prompt="I like to eat",
                top_tokens=[" fro", " fruit", " oranges", " bananas", " strawberries"],
                top_p_increases=[0.14, 0.13, 0.11, 0.10, 0.03]
            )
        ],
        explanation="fruits"
    ),
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="Once",
                top_tokens=[" upon", " in", " a", " long"],
                top_p_increases=[0.22, 0.2, 0.05, 0.04]
            ),
            ExplainerInterventionExample(
                prompt="Ryan Quarles\\n\\nRyan Francis Quarles (born October 20, 1983)",
                top_tokens=[" once", " happily", " for"],
                top_p_increases=[0.03, 0.31, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="MSI Going Full Throttle @ CeBIT",
                top_tokens=[" Once", " once", " in", " the", " a", " The"],
                top_p_increases=[0.02, 0.01, 0.01, 0.01, 0.01, 0.01]
            ),
        ],
        explanation="storytelling"
    ),
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="Given 4x is less than 10,",
                top_tokens=[" 4", " 10", " 40", " 2"],
                top_p_increases=[0.11, 0.04, 0.02, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="For some reason",
                top_tokens=[" one", " 1", " fr"],
                top_p_increases=[0.14, 0.01, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="insurance does not cover claims for accounts with",
                top_tokens=[" one", " more", " 10"],
                top_p_increases=[0.10, 0.02, 0.01]
            )
        ],
        explanation="numbers"
    )
]

neuron_prompter = copy.deepcopy(fs_examples[0])
neuron_prompter.explanation = None
print(get_explainer_prompt(neuron_prompter, fs_examples))


We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.

For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.

The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.

Neuron 1
<PROMPT>My favorite food is</PROMPT>
Most increased tokens: ' oranges' (+0.81), ' bananas' (+0.09), ' apples' (+0.02)

<PROMPT>Whenever I would see</PROMPT>
Most increased tokens: ' fruit' (+0.09), ' a' (+0.06), ' apples' (+0.06), ' red' (+0.05)

<PROMPT>I like to eat</PROMPT>
Most increased tokens: ' fro' (+0.14), ' fruit' (+0.13), ' oranges' (+0.11), ' bananas' (+0.1), ' strawberries' (+0.03)

Explanation: fruits

Neuron 2
<PROMPT>Once</PROMPT>
Most increased tokens: ' upon' (+0.22), ' in' (+0.2), ' a' (+0.05), ' long' (+0.04)

<PROMPT>Ryan

In [13]:
scorer_separator = "<PASSAGE>"

def get_scorer_simplicity_prompt(explanation):
    prefix = "Explanation\n\n"
    return f"{prefix}{explanation}{scorer_tokenizer.eos_token}", prefix

def get_scorer_surprisal_prompt(prompt, explanation, few_shot_prompts=None, few_shot_explanations=None, few_shot_generations=None):
    if few_shot_explanations is not None:
        assert few_shot_generations is not None and few_shot_prompts is not None
        assert len(few_shot_explanations) == len(few_shot_generations) == len(few_shot_prompts)
        few_shot_prompt = "\n\n".join(get_scorer_surprisal_prompt(pr, expl) + generation for pr, expl, generation in zip(few_shot_prompts, few_shot_explanations, few_shot_generations)) + "\n\n"
    else:
        few_shot_prompt = ""

    if explanation is None:
        return prompt
    return few_shot_prompt + f"{scorer_separator}The following passage was written with an amplified amount of \"{explanation}\":\n{prompt}"

few_shot_prompts = [
    "from west to east, the westmost of the seven",
    "Given 4x is less than 10,",
    "In information theory, the information content, self-information, surprisal, or Shannon information is a basic quantity derived",
    "My favorite food is",
]
few_shot_explanations = [
    "Asia",
    "numbers",
    "she/her pronouns",
    "fruits and vegetables",
]
few_shot_generations = [
    " wonders of the world is the great wall of china",
    " 4",
    " by her when she was a student at Windsor",
    " oranges",
]
print(get_scorer_surprisal_prompt(few_shot_prompts[0], few_shot_explanations[0], few_shot_prompts, few_shot_explanations, few_shot_generations))


<PASSAGE>The following passage was written with an amplified amount of "Asia":
from west to east, the westmost of the seven wonders of the world is the great wall of china

<PASSAGE>The following passage was written with an amplified amount of "numbers":
Given 4x is less than 10, 4

<PASSAGE>The following passage was written with an amplified amount of "she/her pronouns":
In information theory, the information content, self-information, surprisal, or Shannon information is a basic quantity derived by her when she was a student at Windsor

<PASSAGE>The following passage was written with an amplified amount of "fruits and vegetables":
My favorite food is oranges

<PASSAGE>The following passage was written with an amplified amount of "Asia":
from west to east, the westmost of the seven


In [14]:
n_intervention_tokens = 5
scorer_intervention_strengths = [0, 10, 32, 100, 320, 1000]
explainer_intervention_strength = 32

path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-9b-pt-res",
    filename=f"layer_{feat_layer}/width_131k/average_l0_51/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(subject_device) for k, v in params.items()}


def get_encoder_decoder_weights(feat_idx, device, random_resid_direction):
    encoder_feat = pt_params["W_enc"][feat_idx, :]
    decoder_feat = pt_params["W_dec"][feat_idx, :]
    if random_resid_direction:
        decoder_feat = torch.randn_like(decoder_feat)
    return encoder_feat, decoder_feat


def garbage_collect():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        print("CUDA garbage collection performed.")


In [15]:
import torch.nn as nn
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon
  
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1]).to(subject_device)
sae.load_state_dict(pt_params)

<All keys matched successfully>

In [16]:
from functools import partial
def addition_intervention(module, input, output, intervention_strength=10.0, position: int | slice = -1, feat=None):
    hiddens = output[0]  # the later elements of the tuple are the key value cache
    hiddens[:, position, :] += intervention_strength * feat.to(hiddens.device)  # type: ignore

def clamping_intervention(module, input, output, feat_idx=None, clamp_value=0.0, position: int | slice = slice(None)):
    hiddens = output[0]  # the later elements of the tuple are the key value cache
    
    encoding = sae.encode(hiddens)
    error = hiddens - sae.decode(encoding)
    encoding[:, position, feat_idx] = clamp_value
    hiddens = sae.decode(encoding) + error
    return (hiddens, *output[1:])

In [17]:
def get_activating_text(example):
    idxs = (example.activations > 0).nonzero(as_tuple=True)[0]
    idx = random.choice(idxs).item()
    act = example.activations[idx]
    return subject_tokenizer.decode(example.tokens[:idx]), act.item()

In [18]:
from tqdm.auto import tqdm
from itertools import product, islice
import time
from collections import Counter
import pandas as pd

random_resid_direction = False  # this is a random baseline
save_path = f"counterfactual_results/generative_{subject_name.split('/')[-1]}_{feat_layer}layer_{n_feats}feats{'_random_dir' if random_resid_direction else ''}.json"
all_results = []
n_explanations = 1
n_test = 40
n_train = 5
max_generation_length = 16


for iter, record in enumerate(tqdm(loader)):
    garbage_collect()
    
    feat_idx = record.feature.feature_index
    encoder_feat, decoder_feat = get_encoder_decoder_weights(feat_idx, subject_device, random_resid_direction)

    # Remove any hooks
    for l in range(len(subject_layers)):
        subject_layers[l]._forward_hooks.clear()
    

    random.shuffle(record.train)
    scorer_examples = [get_activating_text(e) for e in record.train[:n_test]]
    explainer_examples = [get_activating_text(e) for e in record.train[:n_train]]
    
    # get explanation
    print("Getting explanations...")
    def get_subject_logits(text, layer, intervention_strength=0.0, position=-1, feat=None):
        for l in range(len(subject_layers)):
            subject_layers[l]._forward_hooks.clear()
        subject_layers[layer].register_forward_hook(partial(addition_intervention, intervention_strength=intervention_strength, position=-1, feat=feat))

        inputs = subject_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(subject_device)
        with torch.inference_mode():
            outputs = subject(**inputs)

        return outputs.logits[0, -1, :]
    
    def generate_with_intervention(text, layer, clamp_value=0.0, feat_idx=None):
        for l in range(len(subject_layers)):
            subject_layers[l]._forward_hooks.clear()
        
        inputs = subject_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(subject_device)
        # x[:, slice(None), :] is equivalent to x[:, :, :]
        subject_layers[layer].register_forward_hook(partial(clamping_intervention, clamp_value=clamp_value, feat_idx=feat_idx))
        with torch.inference_mode():
            out = subject.generate(
                **inputs,
                max_new_tokens=max_generation_length,
                num_return_sequences=1,
                temperature=1.0,
                do_sample=True,
            )

        return subject_tokenizer.decode(out[0])

    explainer_time = time.time()
    intervention_examples = []
    for text, act in explainer_examples:
        clean_logits = get_subject_logits(text, feat_layer, intervention_strength=0.0, feat=decoder_feat)
        intervened_logits = get_subject_logits(text, feat_layer, intervention_strength=explainer_intervention_strength, feat=decoder_feat)
        top_probs = (intervened_logits.softmax(dim=-1) - clean_logits.softmax(dim=-1)).topk(n_intervention_tokens)
        
        top_tokens = [subject_tokenizer.decode(i) for i in top_probs.indices]
        top_p_increases = top_probs.values.tolist()
        intervention_examples.append(
            ExplainerInterventionExample(
                prompt=text,
                top_tokens=top_tokens,
                top_p_increases=top_p_increases
            )
        )

    neuron_prompter = ExplainerNeuronFormatter(
        intervention_examples=intervention_examples
    )

    # TODO: improve the few-shot examples
    explainer_prompt = get_explainer_prompt(neuron_prompter, fs_examples)
    explainer_input_ids = explainer_tokenizer(explainer_prompt, return_tensors="pt").input_ids.to(explainer_device)
    with torch.inference_mode():
        samples = explainer.generate(
            explainer_input_ids,
            max_new_tokens=20,
            eos_token_id=explainer_tokenizer.encode("\n")[-1],
            num_return_sequences=n_explanations,
            temperature=0.7,
            do_sample=True,
        )[:, explainer_input_ids.shape[1]:]
    explanations = Counter([explainer_tokenizer.decode(sample).split("\n")[0].strip() for sample in samples])
    explainer_time = time.time() - explainer_time
    print(f"Explainer took {explainer_time:.2f} seconds")

    for ie in intervention_examples:
        print(ie.top_tokens)
        print(ie.top_p_increases)
    print(explanations)

    scoring_time = time.time()

    completions = []
    for text, act in scorer_examples:            
        # get generation with and without intervention
        completions.append({"text": text, "act": act, "completions": dict()})
        for name, strength in [("clean", 0), ("intervened", act)]:
            completions[-1]["completions"][name] = generate_with_intervention(text, feat_layer, clamp_value=strength, feat_idx=feat_idx)

    surprisals_by_explanation = dict()
    completions_by_explanation = dict()
    delta_conditional_entropy_by_explanation = dict()
    delta_conditional_entropy_sems_by_explanation = dict()
    delta_entropy_update_by_explanation = dict()
    delta_entropy_update_sems_by_explanation = dict()
    for explanation in explanations:
        surprisals = {"conditional": {"clean": [], "intervened": []}, "unconditional": {"clean": [], "intervened": []}}

        for conditional_name, expl in [("conditional", explanation), ("unconditional", None)]:
            for i, (text, act) in enumerate(scorer_examples):
                
                # get the explanation predictiveness
                scorer_surprisal_prompt = get_scorer_surprisal_prompt(text, expl, few_shot_prompts, few_shot_explanations, few_shot_generations)
                scorer_prompt_ids = scorer_tokenizer(scorer_surprisal_prompt, return_tensors="pt").input_ids.to(scorer_device)
                
                with torch.inference_mode():
                    kv_cache = scorer(scorer_prompt_ids, use_cache=True, return_dict=True).past_key_values
                
                for name, completion in completions[i]["completions"].items():
                    scorer_completion_ids = scorer_tokenizer(completion, return_tensors="pt").input_ids.to(scorer_device)
                    scorer_input_ids = torch.cat([scorer_prompt_ids, scorer_completion_ids], dim=1)
                    labels = scorer_input_ids.clone()
                    labels[:, :scorer_prompt_ids.shape[1]] = -100
                    with torch.inference_mode():
                        out = scorer(scorer_input_ids, labels=labels, return_dict=True, past_key_values=kv_cache)
                        # HF averages over the sequence length, so we undo that
                        surprisals[conditional_name][name].append(out.loss.item() * (scorer_input_ids.shape[1] - scorer_prompt_ids.shape[1]))
                
                del kv_cache


        surprisals = {k: {k2: np.array(v2) for k2, v2 in v.items()} for k, v in surprisals.items()}
        surprisals_by_explanation[explanation] = surprisals
        conditional_surprisals = {k: v for k, v in surprisals["conditional"].items()}
        delta_conditional_entropy_by_explanation[explanation] = (conditional_surprisals["intervened"] - conditional_surprisals["clean"]).mean()
        delta_conditional_entropy_sems_by_explanation[explanation] = (conditional_surprisals["intervened"] - conditional_surprisals["clean"]).std(ddof=1) / np.sqrt(len(conditional_surprisals["intervened"]))
        surprisal_updates = dict()
        for k in surprisals["conditional"]:
            surprisal_updates[k] = surprisals["conditional"][k] - surprisals["unconditional"][k]
        delta_entropy_update_by_explanation[explanation] = (surprisal_updates["intervened"] - surprisal_updates["clean"]).mean()
        delta_entropy_update_sems_by_explanation[explanation] = (surprisal_updates["intervened"] - surprisal_updates["clean"]).std(ddof=1) / np.sqrt(len(surprisal_updates["intervened"]))

    scoring_time = time.time() - scoring_time
    print(f"Scoring took {scoring_time:.2f} seconds")

    print(delta_conditional_entropy_by_explanation)
    print()
    print()
    all_results.append({
        "feat_idx": feat_idx,
        "feat_layer": feat_layer,
        "explanations": dict(explanations),
        "surprisals_by_explanation": surprisals_by_explanation,
        "delta_conditional_entropy_by_explanation": delta_conditional_entropy_by_explanation,
        "delta_conditional_entropy_sems_by_explanation": delta_conditional_entropy_sems_by_explanation,
        "delta_entropy_update_by_explanation": delta_entropy_update_by_explanation,
        "delta_entropy_update_sems_by_explanation": delta_entropy_update_sems_by_explanation,
        "max_delta_conditional_entropy": max(delta_conditional_entropy_by_explanation.values()),
        "max_delta_entropy_update": max(delta_entropy_update_by_explanation.values()),
        "explainer_prompts": [example.prompt for example in intervention_examples],
        "explainer_intervention_strength": explainer_intervention_strength,
        "scorer_examples": scorer_examples,
        "completions": completions,
        "explainer_examples": explainer_examples,
        "neuron_prompter": neuron_prompter,
    })
    if (iter - 1) % 10 == 0:
        all_df = pd.DataFrame(all_results)
        all_df = all_df.sort_values("max_delta_conditional_entropy", ascending=False)
        all_df.to_json(save_path)
all_df = pd.DataFrame(all_results)
all_df = all_df.sort_values("max_delta_conditional_entropy", ascending=False)
all_df.to_json(save_path)


0it [00:00, ?it/s]

CUDA garbage collection performed.
Getting explanations...
Explainer took 3.39 seconds
[' blue', ' waves', ' turquoise', ' bright', ' green']
[0.006616353988647461, 0.0028985068202018738, 0.0021059513092041016, 0.0016255183145403862, 0.0010466007515788078]
['),', ' Sub', '1', ':', '0']
[0.0018309354782104492, 4.4442713260650635e-06, 2.3280517780222e-06, 2.2363146854331717e-06, 2.127704647136852e-06]
[' “', ' the', ' bears', ' belonged', ' has']
[0.0028056171722710133, 0.0021553784608840942, 0.0015515962149947882, 0.0009813075885176659, 0.0005441978573799133]
[' of', ' on', ' around', ' with', ' "']
[0.011573970317840576, 0.002170249819755554, 0.0016705915331840515, 0.0013710465282201767, 0.0007698006229475141]
[' was', '’', ' announces', ' says', ' had']
[0.0013396013528108597, 0.0007813740521669388, 0.0007248418405652046, 0.0007178026135079563, 0.0006609400734305382]
Counter({'colors': 1})


1it [02:49, 169.77s/it]

Scoring took 164.95 seconds
{'colors': 1.0209539115428925}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.35 seconds
[' the', ' ', ' No', ' second', ' it']
[0.007344722747802734, 0.005746225826442242, 0.0022622139658778906, 0.001772892544977367, 0.0017191581428050995]
['The', 'Former', '<b>', 'But', 'World']
[0.005139157176017761, 0.001124335452914238, 0.0010590292513370514, 0.0009201206266880035, 0.0008139063720591366]
[' over', ' ', ' No', ' #', ' against']
[0.01616699993610382, 0.002311231568455696, 0.00219493149779737, 0.0018563976045697927, 0.0016187019646167755]
['No', 'Volume', 'Vol', '1', '2']
[0.00928506813943386, 0.0066351816058158875, 0.0036926493048667908, 0.0025421734899282455, 0.0016635488718748093]
[' (', ' ', ' No', ' #', '-']
[0.01666119694709778, 0.016381777822971344, 0.013269959017634392, 0.0012315623462200165, 0.0004499172791838646]
Counter({'numbers': 1})


2it [05:36, 167.82s/it]

Scoring took 164.03 seconds
{'numbers': -1.1523461163043975}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.34 seconds
[' distance', ' long', '\n', ' new', ' crowd']
[9.000301361083984e-05, 7.581750833196566e-05, 6.33835734333843e-05, 4.2144132748944685e-05, 2.6536141376709566e-05]
[' removed', ' illegally', ' disposed', ' conveyed', ' loaded']
[0.010568559169769287, 0.0015135153662413359, 0.0015118904411792755, 0.0015107141807675362, 0.0013500573113560677]
[' resolution', ' committee', ' letter', ' Committee', ' report']
[0.0023345034569501877, 0.0022897077724337578, 0.0021072691306471825, 0.001819138415157795, 0.0016888240352272987]
['drew', ' effect', ' the', ' Effect', ' immediate']
[0.005748845636844635, 0.005038110073655844, 0.0029590725898742676, 0.0019265650771558285, 0.0019240148831158876]
[' motion', ' recommendation', ' report', ' unanimous', ' formal']
[0.005860321223735809, 0.005414187908172607, 0.003897346556186676, 0.002665160223841667, 0.0

3it [08:23, 167.41s/it]

Scoring took 164.54 seconds
{'legal': -0.6783262521028519}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.51 seconds
[' big', ' annual', ' very', ' highly', ' exciting']
[0.003398401429876685, 0.0029312409460544586, 0.0028140763752162457, 0.0018012914806604385, 0.001514088362455368]
[' Sterling', ' the', ' of', ' my', ' our']
[0.0026035532355308533, 0.0017001628875732422, 0.001513361930847168, 0.0008129682391881943, 0.0006503479089587927]
[' such', ' great', ' this', ' these', ' some']
[0.015144867822527885, 0.0020169042982161045, 0.0017403773963451385, 0.0016172230243682861, 0.001038569025695324]
['<h1>', 'The', 'import', '<strong>', 'This']
[0.008058935403823853, 0.0026723891496658325, 0.0017859041690826416, 0.0015054121613502502, 0.0007402002811431885]
[' doing', ' committed', ' eager', ' dedicated', ' out']
[0.001983216032385826, 0.0014071771875023842, 0.0012412341311573982, 0.0011396100744605064, 0.000910482369363308]
Counter({'html': 1})


4it [11:13, 168.51s/it]

Scoring took 167.33 seconds
{'html': -1.6799896508455276}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.50 seconds
[' picket', ' march', ' pick', ' tent', ' tents']
[0.008191749453544617, 0.003865480422973633, 0.002566961571574211, 0.0022603031247854233, 0.0020143473520874977]
[' challenges', ' global', ' decline', ' backdrop', ' challenge']
[0.00542839989066124, 0.005042310804128647, 0.0015415344387292862, 0.0014315247535705566, 0.001291816122829914]
['1', ' ', 'en', ',', 'a']
[2.8073787689208984e-05, 9.0912237737939e-09, 6.627177029372433e-09, 6.0902785037342255e-09, 5.368544719885904e-09]
['<h1>', 'The', 'import', '#', '<?']
[0.011984691023826599, 0.002488851547241211, 0.0024669580161571503, 0.0014160741120576859, 0.0013186894357204437]
[' had', ' have', ' did', ' was', ' am']
[0.0033141151070594788, 0.001995634287595749, 0.0019497722387313843, 0.0017291754484176636, 0.001231982372701168]
Counter({'crowd control': 1})


5it [14:01, 168.31s/it]

Scoring took 165.35 seconds
{'crowd control': 1.3969963937997818}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.34 seconds
[' on', 'Sun', ' Observer', ' Province', ' ']
[0.0001633376523386687, 0.00010240927804261446, 3.3532716770423576e-05, 2.3101340048015118e-05, 1.7028592992573977e-05]
[' UK', '<eos>', '.', ' Writer', ' Twitter']
[0.018734633922576904, 0.0036741867661476135, 0.0015236581675708294, 0.001087625976651907, 0.0005191113450564444]
[' Foster', ' n', ' foster', ' NF', ' @_']
[0.005046458914875984, 0.0013898639008402824, 0.001245436491444707, 0.0007125362753868103, 0.0006150612607598305]
['boy', ' Patrick', ' Poy', ' Hut', ' Slack']
[0.02243751287460327, 1.139590182219763e-07, 6.676427233287541e-08, 6.104892236180604e-08, 5.691998694601352e-08]
[' columnist', ' senior', ' contributor', ' frequent', ' veteran']
[0.009208839386701584, 0.003234177827835083, 0.003029199317097664, 0.002862554043531418, 0.0014360202476382256]
Counter({'people': 1})


6it [16:46, 167.27s/it]

Scoring took 162.87 seconds
{'people': 1.1953076094388961}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.67 seconds
[' in', ' playing', ' running', ' engaged', ' doing']
[0.005746934562921524, 0.0038447193801403046, 0.002461749594658613, 0.0012081144377589226, 0.0007279776036739349]
[' private', ' the', ' practice', ' business', ' clinical']
[0.015860408544540405, 0.00823153555393219, 0.007704384624958038, 0.0007617436349391937, 0.000734982080757618]
[' the', ' Vegas', ' business', ' Las', ' love']
[0.01891876757144928, 0.002352602779865265, 0.0020235716365277767, 0.00173875130712986, 0.0017176903784275055]
['ve', 'be', 'ver', 'e', 'de']
[0.0005467534065246582, 2.1352698240661994e-06, 1.235372110386379e-06, 8.216884452849627e-07, 7.22118784324266e-07]
[' ', ' three', ' five', ' decade', ' several']
[0.002315536141395569, 0.0013240724802017212, 0.0011934489011764526, 0.0008529499173164368, 0.0007791928946971893]
Counter({'verb conjugation': 1})


7it [19:36, 168.03s/it]

Scoring took 166.86 seconds
{'verb conjugation': -0.36370945274829863}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.79 seconds
[' president', ' OC', ' O', ' President', ' Secretary']
[0.0015517733991146088, 0.0007413958664983511, 0.0007232725620269775, 0.000708308070898056, 0.0007061893120408058]
[' I', ' because', ' since', ' if', ' it']
[0.0014734268188476562, 0.0009105149656534195, 0.0008016414940357208, 0.0004820004105567932, 0.0003573410212993622]
[' Archbishop', ' Metropolitan', ' archbishop', ' Father', ' nun']
[0.02345559000968933, 0.0038253217935562134, 0.0020113252103328705, 0.0013424158096313477, 0.0004335385747253895]
['\n', ' of', ' (', ' M', '\n\n']
[0.004051417112350464, 0.0022763912566006184, 0.0005727289244532585, 0.0005524349398910999, 0.00045886263251304626]
[' tradition', ' edition', ' Greek', ' version', ' Slav']
[0.006039924919605255, 0.0037467852234840393, 0.002936873584985733, 0.00229557603597641, 0.0013339584693312645]
Counter({

8it [22:26, 168.64s/it]

Scoring took 167.10 seconds
{'people, names': -3.3071950703859327}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.45 seconds
[' of', ' consciousness', ' however', ' memories', ' scientifically']
[0.000569760799407959, 0.00044896267354488373, 0.0003902912139892578, 0.00033029436599463224, 0.0002587117487564683]
[' ecstasy', ' a', ' pleasure', ' and', ' an']
[0.009128708392381668, 0.0031279027462005615, 0.0018659960478544235, 0.0016910098493099213, 0.0015071425586938858]
[':', ',', ' States', '®', 'states']
[0.0027272403240203857, 0.0014891475439071655, 0.0006784589495509863, 0.0006706463173031807, 0.0005498445825651288]
[' deaths', ' brain', ' surgical', ' pain', ' "']
[0.007186059840023518, 0.0031084315851330757, 0.0030220653861761093, 0.0022018253803253174, 0.0019114408642053604]
[' of', ' death', ' heaven', ' I', ' it']
[0.0032614916563034058, 0.0002958296099677682, 0.0001791974646039307, 0.0001756441779434681, 0.00015206215903162956]
Counter({'consciou

9it [25:14, 168.51s/it]

Scoring took 165.70 seconds
{'consciousness': 1.797320768237114}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.61 seconds
[' back', ' to', ' into', ' all', ' onto']
[0.025292545557022095, 0.022760506719350815, 0.002355562523007393, 0.00019207783043384552, 0.00017909606685861945]
[' closer', ' to', ' how', ' a', ' what']
[0.008648838847875595, 0.0026722426991909742, 0.0016954019665718079, 0.0014309883117675781, 0.0012148767709732056]
[' to', ' To', ' I', ' It', ' it']
[0.004062035121023655, 0.002699372125789523, 0.0017854273319244385, 0.0009851865470409393, 0.00045781955122947693]
[' went', ' expanded', ' started', ' switched', ' returned']
[0.02057577669620514, 0.017432689666748047, 0.0027897804975509644, 0.002190239727497101, 0.0019120753277093172]
[' to', ' back', ' down', ' closer', ' where']
[0.0688936784863472, 0.017473474144935608, 0.000773855485022068, 0.0001362144248560071, 4.7338311560451984e-05]
Counter({'"to"': 1})


10it [28:01, 168.06s/it]

Scoring took 164.37 seconds
{'"to"': 2.876817652583122}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.36 seconds
[' bond', ' material', ' surface', ' core', ' plastic']
[0.008197270333766937, 0.008184259757399559, 0.005825452972203493, 0.005551476031541824, 0.0026715523563325405]
[' finish', ' texture', ' material', ' body', ' surface']
[0.022140152752399445, 0.006670035421848297, 0.0037151030264794827, 0.003677681088447571, 0.0034105107188224792]
[' quartz', ' glass', ' chal', ' calcium', ' poly']
[0.006028864532709122, 0.0006126246880739927, 0.0002299632760696113, 0.00022561238438356668, 0.00022142985835671425]
[' LED', ' halogen', ' head', ' technology', ' lamps']
[0.009180810302495956, 0.00885464996099472, 0.005249623209238052, 0.003585366765037179, 0.0032961200922727585]
[' materials', ' plastic', ' material', ' plastics', ' fabric']
[0.027014851570129395, 0.018321938812732697, 0.0056162504479289055, 0.0018920600414276123, 0.0017686416395008564]
Cou

11it [30:47, 167.60s/it]

Scoring took 164.17 seconds
{'materials': -0.8451362729072571}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.57 seconds
[' Report', ' Academic', ' Office', ' Task', ' Advisory']
[0.008257515728473663, 0.005141964182257652, 0.0028893500566482544, 0.0027386099100112915, 0.0027042143046855927]
[' assignments', ' exams', ' student', ' tests', ' work']
[0.006365146487951279, 0.0023198057897388935, 0.0019079670310020447, 0.0018182552885264158, 0.0017344430088996887]
[' Search', '/', 'Search', ' Data', ' Collection']
[0.0009974073618650436, 0.0009728670120239258, 0.0009381449781358242, 0.0009110597893595695, 0.0006774328649044037]
[' kindergarten', ' program', ' pre', ' Pre', ' preschool']
[0.0005812216550111771, 0.0004354987759143114, 0.00033621652983129025, 8.404711843468249e-05, 5.8566685765981674e-05]
[' forecast', ' break', ' period', ' test', ' school']
[0.022243589162826538, 0.0020490651950240135, 0.0010410472750663757, 0.0009611250134184957, 0.000676875

12it [33:36, 168.04s/it]

Scoring took 166.43 seconds
{'educational topics': 3.542952525615692}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.43 seconds
[' Squ', ' Phase', ' Solid', ' Safety', ' Sora']
[0.006059959530830383, 0.003203839063644409, 0.0025146454572677612, 0.0011530923657119274, 0.0009638285264372826]
[',', ' if', '\n\n', ' in', ' that']
[0.0019459016621112823, 0.0013047102838754654, 0.0008948482573032379, 0.00033948151394724846, 0.0003264863044023514]
['-', '_', ' for', ' is', 'A']
[0.0018615126609802246, 2.120155841112137e-06, 1.9248655007686466e-06, 1.6776812117313966e-06, 6.270356607274152e-07]
['2', '3', '0', '9', '5']
[0.0009372979402542114, 0.00016948767006397247, 0.00014485418796539307, 8.807331323623657e-05, 6.318127270787954e-05]
['.', ':', ' -', ' Contra', ' Animal']
[0.005579113960266113, 0.0025039352476596832, 0.0007274486124515533, 0.0005550949135795236, 0.0005169622600078583]
Counter({'hyphens': 1})


13it [36:25, 168.16s/it]

Scoring took 165.77 seconds
{'hyphens': 0.4642083734273911}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.39 seconds
[',', '.', ' occasion', ' and', ' list']
[0.005145657807588577, 0.003303617238998413, 0.0015910081565380096, 0.0015354985371232033, 0.0009806826710700989]
[' prevent', ' address', ' remedy', ' correct', ' date']
[0.004710003733634949, 0.003300607204437256, 0.0014124885201454163, 0.0011359453201293945, 0.0007658526301383972]
[',', ' thanks', ' out', ' absent', ' shut']
[0.014833509922027588, 0.004300955682992935, 0.0017358679324388504, 0.0013582780957221985, 0.0009760390967130661]
[' out', ' beat', ' jumped', ' pretty', ' tired']
[0.0038094576448202133, 0.0022468753159046173, 0.001934426836669445, 0.0016964096575975418, 0.0015841769054532051]
[' up', ' out', ' down', ' help', ' involved']
[0.008769191801548004, 0.00561721995472908, 0.0004990538582205772, 0.00037583732046186924, 0.000273791141808033]
Counter({'commas': 1})


14it [39:12, 167.82s/it]

Scoring took 164.56 seconds
{'commas': -0.8638851702213287}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.89 seconds
['iry', '’', '.', 'hay', "'"]
[0.056419193744659424, 0.00015911826631054282, 0.00013332029629964381, 5.850604793522507e-05, 5.3832889534533024e-05]
['ag', 'og', 'agar', 'iles', 'agit']
[0.0009372234344482422, 4.25254984293133e-06, 2.7293135644868016e-06, 2.627620688144816e-06, 2.2652384359389544e-06]
[' Kem', ' Ken', ' Bem', ' Kennedy', ' Kent']
[0.006206043530255556, 0.0015386404702439904, 0.0014558803522959352, 0.001052189152687788, 0.000989265157841146]
['ane', 'na', 'ana', 'le', 'ina']
[0.04078476130962372, 0.017986353486776352, 0.008936207741498947, 0.00856819748878479, 0.006629134062677622]
['aw', 'an', 'uga', 'ew', 'di']
[0.017063111066818237, 0.005953527987003326, 0.0040744394063949585, 0.003213898278772831, 0.0013712160289287567]
Counter({"people's names": 1})


15it [42:02, 168.54s/it]

Scoring took 167.25 seconds
{"people's names": 0.9845414325594902}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.47 seconds
[' Research', ' Fellow', ' Institute', ' researcher', ' has']
[0.0002831653691828251, 0.00023444992257282138, 0.00021839141845703125, 0.00020160869462415576, 0.00019326922483742237]
[' Institute', ' Ins', ' Physics', ' and', ' Institutes']
[0.0035710930824279785, 0.000306492205709219, 5.5649084970355034e-05, 5.726231393055059e-06, 5.032023182138801e-06]
[' Association', ' Council', ' Medical', ' Historical', ' Academy']
[0.00416015088558197, 0.00334961898624897, 0.003299463540315628, 0.0022214502096176147, 0.001766180619597435]
[' Drug', ' Survey', ' Longitudinal', ' Center', ' Epidemio']
[0.018623635172843933, 0.004183650016784668, 0.0035909898579120636, 0.0033325375989079475, 0.001754816621541977]
[' Center', ' institute', ' and', ' Group', ' Surgery']
[0.00010432134149596095, 0.00010229775216430426, 8.967396570369601e-05, 7.02135

16it [44:50, 168.41s/it]

Scoring took 165.59 seconds
{'Institute': 0.5122965931892395}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.59 seconds
['_', '<unused63>', 'work', '<unused61>', '<unused62>']
[0.0012223124504089355, 2.7826217774418183e-07, 1.3477688298735302e-07, 1.0930045846180292e-07, 6.985533218539786e-08]
['-', ' Tech', 'Tech', '_', ' Authors']
[0.007117331027984619, 0.0005229492671787739, 0.0004294624086469412, 1.7061771359294653e-05, 2.0921970644849353e-06]
[' people', ' days', ' things', ' words', ' times']
[0.015257984399795532, 0.0040536485612392426, 0.0034050457179546356, 0.002956768497824669, 0.0012649232521653175]
['ile', 'x', 'os', 'cursion', 'quisite']
[0.0044546183198690414, 0.0021925847977399826, 0.001888178288936615, 0.0015484625473618507, 0.0013073515146970749]
['-', 'nfl', 'game', 'stream', 'games']
[0.011549055576324463, 7.284164894372225e-05, 3.2102714612847194e-05, 9.069404768524691e-06, 4.352298674348276e-06]
Counter({'hyphens': 1})


17it [47:37, 167.85s/it]

Scoring took 163.89 seconds
{'hyphens': -2.910629057884216}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.24 seconds
[' Literacy', ' Education', ' and', ' Fitness', ' Performance']
[0.01689641922712326, 0.011858776211738586, 0.0059695616364479065, 0.0025273077189922333, 0.002030743286013603]
[' ', ' them', ' the', ' more', ' this']
[0.008704990148544312, 0.0018798373639583588, 0.0016981586813926697, 0.001597028225660324, 0.0011709630489349365]
['0', '-', '/', ' seconds', ' yards']
[0.014607250690460205, 0.002117551863193512, 0.0011430736631155014, 0.0011103986762464046, 0.0007668674224987626]
[' at', ' with', ' young', ' as', ' the']
[0.006610125303268433, 0.0017114393413066864, 0.001516668125987053, 0.0008101053535938263, 0.0007757833227515221]
[' complementary', ' complimentary', ' important', ' appropriate', ' reinforcing']
[0.004036247730255127, 0.003714296966791153, 0.0025321748107671738, 0.0016540652140974998, 0.0012468928471207619]
Counter({'numbe

18it [50:23, 167.46s/it]

Scoring took 164.26 seconds
{'numbers': 0.07373359799385071}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.33 seconds
[' bowling', ' turning', ' delivering', ' breaking', ' ripping']
[0.006886288523674011, 0.005131181329488754, 0.00209888257086277, 0.0016512805595993996, 0.001532379537820816]
[' up', ' with', ' and', ' from', ' down']
[0.0011432170867919922, 0.0009709112346172333, 0.0006813257932662964, 0.0006003528833389282, 0.0005827182903885841]
[' double', ' brace', ' straight', ' six', ' low']
[0.003560081124305725, 0.0033200010657310486, 0.00264078751206398, 0.0022058188915252686, 0.001187263522297144]
[' curve', ' knuckle', ' slider', ' screw', ' change']
[0.015515491366386414, 0.009159278124570847, 0.007772520184516907, 0.006143948063254356, 0.004962310194969177]
[' and', ' through', ',', ' out', ' from']
[0.014919057488441467, 0.002207871526479721, 0.0010562613606452942, 0.0006542066112160683, 0.00043057743459939957]
Counter({'baseball': 1})


19it [53:12, 167.71s/it]

Scoring took 165.93 seconds
{'baseball': 2.317998507618904}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.43 seconds
[' the', ' when', ' his', ' after', ' with']
[0.00912216305732727, 0.0026390738785266876, 0.0008386150002479553, 0.0006531886756420135, 0.0006415462121367455]
[' the', ' become', ' mid', ' create', ' establish']
[0.00584334135055542, 0.0015539191663265228, 0.0006602015346288681, 0.000597917940467596, 0.0005845336709171534]
[' the', ' passing', ' strangers', ' settlers', ' our']
[0.023125767707824707, 0.0027464600279927254, 0.0007011205889284611, 0.0006775651127099991, 0.00048504211008548737]
[' the', ' completion', ' opening', ' groundbreaking', ' passage']
[0.025888264179229736, 0.0034565189853310585, 0.0007279827259480953, 0.0003002043813467026, 0.00020474859047681093]
[' the', ' launch', ' entering', ' launching', ' taking']
[0.024612069129943848, 0.006872605532407761, 0.002876437269151211, 0.0020772824063897133, 0.0015620035119354725]


20it [56:00, 167.94s/it]

Scoring took 165.95 seconds
{'determiner': 2.7872099965810775}


CUDA garbage collection performed.
Getting explanations...
Explainer took 2.52 seconds
['zing', 'zes', ' against', 'zed', ' into']
[0.007471919059753418, 0.003767356276512146, 0.0020960974507033825, 0.0009140074253082275, 0.0008343983208760619]
[' was', ' hit', ' got', ' hurts', ' is']
[0.0021324753761291504, 0.0016379519365727901, 0.001039944589138031, 0.0009437147527933121, 0.0007550716400146484]
[' into', ' against', ' in', ' and', ' on']
[0.005905698984861374, 0.0015405853046104312, 0.0008394811302423477, 0.0007611960172653198, 0.0006763716228306293]
[' places', ' ranks', ' serves', ' matches', ' approximately']
[0.00980520248413086, 0.0019284100271761417, 0.001839357428252697, 0.0015939250588417053, 0.0012680711224675179]
[' by', ' to', ' along', ' on', ' next']
[0.0025678141973912716, 0.002399880439043045, 0.0023012354504317045, 0.002289455384016037, 0.002027984242886305]
Counter({'prepositions': 1})


21it [58:47, 167.97s/it]

Scoring took 163.95 seconds
{'prepositions': 1.7992319613695145}







In [23]:
all_df

Unnamed: 0,feat_idx,feat_layer,explanations,surprisals_by_explanation,delta_conditional_entropy_by_explanation,delta_conditional_entropy_sems_by_explanation,delta_entropy_update_by_explanation,delta_entropy_update_sems_by_explanation,max_delta_conditional_entropy,max_delta_entropy_update,explainer_prompts,explainer_intervention_strength,scorer_examples,completions,explainer_examples,neuron_prompter
11,16,32,{'educational topics': 1},{'educational topics': {'conditional': {'clean...,{'educational topics': 3.542952525615692},{'educational topics': 2.5029083376178978},{'educational topics': -0.5158071905374527},{'educational topics': 0.49042080764751483},3.542953,-0.515807,[6 As Prepared By The Academic Centers For The...,32,[(6 As Prepared By The Academic Centers For Th...,[{'text': '6 As Prepared By The Academic Cente...,[(6 As Prepared By The Academic Centers For Th...,ExplainerNeuronFormatter(intervention_examples...
9,14,32,"{'""to""': 1}","{'""to""': {'conditional': {'clean': [155.847466...","{'""to""': 2.876817652583122}","{'""to""': 2.1307858198931804}","{'""to""': 0.7332447558641434}","{'""to""': 0.7316730677874271}",2.876818,0.733245,[<bos> the player some slack when doing time-b...,32,[(<bos> the player some slack when doing time-...,[{'text': '<bos> the player some slack when do...,[(<bos> the player some slack when doing time-...,ExplainerNeuronFormatter(intervention_examples...
19,26,32,{'determiner': 1},{'determiner': {'conditional': {'clean': [158....,{'determiner': 2.7872099965810775},{'determiner': 2.327067918270164},{'determiner': -0.8904970318078995},{'determiner': 0.6167089532317725},2.78721,-0.890497,"[<bos> Adventures of Tom Bombadil]]'', and had...",32,"[(<bos> Adventures of Tom Bombadil]]'', and ha...",[{'text': '<bos> Adventures of Tom Bombadil]]'...,"[(<bos> Adventures of Tom Bombadil]]'', and ha...",ExplainerNeuronFormatter(intervention_examples...
18,25,32,{'baseball': 1},{'baseball': {'conditional': {'clean': [165.09...,{'baseball': 2.317998507618904},{'baseball': 1.7460312352564484},{'baseball': 0.4967650905251503},{'baseball': 0.9446117031906475},2.317999,0.496765,[<bos> years will be even better in the second...,32,[(<bos> years will be even better in the secon...,[{'text': '<bos> years will be even better in ...,[(<bos> years will be even better in the secon...,ExplainerNeuronFormatter(intervention_examples...
20,27,32,{'prepositions': 1},{'prepositions': {'conditional': {'clean': [12...,{'prepositions': 1.7992319613695145},{'prepositions': 1.9783536582497874},{'prepositions': -0.9043507277965546},{'prepositions': 0.8037593268727836},1.799232,-0.904351,"[ through. It worked at 800 High, and again at...",32,"[( through. It worked at 800 High, and again a...","[{'text': ' through. It worked at 800 High, an...","[( through. It worked at 800 High, and again a...",ExplainerNeuronFormatter(intervention_examples...
8,12,32,{'consciousness': 1},{'consciousness': {'conditional': {'clean': [1...,{'consciousness': 1.797320768237114},{'consciousness': 1.9868499372462793},{'consciousness': -0.09122436046600342},{'consciousness': 0.5530723667469144},1.797321,-0.091224,[ We can detect emotions through the physical ...,32,[( We can detect emotions through the physical...,[{'text': ' We can detect emotions through the...,[( We can detect emotions through the physical...,ExplainerNeuronFormatter(intervention_examples...
4,8,32,{'crowd control': 1},{'crowd control': {'conditional': {'clean': [1...,{'crowd control': 1.3969963937997818},{'crowd control': 1.9039305042411705},{'crowd control': -0.8511744707822799},{'crowd control': 0.9074922086568572},1.396996,-0.851174,"[U, Steve Harrison, wrote to the Australian Fe...",32,"[(U, Steve Harrison, wrote to the Australian F...","[{'text': 'U, Steve Harrison, wrote to the Aus...","[(U, Steve Harrison, wrote to the Australian F...",ExplainerNeuronFormatter(intervention_examples...
5,9,32,{'people': 1},{'people': {'conditional': {'clean': [104.2934...,{'people': 1.1953076094388961},{'people': 2.6257955071213086},{'people': 0.3281347781419754},{'people': 0.7986407790261398},1.195308,0.328135,[ Vancouver Sun to your Google Plus circlesFol...,32,[( Vancouver Sun to your Google Plus circlesFo...,[{'text': ' Vancouver Sun to your Google Plus ...,[( Vancouver Sun to your Google Plus circlesFo...,ExplainerNeuronFormatter(intervention_examples...
0,0,32,{'colors': 1},{'colors': {'conditional': {'clean': [170.3265...,{'colors': 1.0209539115428925},{'colors': 2.62119763125034},{'colors': -0.17870076596736909},{'colors': 0.6372228967597492},1.020954,-0.178701,"[<bos> to the beach, to get some photos of the...",32,"[(<bos> to the beach, to get some photos of th...","[{'text': '<bos> to the beach, to get some pho...","[(<bos> to the beach, to get some photos of th...",ExplainerNeuronFormatter(intervention_examples...
14,20,32,{'people's names': 1},{'people's names': {'conditional': {'clean': [...,{'people's names': 0.9845414325594902},{'people's names': 2.591538730736589},{'people's names': -0.3437474474310875},{'people's names': 0.6454999211707236},0.984541,-0.343747,[<bos> had also passed away. Bronia was now al...,32,[(<bos> had also passed away. Bronia was now a...,[{'text': '<bos> had also passed away. Bronia ...,[(<bos> had also passed away. Bronia was now a...,ExplainerNeuronFormatter(intervention_examples...


In [20]:
all_results[1]

{'feat_idx': 2,
 'feat_layer': 32,
 'explanations': {'numbers': 1},
 'surprisals_by_explanation': {'numbers': {'conditional': {'clean': array([130.05473042, 163.60789633, 166.69568157, 103.53930068,
           144.79606295, 134.44639206, 121.25616121, 110.21008015,
           132.32139587, 119.14018631, 181.58804178, 127.30354929,
           126.87774396, 118.04799104,  90.20391226,  87.20266342,
           158.55610538, 132.95434141, 161.08916187,  76.16211605,
           136.49475527, 124.82841778, 110.42610741, 111.96801186,
           147.94923782, 139.84703064, 104.38328743, 132.32879019,
           134.61856461,  98.30728507, 112.26144433, 136.04847765,
           126.61558342, 143.53449941, 114.09034729, 130.91243553,
           146.88883781, 132.55221641, 108.03164434, 125.39075208]),
    'intervened': array([141.80460763, 189.08174086, 168.02305508,  99.28200865,
           146.53147888, 122.80537033, 126.26939487, 131.14189339,
           131.25938416, 118.67609978, 139.88834

In [21]:
for row in all_results:
    compls = row["completions"]
    for compl in compls:
        if compl["completions"]["clean"] != compl["completions"]["intervened"]:
            print(compl["act"])
            print(compl["text"])
            print("clean:\n" + compl["completions"]["clean"])
            print("intervened:\n" + compl["completions"]["intervened"])
            print()
            print()
            

10.140625
<bos> to the beach, to get some photos of the golden sand and the
clean:
<bos><bos> to the beach, to get some photos of the golden sand and the waves breaking in the foreground, and then the sea and the skyline in the distance
intervened:
<bos><bos> to the beach, to get some photos of the golden sand and the white lighthouse, but, as she walked by the sea, she felt a strong


11.484375
<bos>13), Meeting Houses (10), Coastal Lighthouses (8), Pre-Victorian Historic Homes (7), Octagon Buildings (7), Converted Firehouses (6), View more Buildings subcategories...
SignsMilestones (279), Pictorial Pub Signs (279
clean:
<bos><bos>13), Meeting Houses (10), Coastal Lighthouses (8), Pre-Victorian Historic Homes (7), Octagon Buildings (7), Converted Firehouses (6), View more Buildings subcategories...
SignsMilestones (279), Pictorial Pub Signs (279), Historic Signs (237), View more Signs subcategories...
Other
intervened:
<bos><bos>13), Meeting Houses (10), Coastal Lighthouses (8), Pre-V

In [22]:
x = all_df.sort_values("max_normalized_predictiveness_score", ascending=False)
for i in list(range(5)) + [9, 49, 99, 499, 999]:

    row = x.iloc[i]
    best_expl = max(row.normalized_predictiveness_score_by_explanation, key=row.normalized_predictiveness_score_by_explanation.get)
    print(f"({i + 1}/{len(x)}) \"{best_expl}\" | layer {row.feat_layer}, idx {row.feat_idx}")
    print(f"\texplanations: {row.explanations}")
    print(f"\tnormalized predictiveness score: {row.normalized_predictiveness_score:.1f} ± {row.predictiveness_score_stderr:.1f}")
    print(f"\tmax normalized predictiveness score: {row.max_normalized_predictiveness_score:.1f} (expl={best_expl})")
    print("\n\tExample counterfactuals (intervention strength = 32):")
    for pr, tts, tpis in zip(row["explainer_prompts"], row["explainer_top_tokens"], row["explainer_top_p_increases"]):
        ex = ExplainerInterventionExample(
                    prompt=pr,
                    top_tokens=tts,
                    top_p_increases=tpis
                )
        print("\t" + ex.text().replace("\nMost", "\n\t\tMost") + "\n")

KeyError: 'max_normalized_predictiveness_score'