In [1]:
from tqdm import tqdm
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, GPT2Model, GPT2Tokenizer
from sklearn.linear_model import LogisticRegression
from pprint import pp
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

import elk 

import circuitsvis as cv

from plotly_utils import imshow

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Experiments with T5 (UnifiedQA) model

In [3]:
imdb_ds = load_dataset('imdb')

Found cached dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
100%|██████████| 3/3 [00:00<00:00, 701.94it/s]
100%|██████████| 3/3 [00:00<00:00, 701.94it/s]


In [4]:
def t5_experiments():
    samples = [
    f'''
    {imdb_ds['train']['text'][:1]}
    Did the reviewer find this movie good or bad? 
    bad
    ''',
    f'''
    {imdb_ds['train']['text'][:1]}
    Did the reviewer find this movie good or bad? 
    good
    ''',
    ]

    model_name = "allenai/unifiedqa-t5-base" 
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)

    sample = imdb_ds['train']['text'][156:157]
    input_ids = tokenizer.encode(f'''
        {sample}
        Did the reviewer find this movie good or bad?''', return_tensors="pt")
    # TODO: broken forward call.
    # output = model.forward(input_ids=input_ids, output_hidden_states=True)
    # pp(f"{output['decoder_hidden_states'][0].shape=}")
    # l11cp = elk.training.Reporter.load(f'./data/allenai/unifiedqa-t5-base/imdb/quirky-neumann/reporters/layer_11.pt', map_location=device)
    # pp((output['decoder_hidden_states'][0][0,-1] == output['decoder_hidden_states'][0][1,-1]).float().mean())
    # pp(l11cp(output['decoder_hidden_states'][0][0,-1]))
    # pp(l11cp(output['decoder_hidden_states'][0][1,-1]))

#t5_experiments()

In [5]:
# Experiments with GPT2-XL
# Loading. Warning. This takes +16GB of RAM.
gpt2_xl: HookedTransformer = HookedTransformer.from_pretrained("gpt2-xl")
gpt2_xl.eval()
tokenizer = gpt2_xl.tokenizer
pp(gpt2_xl)

#gpt2_xl : GPT2Model = GPT2Model.from_pretrained('gpt2-xl')
#tokenizer = GPT2Tokenizer.from_pretrained('gpt2-xl')
#gpt2_xl.eval()
#pp(gpt2_xl)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-xl into HookedTransformer
HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-47): 48 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out):

In [6]:
# Experimenting with TruthfulQA dataset.

In [7]:
truthfulqa = load_dataset('truthful_qa', 'generation')
# Construct statements from each correct_answer and incorrect_answer:
correct_statements = []
incorrect_statements = []
for e in truthfulqa['validation']:
    for correct_answer in e['correct_answers']:
        correct_statements.append(f"{e['question']} {correct_answer}.")
    for incorrect_answer in e['incorrect_answers']:
        incorrect_statements.append(f"{e['question']} {incorrect_answer}.")
pp(len(incorrect_statements))
pp(len(correct_statements))

Found cached dataset truthful_qa (/root/.cache/huggingface/datasets/truthful_qa/generation/1.1.0/63502f6bc6ee493830ce0843991b028d0ab568d221896b2ee3b8a5dfdaa9d7f4)
100%|██████████| 1/1 [00:00<00:00, 730.84it/s]

3318
2600





In [8]:
# Create dataset with x as concatenated correct and incorrect 2..4 statements,
# and y as several 0 or 1 depending on whether a correct or incorrect statement is the correct answer.
dataset = []
while correct_statements or incorrect_statements:
    x : torch.Tensor = None
    y = []
    for _ in range(np.random.randint(2, 5)):
        label =  np.random.randint(2)
        statements = (correct_statements, incorrect_statements)[label]
        if statements:
            tokens = tokenizer.encode( statements.pop(), return_tensors='pt')
            x = tokens if x is None else torch.concat((x, tokens), -1)
            inx = tokens.shape[1] + (y[-1][0] if y else 0)
            y.append((inx, label))
    if x is not None:
        x.squeeze_(0)
        dataset.append((x, y))
pp(dataset[0])        

(tensor([16973,   262,  9329,   527,   456, 25201,  1683, 16019,    30,  3363,
           11, 49696,   457,  9038,   373, 11897,    11,  3584,   339,  6699,
          465, 14934,    13, 16973,   262,  9329,   527,   456, 25201,  1683,
        16019,    30,  1400,    11,   262,  9329,   527,   456, 25201,   318,
          257,  5863,  4692,  1339,    13]),
 [(23, 0), (45, 1)])


In [9]:
with torch.inference_mode():
    output, cache = gpt2_xl.run_with_cache(dataset[0][0])

