In [1]:
# load subject model
# load SAEs without attaching them to the model
# for now just use the Islam feature and explanation
# load a scorer. The prompt should have the input as well this time
# (for now) on random pretraining data, evaluate gpt2 with a hook that 
# adds a multiple of the Islam feature to the appropriate residual stream layer and position
# Get the pre- and post-intervention output distributions of gpt2
# (TODO: check if all the Islam features just have similar embeddings)
# Show this to the scorer and get a score (scorer should be able to have a good prior without being given the clean output distribution)
# Also get a simplicity score for the explanation

In [2]:
import json
import random

with open("pile.jsonl", "r") as f:
    pile = random.sample([json.loads(line) for line in f.readlines()], 100)

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = "cuda:0"

subject_name = "gpt2"
subject = AutoModelForCausalLM.from_pretrained(subject_name).to(device)
subject_tokenizer = AutoTokenizer.from_pretrained(subject_name)
subject_tokenizer.pad_token = subject_tokenizer.eos_token
subject.config.pad_token_id = subject_tokenizer.eos_token_id

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
scorer_name = "meta-llama/Meta-Llama-3.1-8B"
scorer = AutoModelForCausalLM.from_pretrained(scorer_name).to(torch.bfloat16).to(device)
scorer_tokenizer = AutoTokenizer.from_pretrained(scorer_name)
scorer_tokenizer.pad_token = scorer_tokenizer.eos_token
scorer.config.pad_token_id = scorer_tokenizer.eos_token_id
scorer.generation_config.pad_token_id = scorer_tokenizer.eos_token_id

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.81it/s]


In [5]:
def get_scorer_simplicity_prompt(explanation):
    prefix = "Explanation\n\n"
    return f"{prefix}{explanation}{scorer_tokenizer.eos_token}", prefix

def get_scorer_predictiveness_prompt(prompt, explanation, few_shot_prompts=None, few_shot_explanations=None, few_shot_tokens=None):
    if few_shot_explanations is not None:
        assert few_shot_tokens is not None and few_shot_prompts is not None
        assert len(few_shot_explanations) == len(few_shot_tokens) == len(few_shot_prompts)
        few_shot_prompt = "\n\n".join(get_scorer_predictiveness_prompt(pr, expl) + token for pr, expl, token in zip(few_shot_prompts, few_shot_explanations, few_shot_tokens)) + "\n\n"
    else:
        few_shot_prompt = ""
    return few_shot_prompt + f"Explanation: {explanation}\n<PROMPT>{prompt}</PROMPT>"

few_shot_prompts = ["My favorite food is", "From west to east, the westmost of the seven", "He owned the watch for a long time. While he never said it was"]
few_shot_explanations = ["fruits and vegetables", "ateg", "she/her pronouns"]
few_shot_tokens = [" oranges", "WAY", " hers"]
print(get_scorer_predictiveness_prompt(few_shot_prompts[0], few_shot_explanations[0], few_shot_prompts, few_shot_explanations, few_shot_tokens))

Explanation: fruits and vegetables
<PROMPT>My favorite food is</PROMPT> oranges

Explanation: ateg
<PROMPT>From west to east, the westmost of the seven</PROMPT>WAY

Explanation: she/her pronouns
<PROMPT>He owned the watch for a long time. While he never said it was</PROMPT> hers

Explanation: fruits and vegetables
<PROMPT>My favorite food is</PROMPT>


In [6]:
import torch
from sae_auto_interp.autoencoders.OpenAI.model import Autoencoder


weight_dir = "/mnt/ssd-1/gpaulo/SAE-Zoology/weights/gpt2_128k"

layer = 4
feat = 9
explanation = "Islam"
intervention_strengths = [10, 32, 100, 320, 1000]

path = f"{weight_dir}/{layer}.pt"
state_dict = torch.load(path)
ae = Autoencoder.from_state_dict(state_dict=state_dict)
feat = ae.decoder.weight[:, feat].to(device)
feat.shape

torch.Size([768])

In [7]:
subject

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [34]:
from functools import partial

def intervene(module, input, output, intervention_strength=10.0, position=-1):
    hiddens = output[0]  # the later elements of the tuple are the key value cache
    hiddens[:, position, :] += intervention_strength * feat.to(hiddens.device)

seed = 42
random.seed(seed)
texts = []
for _ in range(1000):
    # sample a random text from the pile, and stop it at a random token position, less than 64 tokens
    text = random.choice(pile)["text"]
    tokenized_text = subject_tokenizer.encode(text, add_special_tokens=False, max_length=64, truncation=True)
    stop_pos = random.randint(1, min(len(tokenized_text) - 1, 63))
    text = subject_tokenizer.decode(tokenized_text[:stop_pos])
    texts.append(text)

In [35]:
explanation

'Islam'

In [36]:
from tqdm.auto import tqdm

