In [1]:
# load subject model
# load SAEs without attaching them to the model
# for now just use the Islam feature and explanation
# load a scorer. The prompt should have the input as well this time
# (for now) on random pretraining data, evaluate gpt2 with a hook that 
# adds a multiple of the Islam feature to the appropriate residual stream layer and position
# Get the pre- and post-intervention output distributions of gpt2
# (TODO: check if all the Islam features just have similar embeddings)
# Show this to the scorer and get a score (scorer should be able to have a good prior without being given the clean output distribution)
# Also get a simplicity score for the explanation

In [2]:
import pandas as pd
from pathlib import Path
import json

results_dir = "/mnt/ssd-1/gpaulo/SAE-Zoology/results/gpt2_simulation/all_at_once"
results = dict()
for fname in Path(results_dir).iterdir():
    with open(fname, "r") as f:
        r = json.load(f)
    last = fname.stem.split(".")[-1]
    layer = int(last.split("_")[0])
    feat = int(last[last.index("_feature") + len("_feature"):])
    results[fname.stem] = {"ev_correlation_score": r["ev_correlation_score"], "layer": layer, "feature": feat}
input_scores_df = pd.DataFrame(results).T
input_scores_df["layer"] = input_scores_df["layer"].astype(int)
input_scores_df["feature"] = input_scores_df["feature"].astype(int)
input_scores_df = input_scores_df.sort_values("ev_correlation_score", ascending=False)
unq_layers = input_scores_df["layer"].unique()
input_scores_df

Unnamed: 0,ev_correlation_score,layer,feature
.transformer.h.2_feature0,0.970093,2,0
.transformer.h.2_feature19,0.966378,2,19
.transformer.h.0_feature0,0.952401,0,0
.transformer.h.4_feature4,0.952061,4,4
.transformer.h.0_feature5,0.949993,0,5
.transformer.h.2_feature4,0.941871,2,4
.transformer.h.2_feature11,0.930066,2,11
.transformer.h.4_feature19,0.918787,4,19
.transformer.h.0_feature14,0.906342,0,14
.transformer.h.0_feature3,0.89708,0,3


In [3]:
import json
import random

with open("pile.jsonl", "r") as f:
    pile = random.sample([json.loads(line) for line in f.readlines()], 100000)

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
subject_device = "cuda:1"

subject_name = "gpt2"
subject = AutoModelForCausalLM.from_pretrained(subject_name).to(subject_device)
subject_tokenizer = AutoTokenizer.from_pretrained(subject_name)
subject_tokenizer.pad_token = subject_tokenizer.eos_token
subject.config.pad_token_id = subject_tokenizer.eos_token_id

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

In [None]:
from transformers import TRANSFORMERS_CACHE

print(f"The Hugging Face cache directory is: {TRANSFORMERS_CACHE}")

The Hugging Face cache directory is: /home/alex/.cache/huggingface/hub


In [None]:
from transformers import BitsAndBytesConfig

scorer_device = "cuda:0"
scorer_name = "meta-llama/Meta-Llama-3.1-70B"
scorer = AutoModelForCausalLM.from_pretrained(
    scorer_name,
    device_map={"": scorer_device},
    torch_dtype=torch.bfloat16,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
)
scorer_tokenizer = AutoTokenizer.from_pretrained(scorer_name)
scorer_tokenizer.pad_token = scorer_tokenizer.eos_token
scorer.config.pad_token_id = scorer_tokenizer.eos_token_id
scorer.generation_config.pad_token_id = scorer_tokenizer.eos_token_id

# explainer is the same model as the scorer
explainer_device = scorer_device
explainer = scorer
explainer_tokenizer = scorer_tokenizer


Loading checkpoint shards: 100%|██████████| 30/30 [00:29<00:00,  1.01it/s]


In [None]:
from dataclasses import dataclass
import copy

@dataclass
class ExplainerInterventionExample:
    prompt: str
    top_tokens: list[str]
    top_p_increases: list[float]

    def __post_init__(self):
        self.prompt = self.prompt.replace("\n", "\\n")

    def text(self) -> str:
        tokens_str = ", ".join(f"'{tok}' (+{round(p, 3)})" for tok, p in zip(self.top_tokens, self.top_p_increases))
        return f"<PROMPT>{self.prompt}</PROMPT>\nMost increased tokens: {tokens_str}"
    
@dataclass
class ExplainerNeuronFormatter:
    intervention_examples: list[ExplainerInterventionExample]
    explanation: str | None = None

    def text(self) -> str:
        text = "\n\n".join(example.text() for example in self.intervention_examples)
        text += "\n\nExplanation:"
        if self.explanation is not None:
            text += " " + self.explanation
        return text


def get_explainer_prompt(neuron_prompter: ExplainerNeuronFormatter, few_shot_examples: list[ExplainerNeuronFormatter] | None = None) -> str:
    prompt = "We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.\n\n" \
        "For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.\n\n" \
        "The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.\n\n"
    
    i = 1
    for few_shot_example in few_shot_examples or []:
        assert few_shot_example.explanation is not None
        prompt += f"Neuron {i}\n" + few_shot_example.text() + "\n\n"
        i += 1

    prompt += f"Neuron {i}\n"
    prompt += neuron_prompter.text()

    return prompt


fs_examples = [
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="My favorite food is",
                top_tokens=[" oranges", " bananas", " apples"],
                top_p_increases=[0.81, 0.09, 0.02]
            ),
            ExplainerInterventionExample(
                prompt="Whenever I would see",
                top_tokens=[" fruit", " a", " apples", " red"],
                top_p_increases=[0.09, 0.06, 0.06, 0.05]
            ),
            ExplainerInterventionExample(
                prompt="I like to eat",
                top_tokens=[" fro", " fruit", " oranges", " bananas", " strawberries"],
                top_p_increases=[0.14, 0.13, 0.11, 0.10, 0.03]
            )
        ],
        explanation="fruits"
    ),
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="Once",
                top_tokens=[" upon", " in", " a", " long"],
                top_p_increases=[0.22, 0.2, 0.05, 0.04]
            ),
            ExplainerInterventionExample(
                prompt="Ryan Quarles\\n\\nRyan Francis Quarles (born October 20, 1983)",
                top_tokens=[" once", " happily", " for"],
                top_p_increases=[0.03, 0.31, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="MSI Going Full Throttle @ CeBIT",
                top_tokens=[" Once", " once", " in", " the", " a", " The"],
                top_p_increases=[0.02, 0.01, 0.01, 0.01, 0.01, 0.01]
            ),
        ],
        explanation="storytelling"
    ),
    ExplainerNeuronFormatter(
        intervention_examples=[
            ExplainerInterventionExample(
                prompt="He owned the watch for a long time. While he never said it was",
                top_tokens=[" hers", " hers", " hers"],
                top_p_increases=[0.09, 0.06, 0.06, 0.5]
            ),
            ExplainerInterventionExample(
                prompt="For some reason",
                top_tokens=[" she", " her", " hers"],
                top_p_increases=[0.14, 0.01, 0.01]
            ),
            ExplainerInterventionExample(
                prompt="insurance does not cover",
                top_tokens=[" her", " women", " her's"],
                top_p_increases=[0.10, 0.02, 0.01]
            )
        ],
        explanation="she/her pronouns"
    )
]

neuron_prompter = copy.deepcopy(fs_examples[0])
neuron_prompter.explanation = None
print(get_explainer_prompt(neuron_prompter, fs_examples))


We're studying neurons in a transformer model. We want to know how intervening on them affects the model's output.

For each neuron, we'll show you a few prompts where we intervened on that neuron at the final token position, and the tokens whose logits increased the most.

The tokens are shown in descending order of their probability increase, given in parentheses. Your job is to give a short summary of what outputs the neuron promotes.

Neuron 1
<PROMPT>My favorite food is</PROMPT>
Most increased tokens: ' oranges' (+0.81), ' bananas' (+0.09), ' apples' (+0.02)

<PROMPT>Whenever I would see</PROMPT>
Most increased tokens: ' fruit' (+0.09), ' a' (+0.06), ' apples' (+0.06), ' red' (+0.05)

<PROMPT>I like to eat</PROMPT>
Most increased tokens: ' fro' (+0.14), ' fruit' (+0.13), ' oranges' (+0.11), ' bananas' (+0.1), ' strawberries' (+0.03)

Explanation: fruits

Neuron 2
<PROMPT>Once</PROMPT>
Most increased tokens: ' upon' (+0.22), ' in' (+0.2), ' a' (+0.05), ' long' (+0.04)

<PROMPT>Ryan

In [None]:
def get_scorer_simplicity_prompt(explanation):
    prefix = "Explanation\n\n"
    return f"{prefix}{explanation}{scorer_tokenizer.eos_token}", prefix