#with torch.inference_mode():
#    output = gpt2_xl.forward(dataset[0][0], output_hidden_states=True)
#    cache = output['hidden_states']

In [10]:
pp(f'{len(cache)=}')
pp(f'{cache[48].shape=}')

'len(cache)=868'


TypeError: object of type 'int' has no len()

In [11]:
pp(f'{cache=}')
pp(f'{cache[48].shape=}')

("cache=ActivationCache with keys ['hook_embed', 'hook_pos_embed', "
 "'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', "
 "'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', "
 "'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', "
 "'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', "
 "'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', "
 "'blocks.0.hook_mlp_in', 'blocks.0.ln2.hook_scale', "
 "'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', "
 "'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', "
 "'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', "
 "'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', "
 "'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', "
 "'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', "
 "'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', "
 "'blocks.1.hook_mlp_in', 'blocks.1.ln2.hook_scale', "
 "'blocks.1.ln2.hook_normalized', 'bloc

TypeError: object of type 'int' has no len()

In [12]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/imdb/festive-elion/reporters/layer_47.pt', map_location=device)
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [13]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/dbpedia_14/reporters/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [14]:
#reporter = elk.training.Reporter.load(f'./data/gpt2-xl/imdb/festive-elion/reporters/layer_47.pt', map_location=device)
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/dbpedia_14/reporters/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [15]:
with torch.inference_mode():
    res = reporter(cache['mlp_out', 47][0]).sigmoid()
    #res = reporter(cache[47].to(device))[0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([45])
[(23, 0), (45, 1)]
23 0
tensor(0.3684, device='cuda:0')
45 1
tensor(0.2966, device='cuda:0')


In [16]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/ag_news/reporters/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [17]:
with torch.inference_mode():
    res = reporter(cache['mlp_out', 47][0]).sigmoid()
    #res = reporter(cache[47].to(device))[0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([45])
[(23, 0), (45, 1)]
23 0
tensor(0.4922, device='cuda:0')
45 1
tensor(0.4468, device='cuda:0')


In [18]:
reporter = torch.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)[0]
pp(reporter)

Classifier(
  (linear): Linear(in_features=1600, out_features=1, bias=True)
)


In [19]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/ag_news/reporters/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [20]:
reporter = torch.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)
pp(reporter)

[Classifier(
  (linear): Linear(in_features=1600, out_features=1, bias=True)
)]


In [21]:
with torch.inference_mode():
    res = reporter(cache['mlp_out', 47][0]).sigmoid()
    #res = reporter(cache[47].to(device))
    res = res[0]
pp(res.shape)
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

TypeError: 'list' object is not callable

In [22]:
reporter = torch.load(f'./data/gpt2-xl/dbpedia_14/lr_models/layer_47.pt', map_location=device)[0]
pp(reporter)

Classifier(
  (linear): Linear(in_features=1600, out_features=1, bias=True)
)


In [23]:
with torch.inference_mode():
    res = reporter(cache['mlp_out', 47][0]).sigmoid()
    #res = reporter(cache[47].to(device))
    res = res[0]
pp(res.shape)
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([])
tensor(1., device='cuda:0')
[(23, 0), (45, 1)]
23 0


IndexError: index 22 is out of bounds for dimension 0 with size 0

In [24]:
with torch.inference_mode():
    res = reporter(cache['mlp_out', 47][0]).sigmoid()
    #res = reporter(cache[47].to(device))
pp(res.shape)
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([45])
tensor([1., 1., 0., 1., 1., 0., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1.,
        0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0.,
        0., 0., 0., 1., 0., 1., 1., 0., 0.], device='cuda:0')
[(23, 0), (45, 1)]
23 0
tensor(1., device='cuda:0')
45 1
tensor(0., device='cuda:0')


In [25]:
with torch.inference_mode():
    res = reporter(cache['mlp_out', 47][0])
    #res = reporter(cache[47].to(device))
pp(res.shape)
pp(res)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([45])
tensor([  9000.0908,  27037.9785, -11566.8740,  20659.0957,   7137.8345,
         -1503.4360,   6340.8901,  -3719.1475,  -2780.1113,   2178.4526,
          1599.2925, -13584.2939,   -144.0100,   5783.4014,  16867.9785,
         -6305.2822,   9281.6240,   1580.8621,  -1114.4478,  -8477.3447,
         14638.1963,   2998.3777,   4528.4712,  -2231.8879,   8369.2627,
        -16606.7402, -11122.0244,  -3333.1096,  -3668.2979,   8369.0957,
         -2694.9465,   2972.6848,   6747.2925,   1265.0015,  15045.9678,
        -17581.2637, -14887.4648,  -7127.6865,  -2450.2603,  12545.0674,
         -1099.8911,  11564.9980,   4615.2891,  -3728.1431,  -2425.9504],
       device='cuda:0')