explanation = "Islam"
predictiveness_scores = []
max_intervened_probs = []
for intervention_strength in tqdm(intervention_strengths):
    if "handle" in locals():
        handle.remove()  # type: ignore
    handle = subject.transformer.h[layer].register_forward_hook(partial(intervene, intervention_strength=intervention_strength, position=-1))

    predictiveness_score = torch.tensor(0.0, device=device)
    max_intervened_prob = 0.0
    for text in texts:
        inputs = subject_tokenizer(text, return_tensors="pt", add_special_tokens=True).to(device)
        with torch.inference_mode():
            outputs = subject(**inputs)
        intervened_probs = outputs.logits[0, -1, :].softmax(dim=-1)
        max_intervened_prob = max(max_intervened_prob, intervened_probs.max().item())

        # get the explanation predictiveness
        scorer_predictiveness_prompt = get_scorer_predictiveness_prompt(text, explanation, few_shot_prompts, few_shot_explanations, few_shot_tokens)
        scorer_input_ids = scorer_tokenizer(scorer_predictiveness_prompt, return_tensors="pt").input_ids.to(device)
        with torch.inference_mode():
            scorer_logits = scorer(scorer_input_ids).logits[0, -1, :]
            scorer_logp = scorer_logits.log_softmax(dim=-1)

        scorer_vocab = scorer_tokenizer.vocab  # for some reason this takes 34 ms so we need to factor it out
        for subj_tok, subj_id in subject_tokenizer.vocab.items():
            if subj_tok in scorer_vocab:
                scorer_tok = subj_tok
            else:
                # we need to map the subject model's tokens to the scorer model token that has the longest common prefix
                for i in range(len(subj_tok) - 1, 0, -1):
                    if subj_tok[:i] in scorer_vocab:
                        scorer_tok = subj_tok[:i]
                        break
                else:
                    raise ValueError(f"No scorer token found for {subj_tok}")
            predictiveness_score += intervened_probs[subj_id] * scorer_logp[scorer_vocab[scorer_tok]]
            if intervened_probs[subj_id] > 0.05:
                print(subj_tok, intervened_probs[subj_id].item())
    max_intervened_probs.append(max_intervened_prob)
    predictiveness_scores.append(predictiveness_score.item() / len(texts))

predictiveness_score = sum(predictiveness_scores) / len(predictiveness_scores)
predictiveness_score

  0%|          | 0/5 [00:00<?, ?it/s]

formation 0.6249427199363708
. 0.9983684420585632
? 0.4443819224834442
Ġand 0.06194838881492615
Ġany 0.0631946474313736
Ġthe 0.19016729295253754
Ġproposal 0.07402820140123367
Ġbill 0.40904274582862854
Ġresolution 0.05020170658826828
Ġdocumentation 0.5085869431495667
Ġfiles 0.359658420085907
Ġsoftware 0.08802788704633713
? 0.08232977986335754
Ġlike 0.05482378974556923
t 0.051101766526699066
mo 0.13094371557235718
pl 0.1599503755569458
Ċ 0.053369905799627304
: 0.05481799691915512
Ċ 0.1690206676721573
: 0.3712936043739319
quez 0.1644543707370758
. 0.05829939991235733
Ċ 0.3021225035190582
Ġthe 0.051967550069093704
Ġand 0.11504463106393814
Ġas 0.06074370816349983
ENSE 0.9990311861038208
clinical 0.38429343700408936
cut 0.2304704338312149
Ġ2003 0.05525459721684456
ascular 0.9879785776138306
Ġany 0.06181328743696213
Ġthe 0.07173586636781693
Ġhaving 0.09951714426279068
Ġcompromising 0.06998807936906815
Ġbeing 0.07718007266521454
heart 0.06815971434116364
H 0.060492679476737976
HF 0.14979508519

 20%|██        | 1/5 [32:39<2:10:37, 1959.30s/it]

fer 0.14507398009300232
formation 0.14565734565258026
fect 0.08775582909584045
plant 0.05260167643427849
cript 0.05486635863780975
. 0.9970888495445251
? 0.3693937659263611
Ġany 0.060874734073877335
Ġthe 0.17941290140151978
Ġplan 0.050435349345207214
Ġproposal 0.07430922240018845
Ġbill 0.3254750669002533
Ġresolution 0.05080261453986168
Ġdocumentation 0.37657538056373596
Ġfiles 0.36497265100479126
Ġsoftware 0.07299243658781052
? 0.11420799046754837
- 0.20435160398483276
_ 0.05548327416181564
_ 0.06377946585416794
Ċ 0.16804860532283783
: 0.44486531615257263
az 0.06679155677556992
. 0.050926655530929565
Ċ 0.2664683759212494
Ġthe 0.11851517111063004
Ġand 0.06859969347715378
Ġa 0.05842047184705734
Ġas 0.059706490486860275
ENSE 0.9782468676567078
clinical 0.39605990052223206
cut 0.11952727288007736
Ġ2013 0.0514795295894146
ascular 0.8447772860527039
Ġany 0.08061228692531586
Ġthe 0.05549406260251999
Ġhaving 0.08833151310682297
Ġcompromising 0.0769289955496788
Ġbeing 0.07131670415401459
heart 

 40%|████      | 2/5 [1:04:41<1:36:52, 1937.40s/it]