def get_scorer_predictiveness_prompt(prompt, explanation, few_shot_prompts=None, few_shot_explanations=None, few_shot_tokens=None):
    if few_shot_explanations is not None:
        assert few_shot_tokens is not None and few_shot_prompts is not None
        assert len(few_shot_explanations) == len(few_shot_tokens) == len(few_shot_prompts)
        few_shot_prompt = "\n\n".join(get_scorer_predictiveness_prompt(pr, expl) + token for pr, expl, token in zip(few_shot_prompts, few_shot_explanations, few_shot_tokens)) + "\n\n"
    else:
        few_shot_prompt = ""
    return few_shot_prompt + f"Explanation: {explanation}\n<PROMPT>{prompt}</PROMPT>"

few_shot_prompts = ["My favorite food is", "From west to east, the westmost of the seven", "He owned the watch for a long time. While he never said it was"]
few_shot_explanations = ["fruits and vegetables", "ateg", "she/her pronouns"]
few_shot_tokens = [" oranges", "WAY", " hers"]
print(get_scorer_predictiveness_prompt(few_shot_prompts[0], few_shot_explanations[0], few_shot_prompts, few_shot_explanations, few_shot_tokens))

Explanation: fruits and vegetables
<PROMPT>My favorite food is</PROMPT> oranges

Explanation: ateg
<PROMPT>From west to east, the westmost of the seven</PROMPT>WAY

Explanation: she/her pronouns
<PROMPT>He owned the watch for a long time. While he never said it was</PROMPT> hers

Explanation: fruits and vegetables
<PROMPT>My favorite food is</PROMPT>


In [None]:
from functools import partial

def intervene(module, input, output, intervention_strength=10.0, position=-1):
    hiddens = output[0]  # the later elements of the tuple are the key value cache
    hiddens[:, position, :] += intervention_strength * feat.to(hiddens.device)

def get_texts(n, seed=42, randomize_length=True):
    random.seed(seed)
    texts = []
    for _ in range(n):
        
        # sample a random text from the pile, and stop it at a random token position, less than 64 tokens
        text = random.choice(pile)["text"]
        tokenized_text = subject_tokenizer.encode(text, add_special_tokens=False, max_length=64, truncation=True)
        if len(tokenized_text) < 1:
            continue
        if randomize_length:
            stop_pos = random.randint(1, min(len(tokenized_text), 63))
        else:
            stop_pos = 63
        text = subject_tokenizer.decode(tokenized_text[:stop_pos])
        texts.append(text)
    return texts

n_explainer_texts = 3
n_scorer_texts = 10
# explainer_texts = get_texts(n_explainer_texts)
# explainer_texts = ["Current religion:", "A country that is", "Many people believe that"]
# scorer_texts = get_texts(n_scorer_texts)

In [None]:
scorer_vocab = scorer_tokenizer.get_vocab()
subject_vocab = subject_tokenizer.get_vocab()

# Pre-compute the mapping of subject tokens to scorer tokens
subject_to_scorer = {}
text_subject_to_scorer = {}
for subj_tok, subj_id in subject_vocab.items():
    if subj_tok in scorer_vocab:
        subject_to_scorer[subj_id] = scorer_vocab[subj_tok]
        text_subject_to_scorer[subj_tok] = subj_tok
    else:
        for i in range(len(subj_tok) - 1, 0, -1):
            if subj_tok[:i] in scorer_vocab:
                subject_to_scorer[subj_id] = scorer_vocab[subj_tok[:i]]
                text_subject_to_scorer[subj_tok] = subj_tok[:i]
                break
        else:
            raise ValueError(f"No scorer token found for {subj_tok}")
subject_ids = torch.tensor(list(subject_to_scorer.keys()), device=scorer_device)
scorer_ids = torch.tensor(list(subject_to_scorer.values()), device=scorer_device)

In [None]:
n_intervention_examples = 5
n_candidate_texts = 1_000
candidate_texts = get_texts(n_candidate_texts, randomize_length=False)

In [None]:
subject_tokenizer.decode(input_ids[0])

'Acrocercops telestis\n\nAcrocercops telestis is a moth of the family Gracillariidae. It is known from India (Bihar).\n\nThe larvae feed on Mallotus repandus, Trewia species (including Trewia nudiflor'

In [None]:
h.norm(dim=-1)

tensor([2568.3259,   51.8801,   56.1855,   56.7833,   56.3197,   73.1205,
          58.5472,   55.7751,   56.7515,   54.8194,   76.8825,   59.6489,
          56.3214,   57.5037,   58.4638,   71.8091,   61.4631,   58.9519,
          56.8281,   59.7283,   72.9409,   56.1726,   55.4350,   64.0503,
          81.6632,   59.4288,   59.9517,   74.9668,   55.2861,   61.9393,
          59.8316,   67.2749,   57.4120,   72.5435,   58.4150,   64.5726,
          69.5647,   58.5335,   52.1635,   51.9985,   61.7360,   83.2534,
          74.2880,   57.7624,   75.1908,   59.3623,   60.8752,   66.7576,
          57.9630,   58.7110,   54.8738,   63.6373,   59.0941,   58.1155,
          68.2508,   60.3722,   69.6522,   66.7617,   59.5492,   60.7646,
          78.1720,   63.0341,   65.7444], device='cuda:1')

In [None]:
print("Loading autoencoder...", end="")
path = f"{weight_dir}/{layer}.pt"
state_dict = torch.load(path)
ae = Autoencoder.from_state_dict(state_dict=state_dict)
feat = ae.decoder.weight[:, feat_idx].to(subject_device)
ae.encoder.to(subject_device)
ae.activation.to(subject_device)

### Find examples where the feature activates
# Remove any hooks
for l in range(len(subject.transformer.h)):
    subject.transformer.h[l]._forward_hooks.clear()

for text in tqdm(candidate_texts, total=len(candidate_texts)):
    input_ids = subject_tokenizer(text, return_tensors="pt").input_ids.to(subject_device)
    with torch.inference_mode():
        out = subject(input_ids, output_hidden_states=True)
        # hidden_states is actually one longer than the number of layers, because it includes the input embeddings
        h = out.hidden_states[layer + 1].squeeze(0)
        feat_acts = ae.encoder(h[-1, :])
        topk = ae.activation(feat_acts)
        print(f"{(feat_acts > 20).sum().item()=}")

Loading autoencoder...

  4%|▎         | 36/1000 [00:00<00:05, 174.81it/s]

(feat_acts > 20).sum().item()=167
(feat_acts > 20).sum().item()=944
(feat_acts > 20).sum().item()=64
(feat_acts > 20).sum().item()=187
(feat_acts > 20).sum().item()=496
(feat_acts > 20).sum().item()=282
(feat_acts > 20).sum().item()=97
(feat_acts > 20).sum().item()=127
(feat_acts > 20).sum().item()=682
(feat_acts > 20).sum().item()=187
(feat_acts > 20).sum().item()=33
(feat_acts > 20).sum().item()=189
(feat_acts > 20).sum().item()=46
(feat_acts > 20).sum().item()=79
(feat_acts > 20).sum().item()=167
(feat_acts > 20).sum().item()=876
(feat_acts > 20).sum().item()=332
(feat_acts > 20).sum().item()=143
(feat_acts > 20).sum().item()=18
(feat_acts > 20).sum().item()=69
(feat_acts > 20).sum().item()=62
(feat_acts > 20).sum().item()=103
(feat_acts > 20).sum().item()=117
(feat_acts > 20).sum().item()=8
(feat_acts > 20).sum().item()=32
(feat_acts > 20).sum().item()=171
(feat_acts > 20).sum().item()=102
(feat_acts > 20).sum().item()=171
(feat_acts > 20).sum().item()=331
(feat_acts > 20).sum().it

  7%|▋         | 72/1000 [00:00<00:05, 174.43it/s]

(feat_acts > 20).sum().item()=13
(feat_acts > 20).sum().item()=179
(feat_acts > 20).sum().item()=118
(feat_acts > 20).sum().item()=151
(feat_acts > 20).sum().item()=270
(feat_acts > 20).sum().item()=163
(feat_acts > 20).sum().item()=9
(feat_acts > 20).sum().item()=207
(feat_acts > 20).sum().item()=31
(feat_acts > 20).sum().item()=217
(feat_acts > 20).sum().item()=32
(feat_acts > 20).sum().item()=364
(feat_acts > 20).sum().item()=27
(feat_acts > 20).sum().item()=12
(feat_acts > 20).sum().item()=80
(feat_acts > 20).sum().item()=183
(feat_acts > 20).sum().item()=66
(feat_acts > 20).sum().item()=93
(feat_acts > 20).sum().item()=104
(feat_acts > 20).sum().item()=594
(feat_acts > 20).sum().item()=76
(feat_acts > 20).sum().item()=3
(feat_acts > 20).sum().item()=117
(feat_acts > 20).sum().item()=131
(feat_acts > 20).sum().item()=199
(feat_acts > 20).sum().item()=102
(feat_acts > 20).sum().item()=167
(feat_acts > 20).sum().item()=156
(feat_acts > 20).sum().item()=453
(feat_acts > 20).sum().item

  9%|▉         | 90/1000 [00:00<00:05, 174.38it/s]