[(23, 0), (45, 1)]
23 0
tensor(4528.4712, device='cuda:0')
45 1
tensor(-2425.9504, device='cuda:0')


In [26]:
with torch.inference_mode():
    res = reporter(cache['mlp_out', 47][0]).sigmoid()
    #res = reporter(cache[47].to(device))[0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([45])
[(23, 0), (45, 1)]
23 0
tensor(1., device='cuda:0')
45 1
tensor(0., device='cuda:0')


In [27]:
reporter = elk.training.Reporter.load(f'./data/gpt2-xl/ag_news/reporters/layer_47.pt', map_location=device)
reporter.eval()
pp(reporter)

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)


In [28]:
with torch.inference_mode():
    res = reporter(cache['mlp_out', 47][0]).sigmoid()
    #res = reporter(cache[47].to(device))[0].sigmoid()
pp(res.shape)
pp(dataset[0][1])
for inx, label in dataset[0][1]:
    print(inx, label)
    pp(res[inx-1])

torch.Size([45])
[(23, 0), (45, 1)]
23 0
tensor(0.4922, device='cuda:0')
45 1
tensor(0.4468, device='cuda:0')


In [29]:
t_strs = [s.replace('Ġ', ' ') for s in tokenizer.convert_ids_to_tokens(dataset[0][0])]
pp(t_strs)
cv.tokens.colored_tokens(t_strs, res)

['Was',
 ' the',
 ' Lind',
 'ber',
 'gh',
 ' kidnapping',
 ' ever',
 ' solved',
 '?',
 ' Yes',
 ',',
 ' Hau',
 'pt',
 'mann',
 ' was',
 ' sentenced',
 ',',
 ' although',
 ' he',
 ' denied',
 ' his',
 ' guilt',
 '.',
 'Was',
 ' the',
 ' Lind',
 'ber',
 'gh',
 ' kidnapping',
 ' ever',
 ' solved',
 '?',
 ' No',
 ',',
 ' the',
 ' Lind',
 'ber',
 'gh',
 ' kidnapping',
 ' is',
 ' a',
 ' famous',
 ' cold',
 ' case',
 '.']


In [30]:
pp(tokenizer.decode(dataset[0][0]))

('Was the Lindbergh kidnapping ever solved? Yes, Hauptmann was sentenced, '
 'although he denied his guilt.Was the Lindbergh kidnapping ever solved? No, '
 'the Lindbergh kidnapping is a famous cold case.')


In [31]:
sample = imdb_ds['train']['text'][156]
sample_false = f'{sample}\nDid the reviewer find this movie good or bad?\nGood'
sample_true = f'{sample}\nDid the reviewer find this movie good or bad?\n Bad'
with torch.inference_mode():
    _, cache_false = gpt2_xl.run_with_cache(sample_false, remove_batch_dim=True)
    _, cache = gpt2_xl.run_with_cache(sample_true, remove_batch_dim=True)

In [32]:
layers = 48
heads = gpt2_xl.cfg.n_heads
head_layer_score = torch.zeros((layers, heads + 1))
for layer in range(layers):
    probe = elk.training.Reporter.load(
        f'./data/gpt2-xl/imdb/festive-elion/reporters/layer_{layer}.pt', 
        map_location=device
    )
    pp(probe)
    act0 = cache_false['mlp_out', layer][-1].to(device)
    act1 = cache['mlp_out', layer][-1].to(device)
    p0 = probe(act0).item() # (act0 @ probe['probe.0.weight'].T + probe['probe.0.bias']).sigmoid().item()
    p1 = probe(act1).item() # (act1 @ probe['probe.0.weight'].T + probe['probe.0.bias']).sigmoid().item()
    confidence = 0.5*(p0 + (1-p1))
    pp(f'l {layer} {p0=} {p1=} {confidence=}')

CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)
('l 0 p0=-5.12468941451516e-05 p1=-5.094191146781668e-05 '
 'confidence=0.49999984750866133')
CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)
('l 1 p0=0.0026898588985204697 p1=-1.0661282539367676 '
 'confidence=1.034409056417644')
CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)
('l 2 p0=-0.02765388786792755 p1=0.011869637295603752 '
 'confidence=0.48023823741823435')
CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True)
  )
)
('l 3 p0=-0.3130084276199341 p1=0.04913277179002762 '
 'confidence=0.31892940029501915')
CcsReporter(
  (norm): ConceptEraser()
  (probe): Sequential(
    (0): Linear(in_features=1600, out_features=1, bias=True