is 0.08886314183473587
cript 0.05505191534757614
www 0.22829239070415497
. 0.7026323080062866
? 0.28746911883354187
Ġand 0.07437486946582794
, 0.05991649255156517
Ġthe 0.17841146886348724
. 0.06704571098089218
Ġban 0.12245253473520279
Ġmedia 0.054003652185201645
Ġbooks 0.1001199409365654
Ġauthor 0.0616995170712471
Ġbook 0.0961933508515358
Ġtext 0.1303848922252655
? 0.05816570669412613
- 0.07153365015983582
_ 0.20559267699718475
" 0.05509497970342636
am 0.18314483761787415
, 0.05216279625892639
_ 0.10098462551832199
. 0.05039078742265701
Ġ( 0.0905800610780716
Ċ 0.09243728220462799
: 0.17706900835037231
Ġthe 0.10966522246599197
Ġbe 0.11209169775247574
, 0.06103073060512543
aw 0.1072627380490303
. 0.07548834383487701
Ġal 0.07034623622894287
Ġare 0.05552219599485397
Ġis 0.11789793521165848
Ċ 0.11980939656496048
Ġthe 0.10554169863462448
_ 0.08777158707380295
. 0.08833185583353043
ĠLIC 0.1224309578537941
Ġ( 0.11458756774663925
- 0.4167388081550598
r 0.1240687146782875
, 0.12444642931222916
Ġ

 60%|██████    | 3/5 [1:37:06<1:04:41, 1940.86s/it]

ĠIslamic 0.27174901962280273
Ġholy 0.09149555116891861
ĠIslamic 0.16914771497249603
ĠIslamic 0.14202526211738586
Ġholy 0.08144627511501312
ĠIslamic 0.18198588490486145
ĠSharia 0.05937403440475464
ĠIslamic 0.5508764386177063
Ġholy 0.053447041660547256
ĠIslamic 0.5266122221946716
ĠIslamic 0.11943002045154572
ĠIslamic 0.1861426830291748
Ġal 0.056493356823921204
ĠIslamic 0.4969257712364197
ĠSharia 0.09936697036027908
ĠIslamic 0.3622719347476959
ĠIslamic 0.32689252495765686
ĠSharia 0.05428048223257065
ĠIslamic 0.6681840419769287
Ġholy 0.06511695683002472
ĠIslamic 0.1002923846244812
ĠIslamic 0.2882227897644043
Ġholy 0.12484189122915268
ĠIslamic 0.3520832061767578
ĠIslamic 0.5978910326957703
ĠIslamic 0.47825315594673157
ĠIslamic 0.22544407844543457
Ġal 0.09131168574094772
Ġthe 0.07269874960184097
Ġholy 0.07813403755426407
ĠIslamic 0.1682371199131012
Ġcal 0.05337674543261528
Ġholy 0.052153199911117554
ĠIslamic 0.2276928722858429
ĠIslamic 0.48530760407447815
ĠIslamic 0.2734593152999878
ĠIslamic

 80%|████████  | 4/5 [2:09:57<32:32, 1952.96s/it]  

ĠIslamic 0.9386752843856812
ĠIslamic 0.9442702531814575
ĠIslamic 0.9244183897972107
ĠIslamic 0.9397141337394714
ĠIslamic 0.9536890387535095
ĠIslamic 0.9598962068557739
ĠIslamic 0.9465983510017395
ĠIslamic 0.9358542561531067
ĠSharia 0.09653507173061371
ĠIslamic 0.669718861579895
ĠIslamic 0.9433712363243103
ĠIslamic 0.9520204067230225
ĠSharia 0.08512207865715027
ĠIslamic 0.6549578905105591
ĠIslamic 0.930568277835846
ĠIslamic 0.9558537602424622
ĠIslamic 0.9435129165649414
ĠIslamic 0.9568827152252197
ĠIslamic 0.9660479426383972
ĠIslamic 0.9486666917800903
ĠIslamic 0.9485344290733337
ĠIslamic 0.9356427788734436
ĠIslamic 0.957284688949585
ĠIslamic 0.9460194706916809
ĠIslamic 0.9656252861022949
ĠIslamic 0.9592301845550537
ĠIslamic 0.9495181441307068
ĠIslamic 0.9542229175567627
ĠIslamic 0.9629397392272949
ĠIslamic 0.9695368409156799
ĠIslamic 0.9603853821754456
ĠIslamic 0.9670529365539551
ĠIslamic 0.950974702835083
ĠIslamic 0.9553609490394592
ĠIslamic 0.9658759236335754
ĠIslamic 0.9512910246849

In [27]:
predictiveness_score

-8.042637176513672

In [30]:
predictiveness_scores

[-7.000522613525391,
 -7.397779083251953,
 -9.827688598632813,
 -11.150653839111328,
 -11.35284881591797]

In [32]:
predictiveness_scores

[-7.026526641845703,
 -7.4196624755859375,
 -9.777403259277344,
 -8.839710235595703,
 -7.149883270263672]

In [33]:
intervention_strengths

[10, 32, 100, 320, 1000]