(feat_acts > 20).sum().item()=10
(feat_acts > 20).sum().item()=199
(feat_acts > 20).sum().item()=224
(feat_acts > 20).sum().item()=118
(feat_acts > 20).sum().item()=22
(feat_acts > 20).sum().item()=609
(feat_acts > 20).sum().item()=450
(feat_acts > 20).sum().item()=252
(feat_acts > 20).sum().item()=256
(feat_acts > 20).sum().item()=41
(feat_acts > 20).sum().item()=431
(feat_acts > 20).sum().item()=29
(feat_acts > 20).sum().item()=117
(feat_acts > 20).sum().item()=264
(feat_acts > 20).sum().item()=94
(feat_acts > 20).sum().item()=149
(feat_acts > 20).sum().item()=215
(feat_acts > 20).sum().item()=4
(feat_acts > 20).sum().item()=89
(feat_acts > 20).sum().item()=265
(feat_acts > 20).sum().item()=66
(feat_acts > 20).sum().item()=146
(feat_acts > 20).sum().item()=91
(feat_acts > 20).sum().item()=146
(feat_acts > 20).sum().item()=2605
(feat_acts > 20).sum().item()=828
(feat_acts > 20).sum().item()=580
(feat_acts > 20).sum().item()=89
(feat_acts > 20).sum().item()=5
(feat_acts > 20).sum().ite

 13%|█▎        | 126/1000 [00:00<00:05, 174.05it/s]

(feat_acts > 20).sum().item()=30
(feat_acts > 20).sum().item()=14
(feat_acts > 20).sum().item()=437
(feat_acts > 20).sum().item()=13
(feat_acts > 20).sum().item()=240
(feat_acts > 20).sum().item()=135
(feat_acts > 20).sum().item()=24
(feat_acts > 20).sum().item()=46
(feat_acts > 20).sum().item()=40
(feat_acts > 20).sum().item()=53
(feat_acts > 20).sum().item()=9
(feat_acts > 20).sum().item()=164
(feat_acts > 20).sum().item()=141
(feat_acts > 20).sum().item()=23
(feat_acts > 20).sum().item()=489
(feat_acts > 20).sum().item()=149
(feat_acts > 20).sum().item()=112
(feat_acts > 20).sum().item()=15
(feat_acts > 20).sum().item()=416
(feat_acts > 20).sum().item()=95
(feat_acts > 20).sum().item()=39
(feat_acts > 20).sum().item()=275
(feat_acts > 20).sum().item()=282
(feat_acts > 20).sum().item()=484
(feat_acts > 20).sum().item()=92
(feat_acts > 20).sum().item()=35
(feat_acts > 20).sum().item()=190
(feat_acts > 20).sum().item()=11
(feat_acts > 20).sum().item()=95
(feat_acts > 20).sum().item()=8

 16%|█▌        | 162/1000 [00:00<00:04, 174.19it/s]

(feat_acts > 20).sum().item()=42
(feat_acts > 20).sum().item()=124
(feat_acts > 20).sum().item()=143
(feat_acts > 20).sum().item()=45
(feat_acts > 20).sum().item()=33
(feat_acts > 20).sum().item()=201
(feat_acts > 20).sum().item()=280
(feat_acts > 20).sum().item()=71
(feat_acts > 20).sum().item()=34
(feat_acts > 20).sum().item()=604
(feat_acts > 20).sum().item()=308
(feat_acts > 20).sum().item()=84
(feat_acts > 20).sum().item()=72
(feat_acts > 20).sum().item()=334
(feat_acts > 20).sum().item()=874
(feat_acts > 20).sum().item()=143
(feat_acts > 20).sum().item()=4
(feat_acts > 20).sum().item()=61
(feat_acts > 20).sum().item()=310
(feat_acts > 20).sum().item()=58
(feat_acts > 20).sum().item()=32
(feat_acts > 20).sum().item()=155
(feat_acts > 20).sum().item()=254
(feat_acts > 20).sum().item()=42
(feat_acts > 20).sum().item()=23
(feat_acts > 20).sum().item()=153
(feat_acts > 20).sum().item()=392
(feat_acts > 20).sum().item()=53
(feat_acts > 20).sum().item()=32
(feat_acts > 20).sum().item()=

 20%|█▉        | 198/1000 [00:01<00:04, 174.26it/s]

(feat_acts > 20).sum().item()=64
(feat_acts > 20).sum().item()=113
(feat_acts > 20).sum().item()=164
(feat_acts > 20).sum().item()=248
(feat_acts > 20).sum().item()=105
(feat_acts > 20).sum().item()=436
(feat_acts > 20).sum().item()=686
(feat_acts > 20).sum().item()=179
(feat_acts > 20).sum().item()=58
(feat_acts > 20).sum().item()=149
(feat_acts > 20).sum().item()=108
(feat_acts > 20).sum().item()=10
(feat_acts > 20).sum().item()=30
(feat_acts > 20).sum().item()=89
(feat_acts > 20).sum().item()=80
(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=150
(feat_acts > 20).sum().item()=949
(feat_acts > 20).sum().item()=92
(feat_acts > 20).sum().item()=58
(feat_acts > 20).sum().item()=19
(feat_acts > 20).sum().item()=41
(feat_acts > 20).sum().item()=98
(feat_acts > 20).sum().item()=103
(feat_acts > 20).sum().item()=81
(feat_acts > 20).sum().item()=46
(feat_acts > 20).sum().item()=64
(feat_acts > 20).sum().item()=271
(feat_acts > 20).sum().item()=92
(feat_acts > 20).sum().item()=

 23%|██▎       | 234/1000 [00:01<00:04, 174.42it/s]

(feat_acts > 20).sum().item()=378
(feat_acts > 20).sum().item()=9
(feat_acts > 20).sum().item()=14
(feat_acts > 20).sum().item()=18
(feat_acts > 20).sum().item()=145
(feat_acts > 20).sum().item()=272
(feat_acts > 20).sum().item()=483
(feat_acts > 20).sum().item()=56
(feat_acts > 20).sum().item()=268
(feat_acts > 20).sum().item()=158
(feat_acts > 20).sum().item()=176
(feat_acts > 20).sum().item()=204
(feat_acts > 20).sum().item()=54
(feat_acts > 20).sum().item()=107
(feat_acts > 20).sum().item()=34
(feat_acts > 20).sum().item()=73
(feat_acts > 20).sum().item()=250
(feat_acts > 20).sum().item()=42
(feat_acts > 20).sum().item()=70
(feat_acts > 20).sum().item()=98
(feat_acts > 20).sum().item()=21
(feat_acts > 20).sum().item()=145
(feat_acts > 20).sum().item()=12
(feat_acts > 20).sum().item()=29
(feat_acts > 20).sum().item()=40
(feat_acts > 20).sum().item()=199
(feat_acts > 20).sum().item()=320
(feat_acts > 20).sum().item()=275
(feat_acts > 20).sum().item()=68
(feat_acts > 20).sum().item()=

 27%|██▋       | 270/1000 [00:01<00:04, 174.60it/s]

(feat_acts > 20).sum().item()=391
(feat_acts > 20).sum().item()=179
(feat_acts > 20).sum().item()=16
(feat_acts > 20).sum().item()=7
(feat_acts > 20).sum().item()=58
(feat_acts > 20).sum().item()=196
(feat_acts > 20).sum().item()=400
(feat_acts > 20).sum().item()=84
(feat_acts > 20).sum().item()=177
(feat_acts > 20).sum().item()=151
(feat_acts > 20).sum().item()=20
(feat_acts > 20).sum().item()=806
(feat_acts > 20).sum().item()=178
(feat_acts > 20).sum().item()=11
(feat_acts > 20).sum().item()=176
(feat_acts > 20).sum().item()=79
(feat_acts > 20).sum().item()=8
(feat_acts > 20).sum().item()=253
(feat_acts > 20).sum().item()=258
(feat_acts > 20).sum().item()=36
(feat_acts > 20).sum().item()=271
(feat_acts > 20).sum().item()=97
(feat_acts > 20).sum().item()=243
(feat_acts > 20).sum().item()=134
(feat_acts > 20).sum().item()=460
(feat_acts > 20).sum().item()=206
(feat_acts > 20).sum().item()=561
(feat_acts > 20).sum().item()=22
(feat_acts > 20).sum().item()=376
(feat_acts > 20).sum().item

 31%|███       | 306/1000 [00:01<00:03, 174.27it/s]

(feat_acts > 20).sum().item()=168
(feat_acts > 20).sum().item()=100
(feat_acts > 20).sum().item()=378
(feat_acts > 20).sum().item()=11
(feat_acts > 20).sum().item()=111
(feat_acts > 20).sum().item()=105
(feat_acts > 20).sum().item()=88
(feat_acts > 20).sum().item()=300
(feat_acts > 20).sum().item()=47
(feat_acts > 20).sum().item()=76
(feat_acts > 20).sum().item()=49
(feat_acts > 20).sum().item()=332
(feat_acts > 20).sum().item()=173
(feat_acts > 20).sum().item()=254
(feat_acts > 20).sum().item()=7
(feat_acts > 20).sum().item()=45
(feat_acts > 20).sum().item()=117
(feat_acts > 20).sum().item()=94
(feat_acts > 20).sum().item()=18
(feat_acts > 20).sum().item()=29
(feat_acts > 20).sum().item()=402
(feat_acts > 20).sum().item()=35
(feat_acts > 20).sum().item()=302
(feat_acts > 20).sum().item()=255
(feat_acts > 20).sum().item()=168
(feat_acts > 20).sum().item()=137
(feat_acts > 20).sum().item()=74
(feat_acts > 20).sum().item()=22
(feat_acts > 20).sum().item()=29
(feat_acts > 20).sum().item()

 34%|███▍      | 342/1000 [00:01<00:03, 173.95it/s]

(feat_acts > 20).sum().item()=159
(feat_acts > 20).sum().item()=69
(feat_acts > 20).sum().item()=277
(feat_acts > 20).sum().item()=177
(feat_acts > 20).sum().item()=47
(feat_acts > 20).sum().item()=363
(feat_acts > 20).sum().item()=19
(feat_acts > 20).sum().item()=65
(feat_acts > 20).sum().item()=144
(feat_acts > 20).sum().item()=153
(feat_acts > 20).sum().item()=15
(feat_acts > 20).sum().item()=74
(feat_acts > 20).sum().item()=97
(feat_acts > 20).sum().item()=1192
(feat_acts > 20).sum().item()=334
(feat_acts > 20).sum().item()=308
(feat_acts > 20).sum().item()=144
(feat_acts > 20).sum().item()=37
(feat_acts > 20).sum().item()=29
(feat_acts > 20).sum().item()=221
(feat_acts > 20).sum().item()=99
(feat_acts > 20).sum().item()=147
(feat_acts > 20).sum().item()=26
(feat_acts > 20).sum().item()=346
(feat_acts > 20).sum().item()=1101
(feat_acts > 20).sum().item()=49
(feat_acts > 20).sum().item()=148
(feat_acts > 20).sum().item()=44
(feat_acts > 20).sum().item()=35
(feat_acts > 20).sum().ite

 38%|███▊      | 378/1000 [00:02<00:03, 174.03it/s]

(feat_acts > 20).sum().item()=218
(feat_acts > 20).sum().item()=43
(feat_acts > 20).sum().item()=159
(feat_acts > 20).sum().item()=138
(feat_acts > 20).sum().item()=120
(feat_acts > 20).sum().item()=70
(feat_acts > 20).sum().item()=175
(feat_acts > 20).sum().item()=54
(feat_acts > 20).sum().item()=283
(feat_acts > 20).sum().item()=251
(feat_acts > 20).sum().item()=94
(feat_acts > 20).sum().item()=315
(feat_acts > 20).sum().item()=145
(feat_acts > 20).sum().item()=43
(feat_acts > 20).sum().item()=275
(feat_acts > 20).sum().item()=522
(feat_acts > 20).sum().item()=43
(feat_acts > 20).sum().item()=36
(feat_acts > 20).sum().item()=271
(feat_acts > 20).sum().item()=126
(feat_acts > 20).sum().item()=130
(feat_acts > 20).sum().item()=7
(feat_acts > 20).sum().item()=61
(feat_acts > 20).sum().item()=44
(feat_acts > 20).sum().item()=32
(feat_acts > 20).sum().item()=236
(feat_acts > 20).sum().item()=201
(feat_acts > 20).sum().item()=115
(feat_acts > 20).sum().item()=51
(feat_acts > 20).sum().item

 41%|████▏     | 414/1000 [00:02<00:03, 173.57it/s]

(feat_acts > 20).sum().item()=42
(feat_acts > 20).sum().item()=12
(feat_acts > 20).sum().item()=55
(feat_acts > 20).sum().item()=14
(feat_acts > 20).sum().item()=11
(feat_acts > 20).sum().item()=859
(feat_acts > 20).sum().item()=569
(feat_acts > 20).sum().item()=212
(feat_acts > 20).sum().item()=52
(feat_acts > 20).sum().item()=86
(feat_acts > 20).sum().item()=517
(feat_acts > 20).sum().item()=236
(feat_acts > 20).sum().item()=61
(feat_acts > 20).sum().item()=85
(feat_acts > 20).sum().item()=65
(feat_acts > 20).sum().item()=20
(feat_acts > 20).sum().item()=150
(feat_acts > 20).sum().item()=8
(feat_acts > 20).sum().item()=91
(feat_acts > 20).sum().item()=94
(feat_acts > 20).sum().item()=42
(feat_acts > 20).sum().item()=286
(feat_acts > 20).sum().item()=310
(feat_acts > 20).sum().item()=125
(feat_acts > 20).sum().item()=699
(feat_acts > 20).sum().item()=99
(feat_acts > 20).sum().item()=427
(feat_acts > 20).sum().item()=48
(feat_acts > 20).sum().item()=63
(feat_acts > 20).sum().item()=176

 45%|████▌     | 450/1000 [00:02<00:03, 170.38it/s]

(feat_acts > 20).sum().item()=162
(feat_acts > 20).sum().item()=646
(feat_acts > 20).sum().item()=388
(feat_acts > 20).sum().item()=71
(feat_acts > 20).sum().item()=224
(feat_acts > 20).sum().item()=4
(feat_acts > 20).sum().item()=116
(feat_acts > 20).sum().item()=227
(feat_acts > 20).sum().item()=107
(feat_acts > 20).sum().item()=250
(feat_acts > 20).sum().item()=85
(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=20
(feat_acts > 20).sum().item()=512
(feat_acts > 20).sum().item()=100
(feat_acts > 20).sum().item()=21
(feat_acts > 20).sum().item()=18
(feat_acts > 20).sum().item()=20
(feat_acts > 20).sum().item()=428
(feat_acts > 20).sum().item()=276
(feat_acts > 20).sum().item()=61
(feat_acts > 20).sum().item()=85
(feat_acts > 20).sum().item()=179
(feat_acts > 20).sum().item()=23
(feat_acts > 20).sum().item()=168
(feat_acts > 20).sum().item()=572
(feat_acts > 20).sum().item()=5
(feat_acts > 20).sum().item()=82
(feat_acts > 20).sum().item()=73
(feat_acts > 20).sum().item()=

 49%|████▊     | 486/1000 [00:02<00:03, 170.35it/s]

(feat_acts > 20).sum().item()=153
(feat_acts > 20).sum().item()=45
(feat_acts > 20).sum().item()=616
(feat_acts > 20).sum().item()=156
(feat_acts > 20).sum().item()=62
(feat_acts > 20).sum().item()=338
(feat_acts > 20).sum().item()=7
(feat_acts > 20).sum().item()=67
(feat_acts > 20).sum().item()=258
(feat_acts > 20).sum().item()=178
(feat_acts > 20).sum().item()=331
(feat_acts > 20).sum().item()=95
(feat_acts > 20).sum().item()=51
(feat_acts > 20).sum().item()=54
(feat_acts > 20).sum().item()=7
(feat_acts > 20).sum().item()=170
(feat_acts > 20).sum().item()=80
(feat_acts > 20).sum().item()=363
(feat_acts > 20).sum().item()=100
(feat_acts > 20).sum().item()=13
(feat_acts > 20).sum().item()=38
(feat_acts > 20).sum().item()=139
(feat_acts > 20).sum().item()=327
(feat_acts > 20).sum().item()=177
(feat_acts > 20).sum().item()=147
(feat_acts > 20).sum().item()=21
(feat_acts > 20).sum().item()=180
(feat_acts > 20).sum().item()=115
(feat_acts > 20).sum().item()=49
(feat_acts > 20).sum().item()

 52%|█████▏    | 522/1000 [00:03<00:02, 172.88it/s]

(feat_acts > 20).sum().item()=99
(feat_acts > 20).sum().item()=220
(feat_acts > 20).sum().item()=14
(feat_acts > 20).sum().item()=16
(feat_acts > 20).sum().item()=150
(feat_acts > 20).sum().item()=223
(feat_acts > 20).sum().item()=9
(feat_acts > 20).sum().item()=170
(feat_acts > 20).sum().item()=12
(feat_acts > 20).sum().item()=86
(feat_acts > 20).sum().item()=26
(feat_acts > 20).sum().item()=67
(feat_acts > 20).sum().item()=913
(feat_acts > 20).sum().item()=138
(feat_acts > 20).sum().item()=165
(feat_acts > 20).sum().item()=101
(feat_acts > 20).sum().item()=186
(feat_acts > 20).sum().item()=27
(feat_acts > 20).sum().item()=8
(feat_acts > 20).sum().item()=58
(feat_acts > 20).sum().item()=260
(feat_acts > 20).sum().item()=162
(feat_acts > 20).sum().item()=14
(feat_acts > 20).sum().item()=176
(feat_acts > 20).sum().item()=616
(feat_acts > 20).sum().item()=48
(feat_acts > 20).sum().item()=246
(feat_acts > 20).sum().item()=1847
(feat_acts > 20).sum().item()=199
(feat_acts > 20).sum().item(

 56%|█████▌    | 558/1000 [00:03<00:02, 170.86it/s]

(feat_acts > 20).sum().item()=113
(feat_acts > 20).sum().item()=270
(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=217
(feat_acts > 20).sum().item()=458
(feat_acts > 20).sum().item()=106
(feat_acts > 20).sum().item()=18
(feat_acts > 20).sum().item()=945
(feat_acts > 20).sum().item()=256
(feat_acts > 20).sum().item()=14
(feat_acts > 20).sum().item()=124
(feat_acts > 20).sum().item()=48
(feat_acts > 20).sum().item()=227
(feat_acts > 20).sum().item()=1051
(feat_acts > 20).sum().item()=22
(feat_acts > 20).sum().item()=199
(feat_acts > 20).sum().item()=19
(feat_acts > 20).sum().item()=70
(feat_acts > 20).sum().item()=25
(feat_acts > 20).sum().item()=96
(feat_acts > 20).sum().item()=208
(feat_acts > 20).sum().item()=222
(feat_acts > 20).sum().item()=410
(feat_acts > 20).sum().item()=11
(feat_acts > 20).sum().item()=557
(feat_acts > 20).sum().item()=444
(feat_acts > 20).sum().item()=127
(feat_acts > 20).sum().item()=123
(feat_acts > 20).sum().item()=213
(feat_acts > 20).sum().

 59%|█████▉    | 594/1000 [00:03<00:02, 171.68it/s]

(feat_acts > 20).sum().item()=353
(feat_acts > 20).sum().item()=83
(feat_acts > 20).sum().item()=284
(feat_acts > 20).sum().item()=27
(feat_acts > 20).sum().item()=17
(feat_acts > 20).sum().item()=17
(feat_acts > 20).sum().item()=192
(feat_acts > 20).sum().item()=21
(feat_acts > 20).sum().item()=276
(feat_acts > 20).sum().item()=72
(feat_acts > 20).sum().item()=30
(feat_acts > 20).sum().item()=246
(feat_acts > 20).sum().item()=1201
(feat_acts > 20).sum().item()=138
(feat_acts > 20).sum().item()=39
(feat_acts > 20).sum().item()=16
(feat_acts > 20).sum().item()=506
(feat_acts > 20).sum().item()=30
(feat_acts > 20).sum().item()=220
(feat_acts > 20).sum().item()=170
(feat_acts > 20).sum().item()=119
(feat_acts > 20).sum().item()=34
(feat_acts > 20).sum().item()=57
(feat_acts > 20).sum().item()=61
(feat_acts > 20).sum().item()=45
(feat_acts > 20).sum().item()=7
(feat_acts > 20).sum().item()=440
(feat_acts > 20).sum().item()=21
(feat_acts > 20).sum().item()=246
(feat_acts > 20).sum().item()=

 63%|██████▎   | 630/1000 [00:03<00:02, 171.63it/s]

(feat_acts > 20).sum().item()=264
(feat_acts > 20).sum().item()=30
(feat_acts > 20).sum().item()=55
(feat_acts > 20).sum().item()=13
(feat_acts > 20).sum().item()=27
(feat_acts > 20).sum().item()=141
(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=7
(feat_acts > 20).sum().item()=169
(feat_acts > 20).sum().item()=110
(feat_acts > 20).sum().item()=65
(feat_acts > 20).sum().item()=66
(feat_acts > 20).sum().item()=375
(feat_acts > 20).sum().item()=262
(feat_acts > 20).sum().item()=83
(feat_acts > 20).sum().item()=70
(feat_acts > 20).sum().item()=274
(feat_acts > 20).sum().item()=121
(feat_acts > 20).sum().item()=20
(feat_acts > 20).sum().item()=125
(feat_acts > 20).sum().item()=144
(feat_acts > 20).sum().item()=51
(feat_acts > 20).sum().item()=161
(feat_acts > 20).sum().item()=254
(feat_acts > 20).sum().item()=284
(feat_acts > 20).sum().item()=33
(feat_acts > 20).sum().item()=142
(feat_acts > 20).sum().item()=257
(feat_acts > 20).sum().item()=26
(feat_acts > 20).sum().item()

 67%|██████▋   | 666/1000 [00:03<00:01, 168.44it/s]

(feat_acts > 20).sum().item()=107
(feat_acts > 20).sum().item()=385
(feat_acts > 20).sum().item()=22
(feat_acts > 20).sum().item()=12
(feat_acts > 20).sum().item()=223
(feat_acts > 20).sum().item()=304
(feat_acts > 20).sum().item()=76
(feat_acts > 20).sum().item()=83
(feat_acts > 20).sum().item()=256
(feat_acts > 20).sum().item()=135
(feat_acts > 20).sum().item()=25
(feat_acts > 20).sum().item()=383
(feat_acts > 20).sum().item()=92
(feat_acts > 20).sum().item()=234
(feat_acts > 20).sum().item()=227
(feat_acts > 20).sum().item()=84
(feat_acts > 20).sum().item()=19
(feat_acts > 20).sum().item()=127
(feat_acts > 20).sum().item()=141
(feat_acts > 20).sum().item()=23
(feat_acts > 20).sum().item()=327
(feat_acts > 20).sum().item()=361
(feat_acts > 20).sum().item()=153
(feat_acts > 20).sum().item()=367
(feat_acts > 20).sum().item()=148
(feat_acts > 20).sum().item()=95
(feat_acts > 20).sum().item()=63
(feat_acts > 20).sum().item()=191
(feat_acts > 20).sum().item()=19
(feat_acts > 20).sum().ite

 70%|███████   | 702/1000 [00:04<00:01, 171.17it/s]

(feat_acts > 20).sum().item()=5
(feat_acts > 20).sum().item()=26
(feat_acts > 20).sum().item()=74
(feat_acts > 20).sum().item()=64
(feat_acts > 20).sum().item()=90
(feat_acts > 20).sum().item()=148
(feat_acts > 20).sum().item()=7
(feat_acts > 20).sum().item()=89
(feat_acts > 20).sum().item()=93
(feat_acts > 20).sum().item()=252
(feat_acts > 20).sum().item()=54
(feat_acts > 20).sum().item()=278
(feat_acts > 20).sum().item()=26
(feat_acts > 20).sum().item()=497
(feat_acts > 20).sum().item()=59
(feat_acts > 20).sum().item()=56
(feat_acts > 20).sum().item()=428
(feat_acts > 20).sum().item()=26
(feat_acts > 20).sum().item()=335
(feat_acts > 20).sum().item()=396
(feat_acts > 20).sum().item()=136
(feat_acts > 20).sum().item()=29
(feat_acts > 20).sum().item()=124
(feat_acts > 20).sum().item()=191
(feat_acts > 20).sum().item()=121
(feat_acts > 20).sum().item()=26
(feat_acts > 20).sum().item()=63
(feat_acts > 20).sum().item()=32
(feat_acts > 20).sum().item()=64
(feat_acts > 20).sum().item()=13
(

 74%|███████▍  | 738/1000 [00:04<00:01, 172.67it/s]

(feat_acts > 20).sum().item()=183
(feat_acts > 20).sum().item()=107
(feat_acts > 20).sum().item()=278
(feat_acts > 20).sum().item()=238
(feat_acts > 20).sum().item()=94
(feat_acts > 20).sum().item()=70
(feat_acts > 20).sum().item()=162
(feat_acts > 20).sum().item()=292
(feat_acts > 20).sum().item()=626
(feat_acts > 20).sum().item()=201
(feat_acts > 20).sum().item()=159
(feat_acts > 20).sum().item()=190
(feat_acts > 20).sum().item()=133
(feat_acts > 20).sum().item()=49
(feat_acts > 20).sum().item()=8
(feat_acts > 20).sum().item()=6
(feat_acts > 20).sum().item()=706
(feat_acts > 20).sum().item()=178
(feat_acts > 20).sum().item()=71
(feat_acts > 20).sum().item()=78
(feat_acts > 20).sum().item()=21
(feat_acts > 20).sum().item()=124
(feat_acts > 20).sum().item()=69
(feat_acts > 20).sum().item()=65
(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=38
(feat_acts > 20).sum().item()=1143
(feat_acts > 20).sum().item()=23
(feat_acts > 20).sum().item()=313
(feat_acts > 20).sum().item(

 77%|███████▋  | 774/1000 [00:04<00:01, 173.82it/s]

(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=122
(feat_acts > 20).sum().item()=58
(feat_acts > 20).sum().item()=50
(feat_acts > 20).sum().item()=52
(feat_acts > 20).sum().item()=125
(feat_acts > 20).sum().item()=14
(feat_acts > 20).sum().item()=60
(feat_acts > 20).sum().item()=736
(feat_acts > 20).sum().item()=11
(feat_acts > 20).sum().item()=611
(feat_acts > 20).sum().item()=64
(feat_acts > 20).sum().item()=31
(feat_acts > 20).sum().item()=40
(feat_acts > 20).sum().item()=89
(feat_acts > 20).sum().item()=252
(feat_acts > 20).sum().item()=158
(feat_acts > 20).sum().item()=36
(feat_acts > 20).sum().item()=395
(feat_acts > 20).sum().item()=122
(feat_acts > 20).sum().item()=17
(feat_acts > 20).sum().item()=403
(feat_acts > 20).sum().item()=319
(feat_acts > 20).sum().item()=4
(feat_acts > 20).sum().item()=288
(feat_acts > 20).sum().item()=191
(feat_acts > 20).sum().item()=8
(feat_acts > 20).sum().item()=58
(feat_acts > 20).sum().item()=35
(feat_acts > 20).sum().item()=322

 81%|████████  | 810/1000 [00:04<00:01, 169.05it/s]

(feat_acts > 20).sum().item()=457
(feat_acts > 20).sum().item()=17
(feat_acts > 20).sum().item()=15
(feat_acts > 20).sum().item()=31
(feat_acts > 20).sum().item()=96
(feat_acts > 20).sum().item()=67
(feat_acts > 20).sum().item()=100
(feat_acts > 20).sum().item()=359
(feat_acts > 20).sum().item()=60
(feat_acts > 20).sum().item()=264
(feat_acts > 20).sum().item()=521
(feat_acts > 20).sum().item()=48
(feat_acts > 20).sum().item()=71
(feat_acts > 20).sum().item()=412
(feat_acts > 20).sum().item()=46
(feat_acts > 20).sum().item()=30
(feat_acts > 20).sum().item()=264
(feat_acts > 20).sum().item()=106
(feat_acts > 20).sum().item()=130
(feat_acts > 20).sum().item()=186
(feat_acts > 20).sum().item()=23
(feat_acts > 20).sum().item()=150
(feat_acts > 20).sum().item()=87
(feat_acts > 20).sum().item()=87
(feat_acts > 20).sum().item()=108
(feat_acts > 20).sum().item()=48
(feat_acts > 20).sum().item()=118
(feat_acts > 20).sum().item()=382
(feat_acts > 20).sum().item()=30
(feat_acts > 20).sum().item()

 84%|████████▍ | 845/1000 [00:04<00:00, 170.28it/s]

(feat_acts > 20).sum().item()=405
(feat_acts > 20).sum().item()=23
(feat_acts > 20).sum().item()=123
(feat_acts > 20).sum().item()=156
(feat_acts > 20).sum().item()=22
(feat_acts > 20).sum().item()=155
(feat_acts > 20).sum().item()=29
(feat_acts > 20).sum().item()=51
(feat_acts > 20).sum().item()=1827
(feat_acts > 20).sum().item()=59
(feat_acts > 20).sum().item()=284
(feat_acts > 20).sum().item()=261
(feat_acts > 20).sum().item()=23
(feat_acts > 20).sum().item()=219
(feat_acts > 20).sum().item()=235
(feat_acts > 20).sum().item()=215
(feat_acts > 20).sum().item()=138
(feat_acts > 20).sum().item()=283
(feat_acts > 20).sum().item()=182
(feat_acts > 20).sum().item()=223
(feat_acts > 20).sum().item()=20
(feat_acts > 20).sum().item()=37
(feat_acts > 20).sum().item()=829
(feat_acts > 20).sum().item()=869
(feat_acts > 20).sum().item()=131
(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=231
(feat_acts > 20).sum().item()=266
(feat_acts > 20).sum().item()=44
(feat_acts > 20).sum().

 86%|████████▋ | 863/1000 [00:05<00:00, 171.23it/s]

(feat_acts > 20).sum().item()=529
(feat_acts > 20).sum().item()=362
(feat_acts > 20).sum().item()=17
(feat_acts > 20).sum().item()=11
(feat_acts > 20).sum().item()=135
(feat_acts > 20).sum().item()=1565
(feat_acts > 20).sum().item()=97
(feat_acts > 20).sum().item()=102
(feat_acts > 20).sum().item()=77
(feat_acts > 20).sum().item()=480
(feat_acts > 20).sum().item()=99
(feat_acts > 20).sum().item()=91
(feat_acts > 20).sum().item()=509
(feat_acts > 20).sum().item()=654
(feat_acts > 20).sum().item()=4
(feat_acts > 20).sum().item()=150
(feat_acts > 20).sum().item()=15
(feat_acts > 20).sum().item()=1
(feat_acts > 20).sum().item()=5
(feat_acts > 20).sum().item()=12
(feat_acts > 20).sum().item()=20
(feat_acts > 20).sum().item()=326
(feat_acts > 20).sum().item()=78
(feat_acts > 20).sum().item()=250
(feat_acts > 20).sum().item()=57
(feat_acts > 20).sum().item()=709
(feat_acts > 20).sum().item()=3
(feat_acts > 20).sum().item()=68
(feat_acts > 20).sum().item()=163
(feat_acts > 20).sum().item()=11


 90%|████████▉ | 899/1000 [00:05<00:00, 171.84it/s]

(feat_acts > 20).sum().item()=135
(feat_acts > 20).sum().item()=36
(feat_acts > 20).sum().item()=390
(feat_acts > 20).sum().item()=54
(feat_acts > 20).sum().item()=148
(feat_acts > 20).sum().item()=409
(feat_acts > 20).sum().item()=5
(feat_acts > 20).sum().item()=502
(feat_acts > 20).sum().item()=58
(feat_acts > 20).sum().item()=72
(feat_acts > 20).sum().item()=52
(feat_acts > 20).sum().item()=424
(feat_acts > 20).sum().item()=117
(feat_acts > 20).sum().item()=46
(feat_acts > 20).sum().item()=105
(feat_acts > 20).sum().item()=489
(feat_acts > 20).sum().item()=85
(feat_acts > 20).sum().item()=34
(feat_acts > 20).sum().item()=251
(feat_acts > 20).sum().item()=19
(feat_acts > 20).sum().item()=330
(feat_acts > 20).sum().item()=5
(feat_acts > 20).sum().item()=201
(feat_acts > 20).sum().item()=25
(feat_acts > 20).sum().item()=25
(feat_acts > 20).sum().item()=89
(feat_acts > 20).sum().item()=33
(feat_acts > 20).sum().item()=55
(feat_acts > 20).sum().item()=9
(feat_acts > 20).sum().item()=51
(

 94%|█████████▎| 935/1000 [00:05<00:00, 172.47it/s]

(feat_acts > 20).sum().item()=2
(feat_acts > 20).sum().item()=101
(feat_acts > 20).sum().item()=223
(feat_acts > 20).sum().item()=17
(feat_acts > 20).sum().item()=821
(feat_acts > 20).sum().item()=315
(feat_acts > 20).sum().item()=48
(feat_acts > 20).sum().item()=240
(feat_acts > 20).sum().item()=129
(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=183
(feat_acts > 20).sum().item()=108
(feat_acts > 20).sum().item()=12
(feat_acts > 20).sum().item()=16
(feat_acts > 20).sum().item()=231
(feat_acts > 20).sum().item()=213
(feat_acts > 20).sum().item()=95
(feat_acts > 20).sum().item()=64
(feat_acts > 20).sum().item()=75
(feat_acts > 20).sum().item()=196
(feat_acts > 20).sum().item()=396
(feat_acts > 20).sum().item()=104
(feat_acts > 20).sum().item()=185
(feat_acts > 20).sum().item()=52
(feat_acts > 20).sum().item()=71
(feat_acts > 20).sum().item()=43
(feat_acts > 20).sum().item()=233
(feat_acts > 20).sum().item()=62
(feat_acts > 20).sum().item()=17
(feat_acts > 20).sum().item()

 97%|█████████▋| 971/1000 [00:05<00:00, 172.95it/s]

(feat_acts > 20).sum().item()=79
(feat_acts > 20).sum().item()=402
(feat_acts > 20).sum().item()=80
(feat_acts > 20).sum().item()=48
(feat_acts > 20).sum().item()=35
(feat_acts > 20).sum().item()=194
(feat_acts > 20).sum().item()=80
(feat_acts > 20).sum().item()=15
(feat_acts > 20).sum().item()=1570
(feat_acts > 20).sum().item()=48
(feat_acts > 20).sum().item()=152
(feat_acts > 20).sum().item()=249
(feat_acts > 20).sum().item()=76
(feat_acts > 20).sum().item()=72
(feat_acts > 20).sum().item()=429
(feat_acts > 20).sum().item()=295
(feat_acts > 20).sum().item()=66
(feat_acts > 20).sum().item()=106
(feat_acts > 20).sum().item()=112
(feat_acts > 20).sum().item()=441
(feat_acts > 20).sum().item()=154
(feat_acts > 20).sum().item()=28
(feat_acts > 20).sum().item()=298
(feat_acts > 20).sum().item()=124
(feat_acts > 20).sum().item()=24
(feat_acts > 20).sum().item()=56
(feat_acts > 20).sum().item()=469
(feat_acts > 20).sum().item()=200
(feat_acts > 20).sum().item()=213
(feat_acts > 20).sum().ite

100%|██████████| 1000/1000 [00:05<00:00, 172.42it/s]

(feat_acts > 20).sum().item()=93
(feat_acts > 20).sum().item()=1149
(feat_acts > 20).sum().item()=258
(feat_acts > 20).sum().item()=299
(feat_acts > 20).sum().item()=145
(feat_acts > 20).sum().item()=68
(feat_acts > 20).sum().item()=390
(feat_acts > 20).sum().item()=165
(feat_acts > 20).sum().item()=146
(feat_acts > 20).sum().item()=161
(feat_acts > 20).sum().item()=232
(feat_acts > 20).sum().item()=148
(feat_acts > 20).sum().item()=459
(feat_acts > 20).sum().item()=321
(feat_acts > 20).sum().item()=66





In [None]:
from collections import Counter
import torch
import numpy as np
from sae_auto_interp.autoencoders.OpenAI.model import Autoencoder
from itertools import product
from tqdm import tqdm
import time

weight_dir = "/mnt/ssd-1/gpaulo/SAE-Zoology/weights/gpt2_128k"

all_results = []

feat_idxs = list(range(10))
feat_layers = [2, 6, 11]
total_iterations = len(feat_idxs) * len(feat_layers)
for feat_idx, feat_layer in tqdm(product(feat_idxs, feat_layers), total=total_iterations):
    scorer_intervention_strengths = [10, 32, 100, 320, 1000]
    explainer_intervention_strength = 32

    print("Loading autoencoder...", end="")
    path = f"{weight_dir}/{layer}.pt"
    state_dict = torch.load(path)
    ae = Autoencoder.from_state_dict(state_dict=state_dict)
    feat = ae.decoder.weight[:, feat_idx].to(subject_device)
    encoder_feat = ae.encoder.weight[feat_idx, :].to(subject_device)
    # ae.encoder.to(device)
    # ae.activation.to(device)

    ### Find examples where the feature activates
    # Remove any hooks
    for l in range(len(subject.transformer.h)):
        subject.transformer.h[l]._forward_hooks.clear()
    print("done")

    subtexts = []
    subtext_acts = []
    for text in tqdm(candidate_texts, total=len(candidate_texts)):
        input_ids = subject_tokenizer(text, return_tensors="pt").input_ids.to(subject_device)
        with torch.inference_mode():
            out = subject(input_ids, output_hidden_states=True)
            # hidden_states is actually one longer than the number of layers, because it includes the input embeddings
            h = out.hidden_states[layer + 1].squeeze(0)
            # feat_acts = ae.activation(ae.encoder(h))[:, feat_idx]
            feat_acts = h @ encoder_feat
            # the first token position just has way higher norm all the time for some reason
            feat_acts[0] = 0

        for i in range(1, len(feat_acts) + 1):
            reassembled_text = subject_tokenizer.decode(input_ids[0, :i])
            subtexts.append(reassembled_text)
            subtext_acts.append(feat_acts[i - 1].item())

    del ae
    # get a random sample of activating contexts
    subtext_acts = torch.tensor(subtext_acts)
    n_candidates = 500
    candidate_indices = subtext_acts.topk(n_candidates).indices
    sampled_indices = np.random.choice(candidate_indices.numpy(), n_scorer_texts + n_explainer_texts, replace=False)
    
    # Get top k subtexts and their activations
    sampled_subtexts = [subtexts[i] for i in sampled_indices]
    sampled_activations = subtext_acts[sampled_indices]

    # Print top k results
    print("Top subtexts with highest feature activation:")
    for i, (subtext, activation) in enumerate(zip(sampled_subtexts, sampled_activations), 1):
        print(f"{i}. Activation: {activation:.4f}")
        print(f"   Text: {subtext}")
        print()

    random.shuffle(sampled_subtexts)  # just as assurance
    scorer_texts = sampled_subtexts[:n_scorer_texts]
    explainer_texts = sampled_subtexts[n_scorer_texts:]

    # get explanation
    def get_subject_logits(text, layer, intervention_strength=0.0, position=-1):
        for l in range(len(subject.transformer.h)):
            subject.transformer.h[l]._forward_hooks.clear()
        subject.transformer.h[layer].register_forward_hook(partial(intervene, intervention_strength=intervention_strength, position=-1))

        inputs = subject_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(subject_device)
        with torch.inference_mode():
            outputs = subject(**inputs)

        return outputs.logits[0, -1, :]

    intervention_examples = []
    for text in explainer_texts:
        clean_logits = get_subject_logits(text, feat_layer, intervention_strength=0.0)
        intervened_logits = get_subject_logits(text, feat_layer, intervention_strength=explainer_intervention_strength)
        top_probs = (intervened_logits.softmax(dim=-1) - clean_logits.softmax(dim=-1)).topk(n_intervention_examples)
        
        top_tokens = [subject_tokenizer.decode(i) for i in top_probs.indices]
        top_p_increases = top_probs.values.tolist()
        intervention_examples.append(
            ExplainerInterventionExample(
                prompt=text,
                top_tokens=top_tokens,
                top_p_increases=top_p_increases
            )
        )

    neuron_prompter = ExplainerNeuronFormatter(
        intervention_examples=intervention_examples
    )

    # TODO: improve the few-shot examples
    explainer_prompt = get_explainer_prompt(neuron_prompter, fs_examples)
    explainer_input_ids = explainer_tokenizer(explainer_prompt, return_tensors="pt").input_ids.to(explainer_device)
    with torch.inference_mode():
        samples = explainer.generate(explainer_input_ids, max_new_tokens=100, eos_token_id=explainer_tokenizer.encode("\n\n")[-1], num_return_sequences=5)[:, explainer_input_ids.shape[1]:]
    explanations = Counter([explainer_tokenizer.decode(sample).split("\n\n")[0].strip() for sample in samples])
    explanation = explanations.most_common(1)[0][0]
    print(explanations)

    predictiveness_scores = []
    max_intervened_probs = []
    for scorer_intervention_strength in tqdm(scorer_intervention_strengths):
        
        predictiveness_score = torch.tensor(0.0, device=scorer_device)
        max_intervened_prob = 0.0
        
        for text in scorer_texts:
            
            intervened_probs = get_subject_logits(text, feat_layer, intervention_strength=scorer_intervention_strength).softmax(dim=-1).to(scorer_device)
            max_intervened_prob = max(max_intervened_prob, intervened_probs.max().item())

            # get the explanation predictiveness
            scorer_predictiveness_prompt = get_scorer_predictiveness_prompt(text, explanation, few_shot_prompts, few_shot_explanations, few_shot_tokens)
            scorer_input_ids = scorer_tokenizer(scorer_predictiveness_prompt, return_tensors="pt").input_ids.to(scorer_device)
            with torch.inference_mode():
                scorer_logits = scorer(scorer_input_ids).logits[0, -1, :]
                scorer_logp = scorer_logits.log_softmax(dim=-1)
            
            predictiveness_score += (intervened_probs[subject_ids] * scorer_logp[scorer_ids]).sum()

        max_intervened_probs.append(max_intervened_prob)
        predictiveness_scores.append(predictiveness_score.item() / len(scorer_texts))
        
    predictiveness_score = sum(predictiveness_scores) / len(predictiveness_scores)
    max_intervened_prob = max(max_intervened_probs)
    all_results.append({
        "feat_idx": feat_idx,
        "feat_layer": feat_layer,
        "explanation": explanation,
        "predictiveness_score": predictiveness_score,
        "intervention_examples": intervention_examples,
        "max_intervened_prob": max_intervened_prob,
        "scorer_intervention_strengths": scorer_intervention_strengths,
        "explainer_intervention_strength": explainer_intervention_strength,
        "scorer_texts": scorer_texts,
        "explainer_texts": explainer_texts,
        "predictiveness_scores": predictiveness_scores,
        "max_intervened_probs": max_intervened_probs,
    })
all_results

  0%|          | 0/30 [00:00<?, ?it/s]

Loading autoencoder...done




In [None]:
all_df = pd.DataFrame(all_results)
all_df = all_df.sort_values("predictiveness_score", ascending=False)
# all_df.to_pickle(f"counterfactual_results/3layers_200feats.pkl")
all_df

Unnamed: 0,feat_idx,feat_layer,explanation,predictiveness_score,intervention_examples,max_intervened_prob,scorer_intervention_strengths,explainer_intervention_strength,scorer_texts,explainer_texts,predictiveness_scores,max_intervened_probs
1,10000,6,commas,-9.046546,[ExplainerInterventionExample(prompt='Larry Sh...,0.989414,"[10, 32, 100, 320, 1000]",32,[1998 Piberstein Styrian Open – Singles\n\nBar...,[Larry Sharpe\n\nLarry Sharpe may refer to:\n\...,"[-6.448710632324219, -6.621741485595703, -9.79...","[0.9894140958786011, 0.9211587905883789, 0.721..."
4,10001,6,verbs,-9.411,[ExplainerInterventionExample(prompt='Continuo...,0.995093,"[10, 32, 100, 320, 1000]",32,"[Arvid Kramer\n\nArvid Kramer (born October 3,...",[Continuous external counterpressure during cl...,"[-6.721965789794922, -6.747789764404297, -7.75...","[0.9905104041099548, 0.9950931072235107, 0.993..."
0,10000,2,the,-9.426538,[ExplainerInterventionExample(prompt='Effects ...,0.99962,"[10, 32, 100, 320, 1000]",32,[1. Field of the Invention\nThis invention gen...,[Effects of acute olanzapine exposure on centr...,"[-8.017688751220703, -8.224576568603515, -10.1...","[0.9996201992034912, 0.9982855916023254, 0.506..."
2,10000,11,punctuation,-10.581014,[ExplainerInterventionExample(prompt='z(w) = -...,0.949169,"[10, 32, 100, 320, 1000]",32,"[\nIf you're, What to do with Doubt... The ver...",[z(w) = -w**3 - 3*w**2 - 3*w - 1. Suppose 4*l ...,"[-7.878225708007813, -7.977081298828125, -8.57...","[0.9491687417030334, 0.9463328123092651, 0.859..."
5,10001,11,capitalization\n<|end_of_text|><|begin_of_text...,-10.795758,[ExplainerInterventionExample(prompt='He procl...,0.816027,"[10, 32, 100, 320, 1000]",32,[Background {#Sec1}\n==========\n\nHereditary ...,"[He proclaims Ireland's most famous day ""a gre...","[-8.285784149169922, -8.428328704833984, -9.03...","[0.8160274028778076, 0.7745303511619568, 0.651..."
3,10001,2,capitalization,-10.861929,[ExplainerInterventionExample(prompt='El ojo d...,0.940045,"[10, 32, 100, 320, 1000]",32,"[From the dome to your home. Ramblings, mutter...",[El ojo de vidrio\n\nEl ojo de vidrio may refe...,"[-7.5106353759765625, -7.581437683105468, -12....","[0.9400449991226196, 0.9275412559509277, 0.855..."


In [None]:
all_df.iloc[0].intervention_examples

[ExplainerInterventionExample(prompt='Larry Sharpe\\n\\nLarry Sharpe may refer to:\\n\\nLarry Sharpe (politician) (born 1968), American business consultant and political activist \\nLarry Sharpe (wrestler) (1951–2017', top_tokens=[')', ');', ',', ' and', ':'], top_p_increases=[0.0795612633228302, 0.0013725794851779938, 0.0009166579693555832, 0.0003981509362347424, 0.0002951501519419253]),
 ExplainerInterventionExample(prompt='Reflux gastritis and dysplasia.\\nIn order to evaluate if duodenogastric reflux (DGR) is associated with a different frequency of gastric dysplasia in comparison with the absence of DGR. 40', top_tokens=[',', '-', ',', ' All', '–'], top_p_increases=[0.09463351964950562, 0.029508035629987717, 0.024005532264709473, 0.003698201384395361, 0.003605559468269348]),
 ExplainerInterventionExample(prompt='Careers\\n\\nOpen Positions\\n\\nAt', top_tokens=[' the', ' least', ' The', ' our', ' a'], top_p_increases=[0.15025703608989716, 0.020413324236869812, 0.006305559538304806

In [None]:
all_df.iloc[0].scorer_texts

['1998 Piberstein Styrian Open – Singles\n\nBarbara Schett was the defending champion but lost in the quarterfinals to Emmanuelle Gagliardi.\n\nPatty Schnyder won in the final 6–2, 4–6, 6–3 against Gala León Garc',
 'Q:\n\nfilesystem for archiving\n\nI have some complex read-only data in',
 'Happy to help/happy to be here: Identifying components of successful clinical placements for undergraduate nursing students',
 'Silica nanoparticle sols. Part 3: Monitoring the state of agglomeration at the air/water interface',
 "According to a newly published report by Dell'Oro Group, the number of point-to-point",
 'West Virginia History OnView (WVHOV) in the West Virginia & Regional History Center’s online database that includes over 50,000 images digitized from our rich and',
 'I am giving this daylily a neutral rating because of the height of the scapes. It is listed as blooming on 24 inch scapes, but my plant has always been much, much shorter. The flowers are large and very pretty, though. 

In [None]:
all_df = pd.DataFrame(all_results)
all_df = all_df.sort_values("predictiveness_score", ascending=False)
all_df.to_pickle(f"counterfactual_results/{len(feat_layers)}layers_{len(feat_idxs)}feats.pkl")

In [None]:
# all_df = pd.read_pickle(f"counterfactual_results/3layers_200feats.pkl")
all_df

Unnamed: 0,feat_idx,feat_layer,explanation,predictiveness_score,intervention_examples,max_intervened_prob,scorer_intervention_strengths,explainer_intervention_strength,scorer_texts,explainer_texts,predictiveness_scores,max_intervened_probs
547,182,6,1-2-3-4-0-9,-8.721251,"[ExplainerInterventionExample(prompt='Road', t...",0.231213,"[10, 32, 100, 320, 1000]",32,"[Project, March, 1, Sim, 2019, 2016, 1, A, Hou...","[Road, A, Value]","[-10.622151947021484, -10.456940460205079, -8....","[0.11679935455322266, 0.1188754290342331, 0.14..."
546,182,2,1234,-8.875147,"[ExplainerInterventionExample(prompt='Road', t...",0.226274,"[10, 32, 100, 320, 1000]",32,"[Project, March, 1, Sim, 2019, 2016, 1, A, Hou...","[Road, A, Value]","[-10.631166076660156, -10.523173522949218, -8....","[0.11783038079738617, 0.12202408164739609, 0.1..."
556,185,6,"2019, value, house",-8.937399,"[ExplainerInterventionExample(prompt='2019', t...",0.431593,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 200, 538, A, A, A, 1, ', 1]","[2019, Value, House]","[-10.816554260253906, -10.416304779052734, -9....","[0.11420270800590515, 0.11989623308181763, 0.1..."
555,185,2,2019,-8.978640,"[ExplainerInterventionExample(prompt='2019', t...",0.435408,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 200, 538, A, A, A, 1, ', 1]","[2019, Value, House]","[-10.726454925537109, -10.404013061523438, -9....","[0.11452289670705795, 0.12135178595781326, 0.1..."
183,61,2,2016,-9.364605,[ExplainerInterventionExample(prompt='Project'...,0.943791,"[10, 32, 100, 320, 1000]",32,"[Sim, 2019, G, Ay, A, A, A, 200, ', L]","[Project, G, 2016]","[-10.697615814208984, -10.711712646484376, -9....","[0.07175029814243317, 0.06015954166650772, 0.4..."
...,...,...,...,...,...,...,...,...,...,...,...,...
402,134,2,2019,-13.192244,"[ExplainerInterventionExample(prompt='2019', t...",0.992656,"[10, 32, 100, 320, 1000]",32,"[2016, Sim, 1, ', A, A, 1, X, A, H]","[2019, 200, House]","[-11.074691772460938, -11.361431121826172, -12...","[0.12653400003910065, 0.16054397821426392, 0.1..."
404,134,11,2019,-13.266629,"[ExplainerInterventionExample(prompt='2019', t...",0.962138,"[10, 32, 100, 320, 1000]",32,"[2016, Sim, 1, ', A, A, 1, X, A, H]","[2019, 200, House]","[-10.987806701660157, -11.11118392944336, -11....","[0.11369820684194565, 0.11930322647094727, 0.1..."
539,179,11,2019,-13.326119,"[ExplainerInterventionExample(prompt='2019', t...",0.962642,"[10, 32, 100, 320, 1000]",32,"[2016, Q, 1, 200, A, A, 1, 538, A, Value]","[2019, ', X]","[-10.844153594970702, -10.966635131835938, -11...","[0.11189045011997223, 0.11463377624750137, 0.0..."
278,92,11,1-200,-13.544800,"[ExplainerInterventionExample(prompt='A', top_...",0.951868,"[10, 32, 100, 320, 1000]",32,"[2016, Road, L, 1, *, A, 538, Value, A, ']","[A, 200, 1]","[-10.737577056884765, -10.89888916015625, -11....","[0.11086613684892654, 0.1106131449341774, 0.08..."
